From 93cb2538351184aba5f93245083cbce5a308fb1a Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Tue, 31 Mar 2026 12:00:23 -0700 Subject: [PATCH 01/60] Updated Cerner and FHIR connection --- playground/scenario_post_visit.py | 29 ++--- src/agents/fhir_cerner_mcp.py | 8 +- src/agents/fhir_epic_mcp.py | 8 +- src/agents/mcp_entrypoint.py | 52 ++++----- src/bindings/factory.py | 18 +-- src/bindings/mcp_server/server.py | 6 +- src/bindings/rest_api/app.py | 10 +- src/connectors/fhir_cerner/logic.py | 158 +++++++++------------------ src/connectors/fhir_cerner/schema.py | 114 ++++++++++--------- src/connectors/fhir_epic/logic.py | 114 ++++--------------- src/connectors/fhir_epic/schema.py | 123 ++++++++++----------- src/connectors/manifest.py | 47 ++++++-- tests/test_fhir_cerner.py | 137 ++++++++++++++--------- tests/test_fhir_epic.py | 102 +++++++++-------- tests/test_toolhive_agent.py | 10 +- 15 files changed, 440 insertions(+), 496 deletions(-) diff --git a/playground/scenario_post_visit.py b/playground/scenario_post_visit.py index 9581c66..f47b681 100644 --- a/playground/scenario_post_visit.py +++ b/playground/scenario_post_visit.py @@ -61,10 +61,9 @@ async def run_scenario(): logger.info(f"Searching for patient: {patient_search_params}") try: - patient_action = connector.get_action("read_patient") - patient_result = await patient_action.internal_execute( - FhirPatientReadInput(search_params=patient_search_params), - trace_id=trace_id + patient_result = await connector.internal_execute( + FhirPatientReadInput(action="read_patient", search_params=patient_search_params), + trace_id=trace_id, ) patient_id = patient_result.resource.get("id") logger.info(f"Found Patient ID: {patient_id}") @@ -82,17 +81,19 @@ async def run_scenario(): logger.info(f"Finding encounter for patient {patient_id} on {today}") try: - encounter_action = connector.get_action("search_encounter") - enc_result = await encounter_action.internal_execute( - FhirEncounterSearchInput(search_params=encounter_params), - trace_id=trace_id + enc_result = await connector.internal_execute( + FhirEncounterSearchInput(action="search_encounter", search_params=encounter_params), + trace_id=trace_id, ) - + if not enc_result.resources: logger.warning("No encounters found for this patient today. Falling back to most recent.") - enc_result = await encounter_action.internal_execute( - FhirEncounterSearchInput(search_params={"patient": patient_id, "status": "finished"}), - trace_id=trace_id + enc_result = await connector.internal_execute( + FhirEncounterSearchInput( + action="search_encounter", + search_params={"patient": patient_id, "status": "finished"}, + ), + trace_id=trace_id, ) if not enc_result.resources: @@ -110,6 +111,7 @@ async def run_scenario(): encoded_note = base64.b64encode(note_content.encode('utf-8')).decode('utf-8') doc_input = FhirDocumentReferenceCreateInput( + action="create_document_reference", identifier=[{"system": "urn:oid:1.2.3", "value": f"NOTE-{datetime.now().timestamp()}"}], status="current", type={"coding": [{"system": "http://loinc.org", "code": "11506-3", "display": "Progress Note"}]}, @@ -125,8 +127,7 @@ async def run_scenario(): logger.info(f"Uploading clinical note for Encounter {encounter_id}") try: - doc_action = connector.get_action("create_document_reference") - doc_result = await doc_action.internal_execute(doc_input, trace_id=trace_id) + doc_result = await connector.internal_execute(doc_input, trace_id=trace_id) logger.info(f"SUCCESS! Created DocumentReference: {doc_result.resource_id}") print(f"\nWorkflow Complete. Resource Created: {doc_result.resource_id}") except Exception as e: diff --git a/src/agents/fhir_cerner_mcp.py b/src/agents/fhir_cerner_mcp.py index 5628bd6..d329170 100644 --- a/src/agents/fhir_cerner_mcp.py +++ b/src/agents/fhir_cerner_mcp.py @@ -55,10 +55,8 @@ async def fhir_cerner_read_patient( if not cerner: raise RuntimeError("fhir_cerner connector not configured") - action = cerner.get_action("read_patient") - if patient_id: - params = FhirCernerPatientReadInput(resource_id=patient_id) + params = FhirCernerPatientReadInput(action="read_patient", resource_id=patient_id) elif family_name or given_name: search = { k: v @@ -69,11 +67,11 @@ async def fhir_cerner_read_patient( }.items() if v } - params = FhirCernerPatientReadInput(search_params=search) + params = FhirCernerPatientReadInput(action="read_patient", search_params=search) else: raise ValueError("Provide patient_id OR at least family_name/given_name") - result = await action.internal_execute(params, trace_id=trace_id) + result = await cerner.internal_execute(params, trace_id=trace_id) resource = result.resource name_parts = resource.get("name", [{}])[0] diff --git a/src/agents/fhir_epic_mcp.py b/src/agents/fhir_epic_mcp.py index d7f6335..b196b7a 100644 --- a/src/agents/fhir_epic_mcp.py +++ b/src/agents/fhir_epic_mcp.py @@ -56,10 +56,8 @@ async def fhir_epic_read_patient( if not epic: raise RuntimeError("fhir_epic connector not configured") - action = epic.get_action("read_patient") - if patient_id: - params = FhirEpicPatientReadInput(resource_id=patient_id) + params = FhirEpicPatientReadInput(action="read_patient", resource_id=patient_id) elif family_name or given_name: search = { k: v @@ -70,11 +68,11 @@ async def fhir_epic_read_patient( }.items() if v } - params = FhirEpicPatientReadInput(search_params=search) + params = FhirEpicPatientReadInput(action="read_patient", search_params=search) else: raise ValueError("Provide patient_id OR at least family_name/given_name") - result = await action.internal_execute(params, trace_id=trace_id) + result = await epic.internal_execute(params, trace_id=trace_id) resource = result.resource name_parts = resource.get("name", [{}])[0] diff --git a/src/agents/mcp_entrypoint.py b/src/agents/mcp_entrypoint.py index dee264e..9d974eb 100644 --- a/src/agents/mcp_entrypoint.py +++ b/src/agents/mcp_entrypoint.py @@ -4,12 +4,14 @@ This module is the main entrypoint for the Node Wire MCP server. When run, it exposes healthcare workflow tools via the MCP stdio transport: - • fhir_cerner_read_patient — fetch a single patient from Cerner FHIR R4 - • fhir_cerner_search_patients — fetch multiple patients from Cerner (multi-ID or name) - • fhir_epic_read_patient — fetch a single patient from Epic FHIR R4 - • fhir_epic_search_patients — fetch multiple patients from Epic (multi-ID or name) - • google_drive_upload_file — write a file to Google Drive - • smtp_send_email — send an email via SMTP + • fhir_cerner_read_patient — fetch a patient from Cerner FHIR R4 + • fhir_cerner_search_patients — search multiple patients in Cerner + • fhir_cerner_search_encounters — search encounters in Cerner + • fhir_epic_read_patient — fetch a patient from Epic FHIR R4 + • fhir_epic_search_patients — search multiple patients in Epic + • fhir_epic_search_encounters — search encounters in Epic + • google_drive_upload_file — write a file to Google Drive + • smtp_send_email — send an email via SMTP ToolHive manages the container lifecycle, injects secrets as environment variables, and proxies the stdio MCP stream to HTTP/SSE for clients. @@ -107,12 +109,11 @@ async def fhir_cerner_read_patient( if not cerner: raise RuntimeError("fhir_cerner connector not configured") - action = cerner.get_action("read_patient") - if patient_id: - params = FhirCernerPatientReadInput(resource_id=patient_id) + params = FhirCernerPatientReadInput(action="read_patient", resource_id=patient_id) elif family_name or given_name or name: params = FhirCernerPatientReadInput( + action="read_patient", given_name=given_name or None, family_name=family_name or None, name=name or None, @@ -121,7 +122,7 @@ async def fhir_cerner_read_patient( else: raise ValueError("Provide patient_id OR at least family_name / given_name / name") - result = await action.internal_execute(params, trace_id=trace_id) + result = await cerner.internal_execute(params, trace_id=trace_id) resource = result.resource # Extract a clean summary for the LLM @@ -184,12 +185,11 @@ async def fhir_epic_read_patient( if not epic: raise RuntimeError("fhir_epic connector not configured") - action = epic.get_action("read_patient") - if patient_id: - params = FhirEpicPatientReadInput(resource_id=patient_id) + params = FhirEpicPatientReadInput(action="read_patient", resource_id=patient_id) elif family_name or given_name or name: params = FhirEpicPatientReadInput( + action="read_patient", given_name=given_name or None, family_name=family_name or None, name=name or None, @@ -198,7 +198,7 @@ async def fhir_epic_read_patient( else: raise ValueError("Provide patient_id OR at least family_name / given_name / name") - result = await action.internal_execute(params, trace_id=trace_id) + result = await epic.internal_execute(params, trace_id=trace_id) resource = result.resource # Clean extract for LLM @@ -258,13 +258,12 @@ async def fhir_cerner_search_patients( if not cerner: raise RuntimeError("fhir_cerner connector not configured") - action = cerner.get_action("search_patients") - if patient_ids.strip(): ids = [i.strip() for i in patient_ids.split(",") if i.strip()] - params = FhirCernerPatientSearchInput(resource_ids=ids) + params = FhirCernerPatientSearchInput(action="search_patients", resource_ids=ids) elif family_name or given_name or name or birthdate: params = FhirCernerPatientSearchInput( + action="search_patients", given_name=given_name or None, family_name=family_name or None, name=name or None, @@ -276,7 +275,7 @@ async def fhir_cerner_search_patients( "family_name / given_name / name / birthdate" ) - result = await action.internal_execute(params, trace_id=trace_id) + result = await cerner.internal_execute(params, trace_id=trace_id) summaries = [] for resource in result.resources: @@ -337,13 +336,12 @@ async def fhir_epic_search_patients( if not epic: raise RuntimeError("fhir_epic connector not configured") - action = epic.get_action("search_patients") - if patient_ids.strip(): ids = [i.strip() for i in patient_ids.split(",") if i.strip()] - params = FhirEpicPatientSearchInput(resource_ids=ids) + params = FhirEpicPatientSearchInput(action="search_patients", resource_ids=ids) elif family_name or given_name or name or birthdate: params = FhirEpicPatientSearchInput( + action="search_patients", given_name=given_name or None, family_name=family_name or None, name=name or None, @@ -355,7 +353,7 @@ async def fhir_epic_search_patients( "family_name / given_name / name / birthdate" ) - result = await action.internal_execute(params, trace_id=trace_id) + result = await epic.internal_execute(params, trace_id=trace_id) summaries = [] for resource in result.resources: @@ -408,15 +406,14 @@ async def fhir_cerner_search_encounters( if not cerner: raise RuntimeError("fhir_cerner connector not configured") - action = cerner.get_action("search_encounter") - params = FhirCernerEncounterSearchInput( + action="search_encounter", patient_id=patient_id or None, status=status or None, date=date or None, ) - result = await action.internal_execute(params, trace_id=trace_id) + result = await cerner.internal_execute(params, trace_id=trace_id) summaries = [] for resource in result.resources: @@ -466,15 +463,14 @@ async def fhir_epic_search_encounters( if not epic: raise RuntimeError("fhir_epic connector not configured") - action = epic.get_action("search_encounter") - params = FhirEpicEncounterSearchInput( + action="search_encounter", patient_id=patient_id or None, status=status or None, date=date or None, ) - result = await action.internal_execute(params, trace_id=trace_id) + result = await epic.internal_execute(params, trace_id=trace_id) summaries = [] for resource in result.resources: diff --git a/src/bindings/factory.py b/src/bindings/factory.py index 8a28256..76e4df8 100644 --- a/src/bindings/factory.py +++ b/src/bindings/factory.py @@ -133,7 +133,9 @@ def _instantiate(self, connector_id: str) -> Any: raise ValueError(f"Unknown connector id {connector_id!r}") - def get_for_protocol(self, connector_id: str, protocol: str, action: Optional[str] = None) -> Optional[BaseConnector[Any, Any]]: + def get_for_protocol( + self, connector_id: str, protocol: str, action: Optional[str] = None + ) -> Optional[BaseConnector[Any, Any]]: cfg = self._configs.get(connector_id) if cfg is None: logger.warning( @@ -160,9 +162,11 @@ def get_for_protocol(self, connector_id: str, protocol: str, action: Optional[st if connector is None: return None - # Multi-action connectors (e.g. fhir_epic) expose a get_action() helper. - if action and hasattr(connector, "get_action"): - return connector.get_action(action) + if action: + logger.debug( + "get_for_protocol resolved connector (action from URL is merged into payload by REST)", + extra={"connector_id": connector_id, "protocol": protocol, "action": action}, + ) return connector # type: ignore[return-value] @@ -170,9 +174,5 @@ def list_for_protocol(self, protocol: str) -> List[BaseConnector[Any, Any]]: result: List[BaseConnector[Any, Any]] = [] for connector_id, connector in self._connectors.items(): if protocol in self._configs[connector_id].exposed_via: - # Multi-action connectors expose all their actions via list_actions(). - if hasattr(connector, "list_actions"): - result.extend(connector.list_actions()) - else: - result.append(connector) # type: ignore[arg-type] + result.append(connector) # type: ignore[arg-type] return result diff --git a/src/bindings/mcp_server/server.py b/src/bindings/mcp_server/server.py index ce98707..6b9f0a2 100644 --- a/src/bindings/mcp_server/server.py +++ b/src/bindings/mcp_server/server.py @@ -50,7 +50,11 @@ async def invoke_tool(self, name: str, arguments: Dict[str, Any]) -> Dict[str, A if connector is None: raise ValueError(f"Connector {connector_id!r} is not available via MCP.") - response = await connector.run(arguments) + run_args = dict(arguments) + if connector_id in ("fhir_cerner", "fhir_epic"): + run_args.setdefault("action", action) + + response = await connector.run(run_args) return response.model_dump() diff --git a/src/bindings/rest_api/app.py b/src/bindings/rest_api/app.py index 7d27dcc..f442734 100644 --- a/src/bindings/rest_api/app.py +++ b/src/bindings/rest_api/app.py @@ -73,7 +73,10 @@ def _http_status_for_category(category: ErrorCategory | None) -> int: return 503 return 500 -def _make_endpoint(cid: str, act: str) -> Any: +_FHIR_REST_IDS = frozenset({"fhir_cerner", "fhir_epic"}) + + +def _make_endpoint(cid: str, act: str) -> Any: async def endpoint( payload: Dict[str, Any], factory_dep: ConnectorFactory = Depends(get_factory), @@ -89,9 +92,12 @@ async def endpoint( connector = factory_dep.get_for_protocol(cid, "rest", action=act) if connector is None: raise HTTPException(status_code=404, detail="Connector not available for REST") + run_payload = dict(payload) + if cid in _FHIR_REST_IDS: + run_payload.setdefault("action", act) # Let the runtime (Layer A) perform full schema validation. # Any validation errors will be mapped into ConnectorResponse. - response: ConnectorResponse = await connector.run(payload) + response: ConnectorResponse = await connector.run(run_payload) status = _http_status_for_category(response.error_category) if not response.success: diff --git a/src/connectors/fhir_cerner/logic.py b/src/connectors/fhir_cerner/logic.py index 03cc6b0..c05281c 100644 --- a/src/connectors/fhir_cerner/logic.py +++ b/src/connectors/fhir_cerner/logic.py @@ -6,7 +6,7 @@ import logging import uuid from datetime import datetime, timedelta, timezone -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional import httpx import jwt @@ -21,6 +21,8 @@ FhirCernerDocumentReferenceSearchOutput, FhirCernerEncounterSearchInput, FhirCernerEncounterSearchOutput, + FhirCernerOperationInput, + FhirCernerOperationOutput, FhirCernerPatientReadInput, FhirCernerPatientReadOutput, FhirCernerPatientSearchInput, @@ -30,60 +32,14 @@ logger = logging.getLogger("connectors.fhir_cerner") -class _FhirCernerAction(BaseConnector[Any, Any]): - """ - Lightweight BaseConnector that delegates execution to a FhirCernerConnector - instance method. One of these is created per action so that the manifest - and REST router can discover each action's schema and route automatically. - """ - - connector_id = "fhir_cerner" - - def __init__( - self, - action: str, - input_model: type, - output_model: type, - handler: Callable, - *, - secret_provider: Optional[SecretProvider] = None, - ) -> None: - super().__init__(input_model, output_model, secret_provider=secret_provider) - self.action = action - self._handler = handler - - async def internal_execute(self, params: Any, *, trace_id: str) -> Any: - return await self._handler(params, trace_id=trace_id) - - -class FhirCernerConnector: +class FhirCernerConnector(BaseConnector[FhirCernerOperationInput, FhirCernerOperationOutput]): """ Single FHIR/Cerner connector. - ``connector_id = "fhir_cerner"``. All authentication helpers and action - implementations live here. The factory registers ONE instance of this - class; ``list_actions()`` and ``get_action()`` are used by the factory to - expose each action to the manifest and REST router. - Authentication uses Cerner's SMART Backend Services (private_key_jwt) flow, identical to Epic's implementation — RS384-signed JWT exchanged for an OAuth2 access token at the configured token endpoint. - Supported actions: - • read_patient — fetch a single Patient by ID or name search - • search_patients — fetch multiple Patients by list of IDs or name search - • search_encounter - • create_document_reference - • search_document_reference - - Name-based search parameters (``given_name``, ``family_name``, ``name``, - ``birthdate``) are prioritised over the raw ``search_params`` dict. - - .. note:: - Cerner's sandbox name search is case-sensitive. Supply names exactly - as stored in the system. Special characters in search values should be - URL-encoded (httpx handles this automatically). - Required secrets (configured via SecretProvider): - cerner_fhir_base_url : Cerner FHIR R4 base URL - cerner_private_key : RSA private key PEM (newlines may be escaped) @@ -94,44 +50,26 @@ class FhirCernerConnector: """ connector_id = "fhir_cerner" + action = "execute" def __init__(self, *, secret_provider: SecretProvider) -> None: + super().__init__(FhirCernerOperationInput, FhirCernerOperationOutput, secret_provider=secret_provider) self._secret_provider = secret_provider - self._actions: Dict[str, _FhirCernerAction] = { - "read_patient": _FhirCernerAction( - "read_patient", FhirCernerPatientReadInput, FhirCernerPatientReadOutput, - self._read_patient, secret_provider=secret_provider, - ), - "search_patients": _FhirCernerAction( - "search_patients", FhirCernerPatientSearchInput, FhirCernerPatientSearchOutput, - self._search_patients, secret_provider=secret_provider, - ), - "search_encounter": _FhirCernerAction( - "search_encounter", FhirCernerEncounterSearchInput, FhirCernerEncounterSearchOutput, - self._search_encounter, secret_provider=secret_provider, - ), - "create_document_reference": _FhirCernerAction( - "create_document_reference", FhirCernerDocumentReferenceCreateInput, FhirCernerDocumentReferenceCreateOutput, - self._create_document_reference, secret_provider=secret_provider, - ), - "search_document_reference": _FhirCernerAction( - "search_document_reference", FhirCernerDocumentReferenceSearchInput, FhirCernerDocumentReferenceSearchOutput, - self._search_document_reference, secret_provider=secret_provider, - ), - } - - # ------------------------------------------------------------------ - # Action discovery — consumed by ConnectorFactory - # ------------------------------------------------------------------ - - def list_actions(self) -> List[_FhirCernerAction]: - """Return all registered action connectors (used by list_for_protocol).""" - return list(self._actions.values()) - - def get_action(self, name: str) -> Optional[_FhirCernerAction]: - """Return the action connector for the given action name.""" - return self._actions.get(name) + async def internal_execute(self, params: Any, *, trace_id: str) -> Any: + # Back-compat: allow calling with either the RootModel union or a concrete action input model. + op = params.root if hasattr(params, "root") else params + if op.action == "read_patient": + return await self._read_patient(op, trace_id=trace_id) + if op.action == "search_patients": + return await self._search_patients(op, trace_id=trace_id) + if op.action == "search_encounter": + return await self._search_encounter(op, trace_id=trace_id) + if op.action == "create_document_reference": + return await self._create_document_reference(op, trace_id=trace_id) + if op.action == "search_document_reference": + return await self._search_document_reference(op, trace_id=trace_id) + raise ValueError(f"Unsupported action: {op.action!r}") # ------------------------------------------------------------------ # Shared authentication helpers @@ -251,15 +189,7 @@ def _build_name_search_params( birthdate: Optional[str], extra: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: - """Build a FHIR search params dict from explicit name/date fields. - - Priority: given_name/family_name > name > (nothing). - The ``extra`` dict (raw search_params) is merged at lowest priority. - - .. note:: - Cerner's sandbox name matching is case-sensitive — supply names - with the same capitalisation as stored in the system. - """ + """Build a FHIR search params dict from explicit name/date fields.""" params: Dict[str, str] = dict(extra or {}) if given_name and given_name.strip(): @@ -443,7 +373,7 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ raise data = response.json() - resources = [] + resources: List[Dict[str, Any]] = [] total = data.get("total") if data.get("resourceType") == "Bundle" and data.get("entry"): resources = [e["resource"] for e in data["entry"] if "resource" in e] @@ -508,19 +438,13 @@ async def _create_document_reference( base_url = self._get_base_url() auth_header = await self._get_auth_header() - # Validate context early so callers get the most actionable error. - if params.context: - ctx = dict(params.context) - if ctx.get("encounter") and not ctx.get("period"): - raise ValueError("Cerner requires 'context.period' when 'context.encounter' is provided.") - # Cerner sandbox strictly requires a charset (lowercase, no space) for text types. # Failing to provide it results in: "a character set must be specified" (422). content_type = (params.content_type or "text/plain").strip().lower() if content_type.startswith("text/"): + content_type = content_type.replace(" ", "") if "charset=" not in content_type: - # Match the formatting expected by tests and common HTTP conventions. - content_type = f"{content_type}; charset=UTF-8" + content_type = f"{content_type};charset=utf-8" attachment: Dict[str, Any] = {"contentType": content_type} if params.data: @@ -530,8 +454,12 @@ async def _create_document_reference( else: raise ValueError("Either 'text' or 'data' must be provided") - # Some Cerner tenants require title/creation; default safely when omitted. - attachment["title"] = params.attachment_title or "Document" + # Cerner requires title and creation on the attachment + if not params.attachment_title: + raise ValueError( + "Cerner requires 'attachment_title' on DocumentReference create." + ) + attachment["title"] = params.attachment_title attachment["creation"] = params.attachment_creation or datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z") doc_ref: Dict[str, Any] = { @@ -593,7 +521,22 @@ async def _create_document_reference( # Note: 'description' is intentionally omitted by default # as Cerner can reject it depending on tenant configuration. if params.context: - doc_ref["context"] = dict(params.context) + context = dict(params.context) + # Cerner REQUIRES context.period whenever context.encounter is set. + # Auto-inject a period using the document date if the caller didn't supply one. + if context.get("encounter") and not context.get("period"): + # Force .000Z precision and provide a 1-hour clinical window + start_dt = datetime.now(tz=timezone.utc) + end_dt = start_dt + timedelta(hours=1) + context["period"] = { + "start": start_dt.strftime("%Y-%m-%dT%H:%M:%S.000Z"), + "end": end_dt.strftime("%Y-%m-%dT%H:%M:%S.000Z"), + } + logger.debug( + "Auto-injected context.period (required by Cerner when encounter is set)", + extra={"trace_id": trace_id}, + ) + doc_ref["context"] = context if params.additional_fields: doc_ref.update(params.additional_fields) @@ -602,9 +545,12 @@ async def _create_document_reference( for field in ["text", "data", "content_type", "attachment_title", "attachment_creation", "doc_status"]: doc_ref.pop(field, None) - # Note: Some Cerner tenants require author/authenticator. The connector does not - # enforce those fields universally; tenants that require them will return 4xx - # with OperationOutcome diagnostics. + # Cerner requires at least one author for clinical note document types. + if not params.author: + raise ValueError( + "Cerner requires 'author' for clinical note document types. " + "Provide at least one author reference, e.g. [{'reference': 'Practitioner/{id}'}]" + ) logger.info("FHIR DocumentReference create", extra={"trace_id": trace_id}) diff --git a/src/connectors/fhir_cerner/schema.py b/src/connectors/fhir_cerner/schema.py index eba29c1..e24d915 100644 --- a/src/connectors/fhir_cerner/schema.py +++ b/src/connectors/fhir_cerner/schema.py @@ -1,16 +1,19 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import Annotated, Any, Dict, List, Literal, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field, RootModel # --------------------------------------------------------------------------- -# Patient – Read (single patient by ID or name search) +# Patient – Read # --------------------------------------------------------------------------- class FhirCernerPatientReadInput(BaseModel): - """Input for reading a single FHIR Patient resource from Cerner.""" + """Input for reading a FHIR Patient resource from Cerner.""" + + action: Literal["read_patient"] = "read_patient" + """Action discriminator (one endpoint, multiple actions pattern).""" resource_id: Optional[str] = None """Direct Patient ID lookup (e.g. '12345678').""" @@ -23,21 +26,13 @@ class FhirCernerPatientReadInput(BaseModel): """Patient family / last name (used in name-based search).""" name: Optional[str] = None - """Full or partial name string — mapped to FHIR 'name' search parameter. - - Use this when you only have a single combined name string. When both - ``name`` and ``given_name``/``family_name`` are set, the explicit given/ - family fields take precedence. - """ + """Full or partial name string — mapped to FHIR 'name' search parameter.""" birthdate: Optional[str] = None """Date of birth in YYYY-MM-DD format — used alongside name search.""" search_params: Optional[Dict[str, str]] = None - """Raw FHIR search parameters (e.g. {\"family\": \"Smith\", \"given\": \"John\"}). - - Lowest priority — used only when no ID or explicit name fields are set. - """ + """Raw FHIR search parameters (e.g. {"family": "Smith", "given": "John"}).""" class FhirCernerPatientReadOutput(BaseModel): @@ -52,65 +47,27 @@ class FhirCernerPatientReadOutput(BaseModel): # --------------------------------------------------------------------------- class FhirCernerPatientSearchInput(BaseModel): - """Input for searching / fetching multiple FHIR Patient resources from Cerner. - - Two modes are supported: - - 1. **Multi-ID lookup** — pass ``resource_ids`` (list of Patient IDs). - Each ID is fetched concurrently; partial failures are captured in - ``FhirCernerPatientSearchOutput.errors`` rather than raising globally. - - 2. **Name-based search** — pass ``given_name``, ``family_name``, ``name``, - and/or ``birthdate``. A single FHIR search request is issued and all - matching Bundle entries are returned. - - Only one mode should be used per request. If ``resource_ids`` is set it - takes priority over the name/search fields. + """Input for searching / fetching multiple FHIR Patient resources from Cerner.""" - .. note:: - Cerner's sandbox name search is case-sensitive. Use the exact - capitalisation stored in the system (e.g. ``family_name="Smith"`` not - ``"smith"``). The ``name`` parameter maps to the standard FHIR - ``name`` token which Cerner supports as a partial-match. - """ + action: Literal["search_patients"] = "search_patients" + """Action discriminator (one endpoint, multiple actions pattern).""" resource_ids: Optional[List[str]] = None - """List of Cerner Patient IDs to fetch concurrently (e.g. ['12345678', '87654321']).""" + """List of Cerner Patient IDs to fetch concurrently.""" given_name: Optional[str] = None - """Patient given / first name.""" - family_name: Optional[str] = None - """Patient family / last name.""" - name: Optional[str] = None - """Full or partial name string — mapped to FHIR 'name' search parameter.""" - birthdate: Optional[str] = None - """Date of birth in YYYY-MM-DD format.""" - search_params: Optional[Dict[str, str]] = None - """Additional raw FHIR search parameters merged with the name fields.""" class FhirCernerPatientSearchOutput(BaseModel): """Output for searching multiple FHIR Patient resources from Cerner.""" resources: List[Dict[str, Any]] - """List of successfully retrieved FHIR Patient JSON objects.""" - total: Optional[int] = None - """Total number of matches reported by the server Bundle (name-search mode).""" - - errors: List[Dict[str, Any]] = [] - """Per-ID errors encountered during multi-ID fan-out. - - Each entry has the shape:: - - {"resource_id": "", "error": ""} - - An empty list means all lookups succeeded. - """ + errors: List[Dict[str, Any]] = Field(default_factory=list) # --------------------------------------------------------------------------- @@ -120,6 +77,9 @@ class FhirCernerPatientSearchOutput(BaseModel): class FhirCernerEncounterSearchInput(BaseModel): """Input for searching FHIR Encounter resources in Cerner.""" + action: Literal["search_encounter"] = "search_encounter" + """Action discriminator (one endpoint, multiple actions pattern).""" + patient_id: Optional[str] = None """Cerner Patient ID to find encounters for (maps to 'patient' FHIR param).""" @@ -150,6 +110,9 @@ class FhirCernerEncounterSearchOutput(BaseModel): class FhirCernerDocumentReferenceCreateInput(BaseModel): """Input for creating a FHIR DocumentReference resource in Cerner.""" + action: Literal["create_document_reference"] = "create_document_reference" + """Action discriminator (one endpoint, multiple actions pattern).""" + identifier: Optional[list[Dict[str, Any]]] = None """Document identifier. @@ -295,8 +258,11 @@ class FhirCernerDocumentReferenceCreateOutput(BaseModel): class FhirCernerDocumentReferenceSearchInput(BaseModel): """Input for searching FHIR DocumentReference resources in Cerner.""" + action: Literal["search_document_reference"] = "search_document_reference" + """Action discriminator (one endpoint, multiple actions pattern).""" + search_params: Dict[str, str] - """Search parameters (e.g. {\"patient\": \"12345678\"}).""" + """Search parameters (e.g. {"patient": "12345678"}).""" class FhirCernerDocumentReferenceSearchOutput(BaseModel): @@ -307,3 +273,35 @@ class FhirCernerDocumentReferenceSearchOutput(BaseModel): total: Optional[int] = None """Total number of results reported by the Bundle.""" + + +# --------------------------------------------------------------------------- +# Unified operation input/output (one endpoint, multiple actions) +# --------------------------------------------------------------------------- + +_FhirCernerOperationUnion = Annotated[ + Union[ + FhirCernerPatientReadInput, + FhirCernerPatientSearchInput, + FhirCernerEncounterSearchInput, + FhirCernerDocumentReferenceCreateInput, + FhirCernerDocumentReferenceSearchInput, + ], + Field(discriminator="action"), +] + +FhirCernerOperationInput = RootModel[_FhirCernerOperationUnion] + + +class FhirCernerOperationOutput(BaseModel): + """ + Combined output shape for schema documentation/manifest generation. + + Individual handlers still return their specific output models. + """ + + resource: Optional[Dict[str, Any]] = None + resources: Optional[list[Dict[str, Any]]] = None + total: Optional[int] = None + resource_id: Optional[str] = None + errors: Optional[list[Dict[str, Any]]] = None diff --git a/src/connectors/fhir_epic/logic.py b/src/connectors/fhir_epic/logic.py index e9cc615..5cbe8c3 100644 --- a/src/connectors/fhir_epic/logic.py +++ b/src/connectors/fhir_epic/logic.py @@ -6,7 +6,7 @@ import logging import uuid from datetime import datetime, timezone -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional import httpx import jwt @@ -20,6 +20,8 @@ FhirDocumentReferenceSearchOutput, FhirEncounterSearchInput, FhirEncounterSearchOutput, + FhirEpicOperationInput, + FhirEpicOperationOutput, FhirPatientReadInput, FhirPatientReadOutput, FhirPatientSearchInput, @@ -29,92 +31,35 @@ logger = logging.getLogger("connectors.fhir_epic") -class _FhirAction(BaseConnector[Any, Any]): - """ - Lightweight BaseConnector that delegates execution to a FhirEpicConnector - instance method. One of these is created per action so that the manifest - and REST router can discover each action's schema and route automatically. - """ - - connector_id = "fhir_epic" - - def __init__( - self, - action: str, - input_model: type, - output_model: type, - handler: Callable, - *, - secret_provider: Optional[SecretProvider] = None, - ) -> None: - super().__init__(input_model, output_model, secret_provider=secret_provider) - self.action = action # instance attribute, overrides absent class-level action - self._handler = handler - - async def internal_execute(self, params: Any, *, trace_id: str) -> Any: - return await self._handler(params, trace_id=trace_id) - - -class FhirEpicConnector: +class FhirEpicConnector(BaseConnector[FhirEpicOperationInput, FhirEpicOperationOutput]): """ Single FHIR/Epic connector. - ``connector_id = "fhir_epic"``. All authentication helpers and action - implementations live here. The factory registers ONE instance of this - class; ``list_actions()`` and ``get_action()`` are used by the factory to - expose each action to the manifest and REST router. - - Supported actions: - • read_patient — fetch a single Patient by ID or name search - • search_patients — fetch multiple Patients by list of IDs or name search - • search_encounter - • create_document_reference - • search_document_reference - - Name-based search parameters (``given_name``, ``family_name``, ``name``, - ``birthdate``) are prioritised over the raw ``search_params`` dict and are - normalised (stripped, lowercased for ``name`` token search). + Exposes one endpoint (`execute`) and dispatches actions via the + `action` discriminator on the request payload. """ connector_id = "fhir_epic" + action = "execute" def __init__(self, *, secret_provider: SecretProvider) -> None: + super().__init__(FhirEpicOperationInput, FhirEpicOperationOutput, secret_provider=secret_provider) self._secret_provider = secret_provider - self._actions: Dict[str, _FhirAction] = { - "read_patient": _FhirAction( - "read_patient", FhirPatientReadInput, FhirPatientReadOutput, - self._read_patient, secret_provider=secret_provider, - ), - "search_patients": _FhirAction( - "search_patients", FhirPatientSearchInput, FhirPatientSearchOutput, - self._search_patients, secret_provider=secret_provider, - ), - "search_encounter": _FhirAction( - "search_encounter", FhirEncounterSearchInput, FhirEncounterSearchOutput, - self._search_encounter, secret_provider=secret_provider, - ), - "create_document_reference": _FhirAction( - "create_document_reference", FhirDocumentReferenceCreateInput, FhirDocumentReferenceCreateOutput, - self._create_document_reference, secret_provider=secret_provider, - ), - "search_document_reference": _FhirAction( - "search_document_reference", FhirDocumentReferenceSearchInput, FhirDocumentReferenceSearchOutput, - self._search_document_reference, secret_provider=secret_provider, - ), - } - - # ------------------------------------------------------------------ - # Action discovery — consumed by ConnectorFactory - # ------------------------------------------------------------------ - - def list_actions(self) -> List[_FhirAction]: - """Return all registered action connectors (used by list_for_protocol).""" - return list(self._actions.values()) - - def get_action(self, name: str) -> Optional[_FhirAction]: - """Return the action connector for the given action name.""" - return self._actions.get(name) + async def internal_execute(self, params: Any, *, trace_id: str) -> Any: + # Back-compat: allow calling with either the RootModel union or a concrete action input model. + op = params.root if hasattr(params, "root") else params + if op.action == "read_patient": + return await self._read_patient(op, trace_id=trace_id) + if op.action == "search_patients": + return await self._search_patients(op, trace_id=trace_id) + if op.action == "search_encounter": + return await self._search_encounter(op, trace_id=trace_id) + if op.action == "create_document_reference": + return await self._create_document_reference(op, trace_id=trace_id) + if op.action == "search_document_reference": + return await self._search_document_reference(op, trace_id=trace_id) + raise ValueError(f"Unsupported action: {op.action!r}") # ------------------------------------------------------------------ # Shared authentication helpers @@ -194,22 +139,14 @@ def _build_name_search_params( birthdate: Optional[str], extra: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: - """Build a FHIR search params dict from explicit name/date fields. - - Priority: given_name/family_name > name > (nothing). - The ``extra`` dict (raw search_params) is merged at lowest priority so - callers can pass additional filters without overriding name fields. - """ + """Build a FHIR search params dict from explicit name/date fields.""" params: Dict[str, str] = dict(extra or {}) - # Normalize: strip whitespace; FHIR name search is typically case-insensitive - # on compliant servers but we preserve original case per FHIR spec. if given_name and given_name.strip(): params["given"] = given_name.strip() if family_name and family_name.strip(): params["family"] = family_name.strip() if name and name.strip() and "given" not in params and "family" not in params: - # Only fall back to the combined 'name' token when no split fields given params["name"] = name.strip() if birthdate and birthdate.strip(): params["birthdate"] = birthdate.strip() @@ -296,7 +233,6 @@ async def _search_patients( base_url = self._get_base_url() auth_header = await self._get_auth_header() - # ---- Mode 1: Multi-ID fan-out ---- if params.resource_ids: ids = [rid.strip() for rid in params.resource_ids if rid.strip()] if not ids: @@ -309,7 +245,6 @@ async def _search_patients( ) async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[str]]: - """Return (rid, resource_or_None, error_or_None).""" try: async with httpx.AsyncClient() as client: resp = await client.get( @@ -344,7 +279,6 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ ) return FhirPatientSearchOutput(resources=resources, total=len(resources), errors=errors) - # ---- Mode 2: Name-based search (returns Bundle) ---- name_params = self._build_name_search_params( params.given_name, params.family_name, params.name, params.birthdate, params.search_params, @@ -386,7 +320,7 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ raise data = response.json() - resources = [] + resources: List[Dict[str, Any]] = [] total = data.get("total") if data.get("resourceType") == "Bundle" and data.get("entry"): resources = [e["resource"] for e in data["entry"] if "resource" in e] diff --git a/src/connectors/fhir_epic/schema.py b/src/connectors/fhir_epic/schema.py index 99aa9b5..eaeef26 100644 --- a/src/connectors/fhir_epic/schema.py +++ b/src/connectors/fhir_epic/schema.py @@ -1,43 +1,30 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import Annotated, Any, Dict, List, Literal, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field, RootModel # --------------------------------------------------------------------------- -# Patient – Read (single patient by ID or name search) +# Patient – Read # --------------------------------------------------------------------------- class FhirPatientReadInput(BaseModel): - """Input for reading a single FHIR Patient resource from Epic.""" + """Input for reading a FHIR Patient resource.""" + + action: Literal["read_patient"] = "read_patient" + """Action discriminator (one endpoint, multiple actions pattern).""" resource_id: Optional[str] = None """Direct Patient ID lookup (e.g. 'eXYZ123').""" - # Convenience name fields — take priority over raw search_params when set. given_name: Optional[str] = None - """Patient given / first name (used in name-based search).""" - family_name: Optional[str] = None - """Patient family / last name (used in name-based search).""" - name: Optional[str] = None - """Full or partial name string — mapped to FHIR 'name' search parameter. - - Use this when you only have a single combined name string. When both - ``name`` and ``given_name``/``family_name`` are set, the explicit given/ - family fields take precedence. - """ - birthdate: Optional[str] = None - """Date of birth in YYYY-MM-DD format — used alongside name search.""" search_params: Optional[Dict[str, str]] = None - """Raw FHIR search parameters (e.g. {\"family\": \"Smith\", \"given\": \"John\"}). - - Lowest priority — used only when no ID or explicit name fields are set. - """ + """Search parameters (e.g. {"family": "Smith", "given": "John"}).""" class FhirPatientReadOutput(BaseModel): @@ -52,59 +39,25 @@ class FhirPatientReadOutput(BaseModel): # --------------------------------------------------------------------------- class FhirPatientSearchInput(BaseModel): - """Input for searching / fetching multiple FHIR Patient resources from Epic. - - Two modes are supported: + """Input for searching / fetching multiple FHIR Patient resources from Epic.""" - 1. **Multi-ID lookup** — pass ``resource_ids`` (list of Patient IDs). - Each ID is fetched concurrently; partial failures are captured in - ``FhirPatientSearchOutput.errors`` rather than raising globally. - - 2. **Name-based search** — pass ``given_name``, ``family_name``, ``name``, - and/or ``birthdate``. A single FHIR search request is issued and all - matching Bundle entries are returned. - - Only one mode should be used per request. If ``resource_ids`` is set it - takes priority over the name/search fields. - """ + action: Literal["search_patients"] = "search_patients" + """Action discriminator (one endpoint, multiple actions pattern).""" resource_ids: Optional[List[str]] = None - """List of Epic Patient IDs to fetch concurrently (e.g. ['eABC', 'eDEF']).""" - given_name: Optional[str] = None - """Patient given / first name.""" - family_name: Optional[str] = None - """Patient family / last name.""" - name: Optional[str] = None - """Full or partial name string — mapped to FHIR 'name' search parameter.""" - birthdate: Optional[str] = None - """Date of birth in YYYY-MM-DD format.""" - search_params: Optional[Dict[str, str]] = None - """Additional raw FHIR search parameters merged with the name fields.""" class FhirPatientSearchOutput(BaseModel): """Output for searching multiple FHIR Patient resources.""" resources: List[Dict[str, Any]] - """List of successfully retrieved FHIR Patient JSON objects.""" - total: Optional[int] = None - """Total number of matches reported by the server Bundle (name-search mode).""" - - errors: List[Dict[str, Any]] = [] - """Per-ID errors encountered during multi-ID fan-out. - - Each entry has the shape:: - - {"resource_id": "", "error": ""} - - An empty list means all lookups succeeded. - """ + errors: List[Dict[str, Any]] = Field(default_factory=list) # --------------------------------------------------------------------------- @@ -114,17 +67,13 @@ class FhirPatientSearchOutput(BaseModel): class FhirEncounterSearchInput(BaseModel): """Input for searching FHIR Encounter resources.""" - patient_id: Optional[str] = None - """FHIR Patient ID to find encounters for (maps to 'patient' FHIR param).""" + action: Literal["search_encounter"] = "search_encounter" + """Action discriminator (one endpoint, multiple actions pattern).""" + patient_id: Optional[str] = None status: Optional[str] = None - """Status of the encounters to find (e.g. 'finished', 'arrived').""" - date: Optional[str] = None - """Date or date range for the encounters (e.g. '2024', 'gt2023-01-01').""" - search_params: Optional[Dict[str, str]] = None - """Raw FHIR search parameters. Used if explicit fields above are not provided.""" class FhirEncounterSearchOutput(BaseModel): @@ -144,6 +93,9 @@ class FhirEncounterSearchOutput(BaseModel): class FhirDocumentReferenceCreateInput(BaseModel): """Input for creating a FHIR DocumentReference resource.""" + action: Literal["create_document_reference"] = "create_document_reference" + """Action discriminator (one endpoint, multiple actions pattern).""" + identifier: list[Dict[str, Any]] """Document identifier.""" @@ -213,8 +165,11 @@ class FhirDocumentReferenceCreateOutput(BaseModel): class FhirDocumentReferenceSearchInput(BaseModel): """Input for searching FHIR DocumentReference resources.""" + action: Literal["search_document_reference"] = "search_document_reference" + """Action discriminator (one endpoint, multiple actions pattern).""" + search_params: Dict[str, str] - """Search parameters (e.g. {\"patient\": \"eXYZ123\"}).""" + """Search parameters (e.g. {"patient": "eXYZ123"}).""" class FhirDocumentReferenceSearchOutput(BaseModel): @@ -224,4 +179,36 @@ class FhirDocumentReferenceSearchOutput(BaseModel): """The list of raw FHIR DocumentReference JSON objects found.""" total: Optional[int] = None - """Total number of results reported by the Bundle.""" \ No newline at end of file + """Total number of results reported by the Bundle.""" + + +# --------------------------------------------------------------------------- +# Unified operation input/output (one endpoint, multiple actions) +# --------------------------------------------------------------------------- + +_FhirEpicOperationUnion = Annotated[ + Union[ + FhirPatientReadInput, + FhirPatientSearchInput, + FhirEncounterSearchInput, + FhirDocumentReferenceCreateInput, + FhirDocumentReferenceSearchInput, + ], + Field(discriminator="action"), +] + +FhirEpicOperationInput = RootModel[_FhirEpicOperationUnion] + + +class FhirEpicOperationOutput(BaseModel): + """ + Combined output shape for schema documentation/manifest generation. + + Individual handlers still return their specific output models. + """ + + resource: Optional[Dict[str, Any]] = None + resources: Optional[list[Dict[str, Any]]] = None + total: Optional[int] = None + resource_id: Optional[str] = None + errors: Optional[list[Dict[str, Any]]] = None diff --git a/src/connectors/manifest.py b/src/connectors/manifest.py index a13f9f5..2033234 100644 --- a/src/connectors/manifest.py +++ b/src/connectors/manifest.py @@ -6,6 +6,25 @@ from runtime import BaseConnector +# FHIR connectors expose a single `execute` entrypoint with a discriminated `action` +# field; expand these for REST/MCP discovery so routes remain per-operation. +_FHIR_DISCRIMINATED_ACTIONS: Dict[str, List[str]] = { + "fhir_cerner": [ + "read_patient", + "search_patients", + "search_encounter", + "create_document_reference", + "search_document_reference", + ], + "fhir_epic": [ + "read_patient", + "search_patients", + "search_encounter", + "create_document_reference", + "search_document_reference", + ], +} + def _schema_for(model: Type[BaseModel]) -> Dict[str, Any]: return model.model_json_schema() @@ -23,13 +42,25 @@ def build_manifest(connectors: List[BaseConnector[Any, Any]]) -> List[Dict[str, for connector in connectors: input_model = connector._input_model_cls # type: ignore[attr-defined] output_model = connector._output_model_cls # type: ignore[attr-defined] - manifest.append( - { - "connector_id": connector.connector_id, - "action": connector.action, - "input_schema": _schema_for(input_model), - "output_schema": _schema_for(output_model), - } - ) + cid = connector.connector_id + if cid in _FHIR_DISCRIMINATED_ACTIONS and getattr(connector, "action", None) == "execute": + for sub_action in _FHIR_DISCRIMINATED_ACTIONS[cid]: + manifest.append( + { + "connector_id": cid, + "action": sub_action, + "input_schema": _schema_for(input_model), + "output_schema": _schema_for(output_model), + } + ) + else: + manifest.append( + { + "connector_id": cid, + "action": connector.action, + "input_schema": _schema_for(input_model), + "output_schema": _schema_for(output_model), + } + ) return manifest diff --git a/tests/test_fhir_cerner.py b/tests/test_fhir_cerner.py index a48eb72..903a927 100644 --- a/tests/test_fhir_cerner.py +++ b/tests/test_fhir_cerner.py @@ -36,18 +36,13 @@ def _connector() -> FhirCernerConnector: # --------------------------------------------------------------------------- -# Sanity: connector exposes all 5 actions +# Sanity: unified connector (single execute entrypoint) # --------------------------------------------------------------------------- -def test_fhir_cerner_connector_exposes_five_actions(): +def test_fhir_cerner_connector_is_unified_execute(): c = _connector() - actions = {a.action for a in c.list_actions()} - assert actions == { - "read_patient", "search_patients", - "search_encounter", "create_document_reference", "search_document_reference", - } - for name in actions: - assert c.get_action(name) is not None + assert c.connector_id == "fhir_cerner" + assert c.action == "execute" # --------------------------------------------------------------------------- @@ -56,9 +51,9 @@ def test_fhir_cerner_connector_exposes_five_actions(): @pytest.mark.asyncio async def test_fhir_cerner_read_patient_by_id(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - params = FhirCernerPatientReadInput(resource_id="12345678") + params = FhirCernerPatientReadInput(action="read_patient", resource_id="12345678") patient_response = MagicMock() patient_response.status_code = 200 @@ -67,7 +62,7 @@ async def test_fhir_cerner_read_patient_by_id(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "12345678" assert result.resource["name"][0]["family"] == "Smith" @@ -79,9 +74,12 @@ async def test_fhir_cerner_read_patient_by_id(): @pytest.mark.asyncio async def test_fhir_cerner_read_patient_by_search(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - params = FhirCernerPatientReadInput(search_params={"family": "Smith", "given": "John"}) + params = FhirCernerPatientReadInput( + action="read_patient", + search_params={"family": "Smith", "given": "John"}, + ) patient_response = MagicMock() patient_response.status_code = 200 @@ -93,7 +91,7 @@ async def test_fhir_cerner_read_patient_by_search(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "99887766" @@ -104,9 +102,14 @@ async def test_fhir_cerner_read_patient_by_search(): @pytest.mark.asyncio async def test_fhir_cerner_read_patient_by_explicit_name_fields(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - params = FhirCernerPatientReadInput(given_name=" Jane ", family_name="Doe", birthdate="1990-06-15") + params = FhirCernerPatientReadInput( + action="read_patient", + given_name=" Jane ", + family_name="Doe", + birthdate="1990-06-15", + ) patient_response = MagicMock() patient_response.status_code = 200 @@ -118,7 +121,7 @@ async def test_fhir_cerner_read_patient_by_explicit_name_fields(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "55551234" call_kwargs = mock_get.call_args @@ -134,9 +137,9 @@ async def test_fhir_cerner_read_patient_by_explicit_name_fields(): @pytest.mark.asyncio async def test_fhir_cerner_read_patient_by_name_field(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - params = FhirCernerPatientReadInput(name="Johnson") + params = FhirCernerPatientReadInput(action="read_patient", name="Johnson") patient_response = MagicMock() patient_response.status_code = 200 @@ -148,7 +151,7 @@ async def test_fhir_cerner_read_patient_by_name_field(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "99990001" call_kwargs = mock_get.call_args @@ -162,14 +165,14 @@ async def test_fhir_cerner_read_patient_by_name_field(): @pytest.mark.asyncio async def test_fhir_cerner_read_patient_no_params_raises(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - params = FhirCernerPatientReadInput() - + params = FhirCernerPatientReadInput(action="read_patient") + with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()): with pytest.raises(ValueError, match="Provide resource_id"): - await action.internal_execute(params, trace_id="test-trace") + await c.internal_execute(params, trace_id="test-trace") # --------------------------------------------------------------------------- @@ -178,9 +181,12 @@ async def test_fhir_cerner_read_patient_no_params_raises(): @pytest.mark.asyncio async def test_fhir_cerner_search_patients_multi_id(): - action = _connector().get_action("search_patients") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientSearchInput - params = FhirCernerPatientSearchInput(resource_ids=["11111111", "22222222"]) + params = FhirCernerPatientSearchInput( + action="search_patients", + resource_ids=["11111111", "22222222"], + ) def _patient_resp(pid: str) -> MagicMock: m = MagicMock() @@ -193,7 +199,7 @@ def _patient_resp(pid: str) -> MagicMock: with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=responses): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") ids = {r["id"] for r in result.resources} assert ids == {"11111111", "22222222"} @@ -207,9 +213,12 @@ def _patient_resp(pid: str) -> MagicMock: @pytest.mark.asyncio async def test_fhir_cerner_search_patients_partial_failure(): - action = _connector().get_action("search_patients") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientSearchInput - params = FhirCernerPatientSearchInput(resource_ids=["99999999", "00000000"]) + params = FhirCernerPatientSearchInput( + action="search_patients", + resource_ids=["99999999", "00000000"], + ) good_resp = MagicMock() good_resp.status_code = 200 @@ -222,7 +231,7 @@ async def test_fhir_cerner_search_patients_partial_failure(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=[good_resp, bad_resp]): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert len(result.resources) == 1 assert result.resources[0]["id"] == "99999999" @@ -236,9 +245,9 @@ async def test_fhir_cerner_search_patients_partial_failure(): @pytest.mark.asyncio async def test_fhir_cerner_search_patients_by_name(): - action = _connector().get_action("search_patients") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientSearchInput - params = FhirCernerPatientSearchInput(family_name="Smith") + params = FhirCernerPatientSearchInput(action="search_patients", family_name="Smith") bundle_resp = MagicMock() bundle_resp.status_code = 200 @@ -254,7 +263,7 @@ async def test_fhir_cerner_search_patients_by_name(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=bundle_resp) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 2 assert len(result.resources) == 2 @@ -270,14 +279,14 @@ async def test_fhir_cerner_search_patients_by_name(): @pytest.mark.asyncio async def test_fhir_cerner_search_patients_no_params_raises(): - action = _connector().get_action("search_patients") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientSearchInput - params = FhirCernerPatientSearchInput() - + params = FhirCernerPatientSearchInput(action="search_patients") + with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()): with pytest.raises(ValueError): - await action.internal_execute(params, trace_id="test-trace") + await c.internal_execute(params, trace_id="test-trace") # --------------------------------------------------------------------------- @@ -286,9 +295,12 @@ async def test_fhir_cerner_search_patients_no_params_raises(): @pytest.mark.asyncio async def test_fhir_cerner_search_encounter(): - action = _connector().get_action("search_encounter") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerEncounterSearchInput - params = FhirCernerEncounterSearchInput(search_params={"patient": "12345678", "status": "finished"}) + params = FhirCernerEncounterSearchInput( + action="search_encounter", + search_params={"patient": "12345678", "status": "finished"}, + ) enc_response = MagicMock() enc_response.status_code = 200 @@ -303,7 +315,7 @@ async def test_fhir_cerner_search_encounter(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=enc_response): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 2 assert result.resources[0]["id"] == "enc-1" @@ -311,9 +323,9 @@ async def test_fhir_cerner_search_encounter(): @pytest.mark.asyncio async def test_fhir_cerner_search_encounter_by_patient(): - action = _connector().get_action("search_encounter") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerEncounterSearchInput - params = FhirCernerEncounterSearchInput(patient_id="12345678") + params = FhirCernerEncounterSearchInput(action="search_encounter", patient_id="12345678") enc_response = MagicMock() enc_response.status_code = 200 @@ -325,7 +337,7 @@ async def test_fhir_cerner_search_encounter_by_patient(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=enc_response) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 1 assert result.resources[0]["id"] == "enc-1" @@ -340,14 +352,26 @@ async def test_fhir_cerner_search_encounter_by_patient(): @pytest.mark.asyncio async def test_fhir_cerner_create_document_reference(): - action = _connector().get_action("create_document_reference") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerDocumentReferenceCreateInput params = FhirCernerDocumentReferenceCreateInput( + action="create_document_reference", identifier=[{"system": "urn:oid:1.2.3", "value": "ID.123"}], status="current", - type={"coding": [{"system": "urn:oid:4.5.6", "code": "18100", "display": "Employer Group Scan"}]}, + doc_status="final", + type={ + "coding": [{ + "system": "urn:oid:4.5.6", + "code": "18100", + "display": "Employer Group Scan", + "userSelected": True, + }], + "text": "Employer Group Scan", + }, subject="Patient/12724066", data="dGVzdA==", + attachment_title="Document", + author=[{"reference": "Practitioner/p1"}], context={ "encounter": [{"reference": "Encounter/enc-1"}], "period": {"start": "2024-01-01T00:00:00Z", "end": "2024-01-01T01:00:00Z"}, @@ -363,14 +387,14 @@ async def test_fhir_cerner_create_document_reference(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post: mock_post.side_effect = [_token_mock(), create_response] - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource_id == "doc-456" _, kwargs = mock_post.call_args_list[1] assert kwargs["json"]["resourceType"] == "DocumentReference" assert kwargs["json"]["subject"] == {"reference": "Patient/12724066"} # Verify that charset was added to contentType - assert kwargs["json"]["content"][0]["attachment"]["contentType"] == "text/plain; charset=UTF-8" + assert kwargs["json"]["content"][0]["attachment"]["contentType"] == "text/plain;charset=utf-8" # --------------------------------------------------------------------------- @@ -379,9 +403,12 @@ async def test_fhir_cerner_create_document_reference(): @pytest.mark.asyncio async def test_fhir_cerner_search_document_reference(): - action = _connector().get_action("search_document_reference") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerDocumentReferenceSearchInput - params = FhirCernerDocumentReferenceSearchInput(search_params={"patient": "12345678"}) + params = FhirCernerDocumentReferenceSearchInput( + action="search_document_reference", + search_params={"patient": "12345678"}, + ) search_response = MagicMock() search_response.status_code = 200 @@ -394,7 +421,7 @@ async def test_fhir_cerner_search_document_reference(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=search_response): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 1 assert result.resources[0]["id"] == "doc-789" @@ -407,18 +434,22 @@ async def test_fhir_cerner_search_document_reference(): @pytest.mark.asyncio async def test_fhir_cerner_create_document_reference_validation(): """Verify that ValueError is raised when period is missing but encounter is present.""" - action = _connector().get_action("create_document_reference") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerDocumentReferenceCreateInput params = FhirCernerDocumentReferenceCreateInput( + action="create_document_reference", identifier=[{"system": "urn:oid:1.2.3", "value": "ID.123"}], status="current", + doc_status="final", type={"coding": [{"system": "http://loinc.org", "code": "11488-4"}]}, subject="Patient/12724066", data="dGVzdA==", + attachment_title="Doc", + author=[{"reference": "Practitioner/p1"}], context={"encounter": [{"reference": "Encounter/enc-1"}]}, ) with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()): with pytest.raises(ValueError, match="Cerner requires the proprietary CodeSet 72 system"): - await action.internal_execute(params, trace_id="test-trace") + await c.internal_execute(params, trace_id="test-trace") diff --git a/tests/test_fhir_epic.py b/tests/test_fhir_epic.py index b076da5..38eee40 100644 --- a/tests/test_fhir_epic.py +++ b/tests/test_fhir_epic.py @@ -36,18 +36,13 @@ def _connector() -> FhirEpicConnector: # --------------------------------------------------------------------------- -# Sanity: connector exposes all 5 actions +# Sanity: unified connector (single execute entrypoint) # --------------------------------------------------------------------------- -def test_fhir_epic_connector_exposes_five_actions(): +def test_fhir_epic_connector_is_unified_execute(): c = _connector() - actions = {a.action for a in c.list_actions()} - assert actions == { - "read_patient", "search_patients", - "search_encounter", "create_document_reference", "search_document_reference", - } - for name in actions: - assert c.get_action(name) is not None + assert c.connector_id == "fhir_epic" + assert c.action == "execute" # --------------------------------------------------------------------------- @@ -56,9 +51,9 @@ def test_fhir_epic_connector_exposes_five_actions(): @pytest.mark.asyncio async def test_fhir_epic_read_patient_by_id(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_epic.schema import FhirPatientReadInput - params = FhirPatientReadInput(resource_id="eXYZ123") + params = FhirPatientReadInput(action="read_patient", resource_id="eXYZ123") patient_response = MagicMock() patient_response.status_code = 200 @@ -67,7 +62,7 @@ async def test_fhir_epic_read_patient_by_id(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "eXYZ123" assert result.resource["name"][0]["family"] == "Smith" @@ -79,9 +74,12 @@ async def test_fhir_epic_read_patient_by_id(): @pytest.mark.asyncio async def test_fhir_epic_read_patient_by_search(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_epic.schema import FhirPatientReadInput - params = FhirPatientReadInput(search_params={"family": "Smith", "given": "John"}) + params = FhirPatientReadInput( + action="read_patient", + search_params={"family": "Smith", "given": "John"}, + ) patient_response = MagicMock() patient_response.status_code = 200 @@ -93,7 +91,7 @@ async def test_fhir_epic_read_patient_by_search(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "eABC" @@ -104,9 +102,14 @@ async def test_fhir_epic_read_patient_by_search(): @pytest.mark.asyncio async def test_fhir_epic_read_patient_by_explicit_name_fields(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_epic.schema import FhirPatientReadInput - params = FhirPatientReadInput(given_name=" John ", family_name="Smith", birthdate="1980-01-01") + params = FhirPatientReadInput( + action="read_patient", + given_name=" John ", + family_name="Smith", + birthdate="1980-01-01", + ) patient_response = MagicMock() patient_response.status_code = 200 @@ -118,7 +121,7 @@ async def test_fhir_epic_read_patient_by_explicit_name_fields(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "eDEF" # Verify the correct FHIR params were built (stripped whitespace) @@ -135,9 +138,9 @@ async def test_fhir_epic_read_patient_by_explicit_name_fields(): @pytest.mark.asyncio async def test_fhir_epic_read_patient_by_name_field(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_epic.schema import FhirPatientReadInput - params = FhirPatientReadInput(name="Johnson") + params = FhirPatientReadInput(action="read_patient", name="Johnson") patient_response = MagicMock() patient_response.status_code = 200 @@ -149,7 +152,7 @@ async def test_fhir_epic_read_patient_by_name_field(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "eGHI" call_kwargs = mock_get.call_args @@ -163,14 +166,14 @@ async def test_fhir_epic_read_patient_by_name_field(): @pytest.mark.asyncio async def test_fhir_epic_read_patient_no_params_raises(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_epic.schema import FhirPatientReadInput - params = FhirPatientReadInput() # nothing provided - + params = FhirPatientReadInput(action="read_patient") + with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()): with pytest.raises(ValueError, match="Provide resource_id"): - await action.internal_execute(params, trace_id="test-trace") + await c.internal_execute(params, trace_id="test-trace") # --------------------------------------------------------------------------- @@ -179,9 +182,9 @@ async def test_fhir_epic_read_patient_no_params_raises(): @pytest.mark.asyncio async def test_fhir_epic_search_patients_multi_id(): - action = _connector().get_action("search_patients") + c = _connector() from connectors.fhir_epic.schema import FhirPatientSearchInput - params = FhirPatientSearchInput(resource_ids=["eABC", "eDEF"]) + params = FhirPatientSearchInput(action="search_patients", resource_ids=["eABC", "eDEF"]) def _patient_resp(pid: str) -> MagicMock: m = MagicMock() @@ -194,7 +197,7 @@ def _patient_resp(pid: str) -> MagicMock: with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=responses): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") ids = {r["id"] for r in result.resources} assert ids == {"eABC", "eDEF"} @@ -208,9 +211,9 @@ def _patient_resp(pid: str) -> MagicMock: @pytest.mark.asyncio async def test_fhir_epic_search_patients_partial_failure(): - action = _connector().get_action("search_patients") + c = _connector() from connectors.fhir_epic.schema import FhirPatientSearchInput - params = FhirPatientSearchInput(resource_ids=["eGOOD", "eBAD"]) + params = FhirPatientSearchInput(action="search_patients", resource_ids=["eGOOD", "eBAD"]) good_resp = MagicMock() good_resp.status_code = 200 @@ -223,7 +226,7 @@ async def test_fhir_epic_search_patients_partial_failure(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=[good_resp, bad_resp]): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert len(result.resources) == 1 assert result.resources[0]["id"] == "eGOOD" @@ -237,9 +240,9 @@ async def test_fhir_epic_search_patients_partial_failure(): @pytest.mark.asyncio async def test_fhir_epic_search_patients_by_name(): - action = _connector().get_action("search_patients") + c = _connector() from connectors.fhir_epic.schema import FhirPatientSearchInput - params = FhirPatientSearchInput(family_name="Smith") + params = FhirPatientSearchInput(action="search_patients", family_name="Smith") bundle_resp = MagicMock() bundle_resp.status_code = 200 @@ -255,7 +258,7 @@ async def test_fhir_epic_search_patients_by_name(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=bundle_resp) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 2 assert len(result.resources) == 2 @@ -272,14 +275,14 @@ async def test_fhir_epic_search_patients_by_name(): @pytest.mark.asyncio async def test_fhir_epic_search_patients_no_params_raises(): - action = _connector().get_action("search_patients") + c = _connector() from connectors.fhir_epic.schema import FhirPatientSearchInput - params = FhirPatientSearchInput() - + params = FhirPatientSearchInput(action="search_patients") + with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()): with pytest.raises(ValueError): - await action.internal_execute(params, trace_id="test-trace") + await c.internal_execute(params, trace_id="test-trace") # --------------------------------------------------------------------------- @@ -288,9 +291,12 @@ async def test_fhir_epic_search_patients_no_params_raises(): @pytest.mark.asyncio async def test_fhir_epic_search_encounter(): - action = _connector().get_action("search_encounter") + c = _connector() from connectors.fhir_epic.schema import FhirEncounterSearchInput - params = FhirEncounterSearchInput(search_params={"patient": "eXYZ123", "status": "finished"}) + params = FhirEncounterSearchInput( + action="search_encounter", + search_params={"patient": "eXYZ123", "status": "finished"}, + ) enc_response = MagicMock() enc_response.status_code = 200 @@ -305,7 +311,7 @@ async def test_fhir_epic_search_encounter(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=enc_response): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 2 assert result.resources[0]["id"] == "enc-1" @@ -317,9 +323,10 @@ async def test_fhir_epic_search_encounter(): @pytest.mark.asyncio async def test_fhir_epic_create_document_reference(): - action = _connector().get_action("create_document_reference") + c = _connector() from connectors.fhir_epic.schema import FhirDocumentReferenceCreateInput params = FhirDocumentReferenceCreateInput( + action="create_document_reference", identifier=[{"system": "urn:oid:1.2.3", "value": "ID.123"}], status="current", type={"coding": [{"system": "urn:oid:4.5.6", "code": "18100", "display": "Employer Group Scan"}]}, @@ -337,7 +344,7 @@ async def test_fhir_epic_create_document_reference(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post: mock_post.side_effect = [_token_mock(), create_response] - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource_id == "doc-456" _, kwargs = mock_post.call_args_list[1] @@ -351,9 +358,12 @@ async def test_fhir_epic_create_document_reference(): @pytest.mark.asyncio async def test_fhir_epic_search_document_reference(): - action = _connector().get_action("search_document_reference") + c = _connector() from connectors.fhir_epic.schema import FhirDocumentReferenceSearchInput - params = FhirDocumentReferenceSearchInput(search_params={"patient": "eXYZ123"}) + params = FhirDocumentReferenceSearchInput( + action="search_document_reference", + search_params={"patient": "eXYZ123"}, + ) search_response = MagicMock() search_response.status_code = 200 @@ -366,7 +376,7 @@ async def test_fhir_epic_search_document_reference(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=search_response): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 1 assert result.resources[0]["id"] == "doc-789" diff --git a/tests/test_toolhive_agent.py b/tests/test_toolhive_agent.py index 5f92366..50aa61a 100644 --- a/tests/test_toolhive_agent.py +++ b/tests/test_toolhive_agent.py @@ -248,8 +248,8 @@ async def test_agent_fails_when_mcp_unreachable() -> None: # MCP entrypoint smoke test # --------------------------------------------------------------------------- -def test_mcp_entrypoint_registers_four_tools() -> None: - """The FastMCP server should expose exactly 4 tools.""" +def test_mcp_entrypoint_registers_eight_tools() -> None: + """The FastMCP server should expose the full FHIR + integration tool surface.""" # We patch all external deps before importing the module to avoid side effects with ( patch("bindings.factory.ConnectorFactory") as mock_factory_cls, @@ -280,9 +280,13 @@ def fake_tool(*args: Any, **kwargs: Any): from agents.mcp_entrypoint import _make_server _make_server() - assert len(registered_tools) == 4 + assert len(registered_tools) == 8 assert "fhir_cerner_read_patient" in registered_tools + assert "fhir_cerner_search_patients" in registered_tools + assert "fhir_cerner_search_encounters" in registered_tools assert "fhir_epic_read_patient" in registered_tools + assert "fhir_epic_search_patients" in registered_tools + assert "fhir_epic_search_encounters" in registered_tools assert "google_drive_upload_file" in registered_tools assert "smtp_send_email" in registered_tools From 018ba6296ec4c5e4330aa489581949a18a5f169c Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Tue, 31 Mar 2026 12:19:59 -0700 Subject: [PATCH 02/60] cleanup and update mcp --- src/agents/fhir_cerner_mcp.py | 22 ++++++------ src/agents/fhir_epic_mcp.py | 22 ++++++------ src/agents/mcp_entrypoint.py | 8 ++++- src/connectors/manifest.py | 66 +++++++++++++++++++++++------------ tests/test_fhir_cerner.py | 2 +- 5 files changed, 71 insertions(+), 49 deletions(-) diff --git a/src/agents/fhir_cerner_mcp.py b/src/agents/fhir_cerner_mcp.py index d329170..fd2067c 100644 --- a/src/agents/fhir_cerner_mcp.py +++ b/src/agents/fhir_cerner_mcp.py @@ -48,6 +48,7 @@ async def fhir_cerner_read_patient( patient_id: str = "", family_name: str = "", given_name: str = "", + name: str = "", birthdate: str = "", ) -> dict: trace_id = str(uuid.uuid4()) @@ -57,19 +58,16 @@ async def fhir_cerner_read_patient( if patient_id: params = FhirCernerPatientReadInput(action="read_patient", resource_id=patient_id) - elif family_name or given_name: - search = { - k: v - for k, v in { - "family": family_name, - "given": given_name, - "birthdate": birthdate, - }.items() - if v - } - params = FhirCernerPatientReadInput(action="read_patient", search_params=search) + elif family_name or given_name or name: + params = FhirCernerPatientReadInput( + action="read_patient", + given_name=given_name or None, + family_name=family_name or None, + name=name or None, + birthdate=birthdate or None, + ) else: - raise ValueError("Provide patient_id OR at least family_name/given_name") + raise ValueError("Provide patient_id OR at least family_name / given_name / name") result = await cerner.internal_execute(params, trace_id=trace_id) resource = result.resource diff --git a/src/agents/fhir_epic_mcp.py b/src/agents/fhir_epic_mcp.py index b196b7a..5e6798e 100644 --- a/src/agents/fhir_epic_mcp.py +++ b/src/agents/fhir_epic_mcp.py @@ -49,6 +49,7 @@ async def fhir_epic_read_patient( patient_id: str = "", family_name: str = "", given_name: str = "", + name: str = "", birthdate: str = "", ) -> dict: trace_id = str(uuid.uuid4()) @@ -58,19 +59,16 @@ async def fhir_epic_read_patient( if patient_id: params = FhirEpicPatientReadInput(action="read_patient", resource_id=patient_id) - elif family_name or given_name: - search = { - k: v - for k, v in { - "family": family_name, - "given": given_name, - "birthdate": birthdate, - }.items() - if v - } - params = FhirEpicPatientReadInput(action="read_patient", search_params=search) + elif family_name or given_name or name: + params = FhirEpicPatientReadInput( + action="read_patient", + given_name=given_name or None, + family_name=family_name or None, + name=name or None, + birthdate=birthdate or None, + ) else: - raise ValueError("Provide patient_id OR at least family_name/given_name") + raise ValueError("Provide patient_id OR at least family_name / given_name / name") result = await epic.internal_execute(params, trace_id=trace_id) resource = result.resource diff --git a/src/agents/mcp_entrypoint.py b/src/agents/mcp_entrypoint.py index 9d974eb..ba9ac46 100644 --- a/src/agents/mcp_entrypoint.py +++ b/src/agents/mcp_entrypoint.py @@ -406,6 +406,9 @@ async def fhir_cerner_search_encounters( if not cerner: raise RuntimeError("fhir_cerner connector not configured") + if not (patient_id or status or date): + raise ValueError("Provide at least one of patient_id / status / date") + params = FhirCernerEncounterSearchInput( action="search_encounter", patient_id=patient_id or None, @@ -463,6 +466,9 @@ async def fhir_epic_search_encounters( if not epic: raise RuntimeError("fhir_epic connector not configured") + if not (patient_id or status or date): + raise ValueError("Provide at least one of patient_id / status / date") + params = FhirEpicEncounterSearchInput( action="search_encounter", patient_id=patient_id or None, @@ -544,7 +550,7 @@ async def google_drive_upload_file( } # ------------------------------------------------------------------ - # Tool 4: Send email via SMTP + # Tool 8: Send email via SMTP # ------------------------------------------------------------------ @mcp.tool( diff --git a/src/connectors/manifest.py b/src/connectors/manifest.py index 2033234..ffb5cfa 100644 --- a/src/connectors/manifest.py +++ b/src/connectors/manifest.py @@ -6,30 +6,46 @@ from runtime import BaseConnector -# FHIR connectors expose a single `execute` entrypoint with a discriminated `action` -# field; expand these for REST/MCP discovery so routes remain per-operation. -_FHIR_DISCRIMINATED_ACTIONS: Dict[str, List[str]] = { - "fhir_cerner": [ - "read_patient", - "search_patients", - "search_encounter", - "create_document_reference", - "search_document_reference", - ], - "fhir_epic": [ - "read_patient", - "search_patients", - "search_encounter", - "create_document_reference", - "search_document_reference", - ], -} - def _schema_for(model: Type[BaseModel]) -> Dict[str, Any]: return model.model_json_schema() +def _fhir_action_schemas() -> Dict[str, Dict[str, Type[BaseModel]]]: + """Return per-action input model classes for FHIR connectors (lazy import).""" + from connectors.fhir_cerner.schema import ( + FhirCernerDocumentReferenceCreateInput, + FhirCernerDocumentReferenceSearchInput, + FhirCernerEncounterSearchInput, + FhirCernerPatientReadInput, + FhirCernerPatientSearchInput, + ) + from connectors.fhir_epic.schema import ( + FhirDocumentReferenceCreateInput, + FhirDocumentReferenceSearchInput, + FhirEncounterSearchInput, + FhirPatientReadInput, + FhirPatientSearchInput, + ) + + return { + "fhir_cerner": { + "read_patient": FhirCernerPatientReadInput, + "search_patients": FhirCernerPatientSearchInput, + "search_encounter": FhirCernerEncounterSearchInput, + "create_document_reference": FhirCernerDocumentReferenceCreateInput, + "search_document_reference": FhirCernerDocumentReferenceSearchInput, + }, + "fhir_epic": { + "read_patient": FhirPatientReadInput, + "search_patients": FhirPatientSearchInput, + "search_encounter": FhirEncounterSearchInput, + "create_document_reference": FhirDocumentReferenceCreateInput, + "search_document_reference": FhirDocumentReferenceSearchInput, + }, + } + + def build_manifest(connectors: List[BaseConnector[Any, Any]]) -> List[Dict[str, Any]]: """ Build a simple manifest for discovery. @@ -39,21 +55,25 @@ def build_manifest(connectors: List[BaseConnector[Any, Any]]) -> List[Dict[str, REST route generation and MCP tool manifests. """ manifest: List[Dict[str, Any]] = [] + fhir_schemas: Dict[str, Dict[str, Type[BaseModel]]] | None = None + for connector in connectors: - input_model = connector._input_model_cls # type: ignore[attr-defined] output_model = connector._output_model_cls # type: ignore[attr-defined] cid = connector.connector_id - if cid in _FHIR_DISCRIMINATED_ACTIONS and getattr(connector, "action", None) == "execute": - for sub_action in _FHIR_DISCRIMINATED_ACTIONS[cid]: + if getattr(connector, "action", None) == "execute" and cid in ("fhir_cerner", "fhir_epic"): + if fhir_schemas is None: + fhir_schemas = _fhir_action_schemas() + for sub_action, input_cls in fhir_schemas[cid].items(): manifest.append( { "connector_id": cid, "action": sub_action, - "input_schema": _schema_for(input_model), + "input_schema": _schema_for(input_cls), "output_schema": _schema_for(output_model), } ) else: + input_model = connector._input_model_cls # type: ignore[attr-defined] manifest.append( { "connector_id": cid, diff --git a/tests/test_fhir_cerner.py b/tests/test_fhir_cerner.py index 903a927..9aa7fe7 100644 --- a/tests/test_fhir_cerner.py +++ b/tests/test_fhir_cerner.py @@ -16,7 +16,7 @@ class MockSecretProvider(SecretProvider): def get_secret(self, key: str) -> str: return { "cerner_fhir_base_url": "https://fhir-myrecord.cerner.com/r4/tenant-id", - "cerner_private_key": "-----BEGIN RSA PRIVATE KEY-----\\\\nMEowIQ...dummy\\\\n-----END RSA PRIVATE KEY-----", + "cerner_private_key": "-----BEGIN RSA PRIVATE KEY-----\\nMEowIQ...dummy\\n-----END RSA PRIVATE KEY-----", "cerner_kid": "dummy-kid", "cerner_client_id": "dummy-client-id", "cerner_token_url": "https://authorization.cerner.com/tenants/tenant-id/protocols/oauth2/profiles/smart-v1/token", From cc803d261bc104f3d67d55597ca2087f994d0245 Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Tue, 31 Mar 2026 19:08:43 -0700 Subject: [PATCH 03/60] SDK connector added --- src/bindings/factory.py | 51 +-- src/bindings/mcp_server/server.py | 3 +- src/bindings/rest_api/app.py | 11 +- src/connectors/__init__.py | 27 +- src/connectors/fhir_cerner/__init__.py | 1 + src/connectors/fhir_cerner/logic.py | 79 +++-- src/connectors/fhir_cerner/schema.py | 26 +- src/connectors/fhir_epic/__init__.py | 1 + src/connectors/fhir_epic/logic.py | 292 ++++++++++------- src/connectors/fhir_epic/schema.py | 26 +- src/connectors/google_drive/logic.py | 413 ++++++++++++++----------- src/connectors/google_drive/schema.py | 18 +- src/connectors/http_generic/logic.py | 2 - src/connectors/manifest.py | 61 +--- src/connectors/stripe/logic.py | 42 ++- src/connectors/stripe/schema.py | 4 +- src/runtime/__init__.py | 4 + src/runtime/base.py | 3 +- src/runtime/sdk_connector.py | 217 +++++++++++++ tests/test_connectors_basic.py | 6 +- tests/test_google_drive.py | 9 +- tests/test_sdk_connector_manifest.py | 58 ++++ 22 files changed, 836 insertions(+), 518 deletions(-) create mode 100644 src/connectors/fhir_cerner/__init__.py create mode 100644 src/connectors/fhir_epic/__init__.py create mode 100644 src/runtime/sdk_connector.py create mode 100644 tests/test_sdk_connector_manifest.py diff --git a/src/bindings/factory.py b/src/bindings/factory.py index 76e4df8..2f87234 100644 --- a/src/bindings/factory.py +++ b/src/bindings/factory.py @@ -7,24 +7,15 @@ import yaml -from connectors.fhir_epic.logic import FhirEpicConnector -from connectors.fhir_cerner.logic import FhirCernerConnector from connectors.http_generic.logic import HttpGenericConnector from connectors.http_generic.schema import HttpRequestInput, HttpResponseOutput -from connectors.google_drive.logic import GoogleDriveConnector -from connectors.google_drive.schema import ( - GoogleDriveOperationInput, - GoogleDriveOperationOutput, -) from connectors.smtp.logic import SmtpConnector from connectors.smtp.schema import SmtpSendInput, SmtpSendOutput -from connectors.stripe.logic import StripeChargeConnector -from connectors.stripe.schema import ChargeInput, ChargeOutput from runtime import BaseConnector, SecretProvider +from runtime.sdk_connector import _CONNECTOR_REGISTRY logger = logging.getLogger("bindings.factory") -# Resolve default config relative to platform root so it works from any cwd. _PLATFORM_ROOT = Path(__file__).resolve().parent.parent.parent _DEFAULT_CONFIG_PATH = _PLATFORM_ROOT / "config" / "connectors.yaml" @@ -38,11 +29,7 @@ class ConnectorConfig: class EnvSecretProvider(SecretProvider): - """ - Simple SecretProvider implementation backed by environment variables. - - Keys are looked up directly from os.environ for the POC. - """ + """SecretProvider backed by environment variables.""" def __init__(self) -> None: import os @@ -56,16 +43,13 @@ def get_secret(self, key: str) -> str: val = self._env.get(key.upper()) if val is not None: return val.strip(" '\"") - # Return empty string instead of raising RuntimeError for zero-config/local testing. return "" class ConnectorFactory: """ - Factory responsible for: - - Loading connector configuration from config/connectors.yaml - - Instantiating connector adapters - - Enforcing exposed_via rules per protocol + Loads config/connectors.yaml, instantiates connectors from the SDK registry + or legacy explicit constructors. """ def __init__(self, config_path: str | Path | None = None) -> None: @@ -74,7 +58,6 @@ def __init__(self, config_path: str | Path | None = None) -> None: elif _DEFAULT_CONFIG_PATH.is_file(): self._config_path = str(_DEFAULT_CONFIG_PATH) else: - # Fallback when run from platform dir (e.g. package installed from wheel) cwd_config = Path.cwd() / "config" / "connectors.yaml" self._config_path = str(cwd_config) self._secret_provider: SecretProvider = EnvSecretProvider() @@ -114,22 +97,22 @@ def load(self) -> None: self._connectors[connector_id] = self._instantiate(connector_id) def _instantiate(self, connector_id: str) -> Any: + sdk_cls = _CONNECTOR_REGISTRY.get(connector_id) + if sdk_cls is not None: + return sdk_cls(secret_provider=self._secret_provider) + if connector_id == "http_generic": - return HttpGenericConnector(HttpRequestInput, HttpResponseOutput, secret_provider=self._secret_provider) + return HttpGenericConnector( + HttpRequestInput, + HttpResponseOutput, + secret_provider=self._secret_provider, + ) if connector_id == "smtp": - return SmtpConnector(SmtpSendInput, SmtpSendOutput, secret_provider=self._secret_provider) - if connector_id == "stripe": - return StripeChargeConnector(ChargeInput, ChargeOutput, secret_provider=self._secret_provider) - if connector_id == "google_drive": - return GoogleDriveConnector( - GoogleDriveOperationInput, - GoogleDriveOperationOutput, + return SmtpConnector( + SmtpSendInput, + SmtpSendOutput, secret_provider=self._secret_provider, ) - if connector_id == "fhir_epic": - return FhirEpicConnector(secret_provider=self._secret_provider) - if connector_id == "fhir_cerner": - return FhirCernerConnector(secret_provider=self._secret_provider) raise ValueError(f"Unknown connector id {connector_id!r}") @@ -164,7 +147,7 @@ def get_for_protocol( if action: logger.debug( - "get_for_protocol resolved connector (action from URL is merged into payload by REST)", + "get_for_protocol resolved connector", extra={"connector_id": connector_id, "protocol": protocol, "action": action}, ) diff --git a/src/bindings/mcp_server/server.py b/src/bindings/mcp_server/server.py index 6b9f0a2..5bbed57 100644 --- a/src/bindings/mcp_server/server.py +++ b/src/bindings/mcp_server/server.py @@ -7,6 +7,7 @@ from bindings.factory import ConnectorFactory from connectors import auto_register from connectors.manifest import build_manifest +from runtime import SDKConnector logger = logging.getLogger("bindings.mcp_server") @@ -51,7 +52,7 @@ async def invoke_tool(self, name: str, arguments: Dict[str, Any]) -> Dict[str, A raise ValueError(f"Connector {connector_id!r} is not available via MCP.") run_args = dict(arguments) - if connector_id in ("fhir_cerner", "fhir_epic"): + if isinstance(connector, SDKConnector): run_args.setdefault("action", action) response = await connector.run(run_args) diff --git a/src/bindings/rest_api/app.py b/src/bindings/rest_api/app.py index f442734..4fd283a 100644 --- a/src/bindings/rest_api/app.py +++ b/src/bindings/rest_api/app.py @@ -6,7 +6,6 @@ from fastapi import Depends, FastAPI, HTTPException from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel, create_model from dotenv import load_dotenv load_dotenv() # Load environmental variables from .env @@ -14,7 +13,7 @@ from bindings.factory import ConnectorFactory from connectors import auto_register from connectors.manifest import build_manifest -from runtime import ConnectorResponse, ErrorCategory +from runtime import ConnectorResponse, ErrorCategory, SDKConnector from opentelemetry import trace from opentelemetry.trace import Status, StatusCode from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor @@ -36,9 +35,6 @@ app = FastAPI(title="Node Wire - REST API") FastAPIInstrumentor.instrument_app(app) -import os -from pathlib import Path - # Include the professional scenarios orchestrator app.include_router(scenarios_router) @@ -73,9 +69,6 @@ def _http_status_for_category(category: ErrorCategory | None) -> int: return 503 return 500 -_FHIR_REST_IDS = frozenset({"fhir_cerner", "fhir_epic"}) - - def _make_endpoint(cid: str, act: str) -> Any: async def endpoint( payload: Dict[str, Any], @@ -93,7 +86,7 @@ async def endpoint( if connector is None: raise HTTPException(status_code=404, detail="Connector not available for REST") run_payload = dict(payload) - if cid in _FHIR_REST_IDS: + if isinstance(connector, SDKConnector): run_payload.setdefault("action", act) # Let the runtime (Layer A) perform full schema validation. # Any validation errors will be mapped into ConnectorResponse. diff --git a/src/connectors/__init__.py b/src/connectors/__init__.py index f9c7b0f..50d93b4 100644 --- a/src/connectors/__init__.py +++ b/src/connectors/__init__.py @@ -3,48 +3,51 @@ """ Node Wire - Layer B: System Adapters. -Each connector lives in its own subpackage and follows the three-file pattern: +Each connector lives in its own subpackage: connector_name/ schema.py logic.py - registration.py + registration.py (optional — legacy connectors) -Registration modules are auto-discovered so they can register system-specific -exceptions with the global ErrorMapper in Layer A. +SDKConnector-based connectors self-register when their `logic` module is +imported. Legacy connectors may still use `registration.py` for ErrorMapper. """ from importlib import import_module from pkgutil import iter_modules -from typing import Iterable, List +from typing import List def auto_register() -> List[str]: """ - Import all `registration` modules in connector subpackages. + Import connector subpackages so SDK connectors register and legacy mappings apply. - Returns the list of successfully imported module names. This should be - called once at process startup (e.g. by Layer C bindings) to ensure all - connector-specific error mappings are registered. + Imports `logic` first (triggers SDKConnector.__init_subclass__), then + `registration` when present. """ imported: List[str] = [] package_name = __name__ for module_info in iter_modules(__path__, prefix=f"{package_name}."): - # We only care about subpackages; each is expected to expose registration.py if not module_info.ispkg: continue + logic_module = f"{module_info.name}.logic" + try: + import_module(logic_module) + imported.append(logic_module) + except ModuleNotFoundError: + pass + registration_module = f"{module_info.name}.registration" try: import_module(registration_module) imported.append(registration_module) except ModuleNotFoundError: - # Connector without a registration module; skip silently. continue return imported __all__ = ["auto_register"] - diff --git a/src/connectors/fhir_cerner/__init__.py b/src/connectors/fhir_cerner/__init__.py new file mode 100644 index 0000000..9fa8ea5 --- /dev/null +++ b/src/connectors/fhir_cerner/__init__.py @@ -0,0 +1 @@ +"""FHIR Cerner connector package.""" diff --git a/src/connectors/fhir_cerner/logic.py b/src/connectors/fhir_cerner/logic.py index c05281c..94b453d 100644 --- a/src/connectors/fhir_cerner/logic.py +++ b/src/connectors/fhir_cerner/logic.py @@ -11,7 +11,7 @@ import httpx import jwt -from runtime import BaseConnector, SecretProvider +from runtime import SDKConnector, sdk_action from . import registration from .schema import ( @@ -21,7 +21,6 @@ FhirCernerDocumentReferenceSearchOutput, FhirCernerEncounterSearchInput, FhirCernerEncounterSearchOutput, - FhirCernerOperationInput, FhirCernerOperationOutput, FhirCernerPatientReadInput, FhirCernerPatientReadOutput, @@ -32,44 +31,56 @@ logger = logging.getLogger("connectors.fhir_cerner") -class FhirCernerConnector(BaseConnector[FhirCernerOperationInput, FhirCernerOperationOutput]): +class FhirCernerConnector(SDKConnector): """ - Single FHIR/Cerner connector. - - Authentication uses Cerner's SMART Backend Services (private_key_jwt) flow, - identical to Epic's implementation — RS384-signed JWT exchanged for an - OAuth2 access token at the configured token endpoint. - - Required secrets (configured via SecretProvider): - - cerner_fhir_base_url : Cerner FHIR R4 base URL - - cerner_private_key : RSA private key PEM (newlines may be escaped) - - cerner_kid : Key ID registered in the Cerner code console - - cerner_client_id : Client ID from Cerner app registration - - cerner_token_url : OAuth2 token endpoint URL (from .well-known/smart-configuration - or the Cerner code console) + FHIR/Cerner connector: SMART Backend Services (private_key_jwt), RS384. + + Required secrets: cerner_fhir_base_url, cerner_private_key, cerner_kid, + cerner_client_id, cerner_token_url (optional cerner_scopes). """ connector_id = "fhir_cerner" action = "execute" + output_model = FhirCernerOperationOutput + + @sdk_action("read_patient") + async def read_patient( + self, params: FhirCernerPatientReadInput, *, trace_id: str + ) -> FhirCernerOperationOutput: + out = await self._read_patient(params, trace_id=trace_id) + return FhirCernerOperationOutput(resource=out.resource) + + @sdk_action("search_patients") + async def search_patients( + self, params: FhirCernerPatientSearchInput, *, trace_id: str + ) -> FhirCernerOperationOutput: + out = await self._search_patients(params, trace_id=trace_id) + return FhirCernerOperationOutput( + resources=out.resources, + total=out.total, + errors=out.errors, + ) - def __init__(self, *, secret_provider: SecretProvider) -> None: - super().__init__(FhirCernerOperationInput, FhirCernerOperationOutput, secret_provider=secret_provider) - self._secret_provider = secret_provider - - async def internal_execute(self, params: Any, *, trace_id: str) -> Any: - # Back-compat: allow calling with either the RootModel union or a concrete action input model. - op = params.root if hasattr(params, "root") else params - if op.action == "read_patient": - return await self._read_patient(op, trace_id=trace_id) - if op.action == "search_patients": - return await self._search_patients(op, trace_id=trace_id) - if op.action == "search_encounter": - return await self._search_encounter(op, trace_id=trace_id) - if op.action == "create_document_reference": - return await self._create_document_reference(op, trace_id=trace_id) - if op.action == "search_document_reference": - return await self._search_document_reference(op, trace_id=trace_id) - raise ValueError(f"Unsupported action: {op.action!r}") + @sdk_action("search_encounter") + async def search_encounter( + self, params: FhirCernerEncounterSearchInput, *, trace_id: str + ) -> FhirCernerOperationOutput: + out = await self._search_encounter(params, trace_id=trace_id) + return FhirCernerOperationOutput(resources=out.resources, total=out.total) + + @sdk_action("create_document_reference") + async def create_document_reference( + self, params: FhirCernerDocumentReferenceCreateInput, *, trace_id: str + ) -> FhirCernerOperationOutput: + out = await self._create_document_reference(params, trace_id=trace_id) + return FhirCernerOperationOutput(resource_id=out.resource_id, resource=out.resource) + + @sdk_action("search_document_reference") + async def search_document_reference( + self, params: FhirCernerDocumentReferenceSearchInput, *, trace_id: str + ) -> FhirCernerOperationOutput: + out = await self._search_document_reference(params, trace_id=trace_id) + return FhirCernerOperationOutput(resources=out.resources, total=out.total) # ------------------------------------------------------------------ # Shared authentication helpers diff --git a/src/connectors/fhir_cerner/schema.py b/src/connectors/fhir_cerner/schema.py index e24d915..123c81d 100644 --- a/src/connectors/fhir_cerner/schema.py +++ b/src/connectors/fhir_cerner/schema.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Annotated, Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional -from pydantic import BaseModel, Field, RootModel +from pydantic import BaseModel, Field # --------------------------------------------------------------------------- @@ -275,29 +275,11 @@ class FhirCernerDocumentReferenceSearchOutput(BaseModel): """Total number of results reported by the Bundle.""" -# --------------------------------------------------------------------------- -# Unified operation input/output (one endpoint, multiple actions) -# --------------------------------------------------------------------------- - -_FhirCernerOperationUnion = Annotated[ - Union[ - FhirCernerPatientReadInput, - FhirCernerPatientSearchInput, - FhirCernerEncounterSearchInput, - FhirCernerDocumentReferenceCreateInput, - FhirCernerDocumentReferenceSearchInput, - ], - Field(discriminator="action"), -] - -FhirCernerOperationInput = RootModel[_FhirCernerOperationUnion] - - class FhirCernerOperationOutput(BaseModel): """ - Combined output shape for schema documentation/manifest generation. + Unified output for all Cerner FHIR actions (SDKConnector single output_model). - Individual handlers still return their specific output models. + Fields are populated depending on the action; unused fields are None. """ resource: Optional[Dict[str, Any]] = None diff --git a/src/connectors/fhir_epic/__init__.py b/src/connectors/fhir_epic/__init__.py new file mode 100644 index 0000000..aa47436 --- /dev/null +++ b/src/connectors/fhir_epic/__init__.py @@ -0,0 +1 @@ +"""FHIR Epic connector package.""" diff --git a/src/connectors/fhir_epic/logic.py b/src/connectors/fhir_epic/logic.py index 5cbe8c3..9e72e58 100644 --- a/src/connectors/fhir_epic/logic.py +++ b/src/connectors/fhir_epic/logic.py @@ -11,7 +11,7 @@ import httpx import jwt -from runtime import BaseConnector, SecretProvider +from runtime import SDKConnector, sdk_action from .schema import ( FhirDocumentReferenceCreateInput, @@ -20,7 +20,6 @@ FhirDocumentReferenceSearchOutput, FhirEncounterSearchInput, FhirEncounterSearchOutput, - FhirEpicOperationInput, FhirEpicOperationOutput, FhirPatientReadInput, FhirPatientReadOutput, @@ -31,69 +30,82 @@ logger = logging.getLogger("connectors.fhir_epic") -class FhirEpicConnector(BaseConnector[FhirEpicOperationInput, FhirEpicOperationOutput]): - """ - Single FHIR/Epic connector. - - Exposes one endpoint (`execute`) and dispatches actions via the - `action` discriminator on the request payload. - """ +class FhirEpicConnector(SDKConnector): + """FHIR/Epic connector: one @sdk_action per operation.""" connector_id = "fhir_epic" action = "execute" + output_model = FhirEpicOperationOutput + + @sdk_action("read_patient") + async def read_patient( + self, params: FhirPatientReadInput, *, trace_id: str + ) -> FhirEpicOperationOutput: + out = await self._read_patient(params, trace_id=trace_id) + return FhirEpicOperationOutput(resource=out.resource) + + @sdk_action("search_patients") + async def search_patients( + self, params: FhirPatientSearchInput, *, trace_id: str + ) -> FhirEpicOperationOutput: + out = await self._search_patients(params, trace_id=trace_id) + return FhirEpicOperationOutput( + resources=out.resources, + total=out.total, + errors=out.errors, + ) - def __init__(self, *, secret_provider: SecretProvider) -> None: - super().__init__(FhirEpicOperationInput, FhirEpicOperationOutput, secret_provider=secret_provider) - self._secret_provider = secret_provider - - async def internal_execute(self, params: Any, *, trace_id: str) -> Any: - # Back-compat: allow calling with either the RootModel union or a concrete action input model. - op = params.root if hasattr(params, "root") else params - if op.action == "read_patient": - return await self._read_patient(op, trace_id=trace_id) - if op.action == "search_patients": - return await self._search_patients(op, trace_id=trace_id) - if op.action == "search_encounter": - return await self._search_encounter(op, trace_id=trace_id) - if op.action == "create_document_reference": - return await self._create_document_reference(op, trace_id=trace_id) - if op.action == "search_document_reference": - return await self._search_document_reference(op, trace_id=trace_id) - raise ValueError(f"Unsupported action: {op.action!r}") + @sdk_action("search_encounter") + async def search_encounter( + self, params: FhirEncounterSearchInput, *, trace_id: str + ) -> FhirEpicOperationOutput: + out = await self._search_encounter(params, trace_id=trace_id) + return FhirEpicOperationOutput(resources=out.resources, total=out.total) + + @sdk_action("create_document_reference") + async def create_document_reference( + self, params: FhirDocumentReferenceCreateInput, *, trace_id: str + ) -> FhirEpicOperationOutput: + out = await self._create_document_reference(params, trace_id=trace_id) + return FhirEpicOperationOutput(resource_id=out.resource_id, resource=out.resource) + + @sdk_action("search_document_reference") + async def search_document_reference( + self, params: FhirDocumentReferenceSearchInput, *, trace_id: str + ) -> FhirEpicOperationOutput: + out = await self._search_document_reference(params, trace_id=trace_id) + return FhirEpicOperationOutput(resources=out.resources, total=out.total) # ------------------------------------------------------------------ # Shared authentication helpers # ------------------------------------------------------------------ def _get_base_url(self) -> str: - return self._secret_provider.get_secret("epic_fhir_base_url").rstrip("/") + return self.secret_provider.get_secret("epic_fhir_base_url").rstrip("/") async def _get_auth_header(self) -> Dict[str, str]: - """ - Obtain an access token via Epic's SMART Backend Services (private_key_jwt) - and return ready-to-use request headers. - - Algorithm: RS384. Token lifetime: 5 minutes (Epic maximum). - Reference: https://fhir.epic.com/Documentation?docId=oauth2tutorial§ion=cloud-based-app - """ headers = { "Content-Type": "application/fhir+json", "Accept": "application/fhir+json", } - private_key_str = self._secret_provider.get_secret("epic_private_key") - kid = self._secret_provider.get_secret("epic_kid") - client_id = self._secret_provider.get_secret("epic_client_id") - token_url = self._secret_provider.get_secret("epic_token_url") + private_key_str = self.secret_provider.get_secret("epic_private_key") + kid = self.secret_provider.get_secret("epic_kid") + client_id = self.secret_provider.get_secret("epic_client_id") + token_url = self.secret_provider.get_secret("epic_token_url") - # Environment variables sometimes store newlines as escape sequences. private_key_pem = codecs.decode(private_key_str, "unicode_escape") now = int(datetime.now(tz=timezone.utc).timestamp()) jwt_token = jwt.encode( { - "iss": client_id, "sub": client_id, "aud": token_url, - "jti": str(uuid.uuid4()), "iat": now, "nbf": now, "exp": now + 300, + "iss": client_id, + "sub": client_id, + "aud": token_url, + "jti": str(uuid.uuid4()), + "iat": now, + "nbf": now, + "exp": now + 300, }, private_key_pem, algorithm="RS384", @@ -115,7 +127,8 @@ async def _get_auth_header(self) -> Dict[str, str]: if token_response.status_code != 200: logger.error( "OAuth token exchange failed | status=%s | body=%s", - token_response.status_code, token_response.text, + token_response.status_code, + token_response.text, ) token_response.raise_for_status() token_data = token_response.json() @@ -127,10 +140,6 @@ async def _get_auth_header(self) -> Dict[str, str]: headers["Authorization"] = f"Bearer {access_token}" return headers - # ------------------------------------------------------------------ - # Internal name-field helpers - # ------------------------------------------------------------------ - @staticmethod def _build_name_search_params( given_name: Optional[str], @@ -139,7 +148,6 @@ def _build_name_search_params( birthdate: Optional[str], extra: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: - """Build a FHIR search params dict from explicit name/date fields.""" params: Dict[str, str] = dict(extra or {}) if given_name and given_name.strip(): @@ -160,7 +168,6 @@ def _build_encounter_search_params( date: Optional[str], extra: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: - """Build a FHIR search params dict for Encounter from explicit fields.""" params: Dict[str, str] = dict(extra or {}) if patient_id and patient_id.strip(): @@ -172,10 +179,6 @@ def _build_encounter_search_params( return params - # ------------------------------------------------------------------ - # Action: read_patient - # ------------------------------------------------------------------ - async def _read_patient( self, params: FhirPatientReadInput, *, trace_id: str ) -> FhirPatientReadOutput: @@ -185,18 +188,30 @@ async def _read_patient( if params.resource_id: url = f"{base_url}/Patient/{params.resource_id}" query_params: Optional[Dict[str, str]] = None - logger.info("FHIR Patient read by ID", extra={"trace_id": trace_id, "resource_id": params.resource_id}) + logger.info( + "FHIR Patient read by ID", + extra={"trace_id": trace_id, "resource_id": params.resource_id}, + ) elif params.given_name or params.family_name or params.name: url = f"{base_url}/Patient" query_params = self._build_name_search_params( - params.given_name, params.family_name, params.name, - params.birthdate, params.search_params, + params.given_name, + params.family_name, + params.name, + params.birthdate, + params.search_params, + ) + logger.info( + "FHIR Patient read by name fields", + extra={"trace_id": trace_id, "query_params": query_params}, ) - logger.info("FHIR Patient read by name fields", extra={"trace_id": trace_id, "query_params": query_params}) elif params.search_params: url = f"{base_url}/Patient" query_params = params.search_params - logger.info("FHIR Patient read by search", extra={"trace_id": trace_id, "search_params": params.search_params}) + logger.info( + "FHIR Patient read by search", + extra={"trace_id": trace_id, "search_params": params.search_params}, + ) else: raise ValueError( "Provide resource_id, or name fields (given_name/family_name/name), " @@ -205,10 +220,17 @@ async def _read_patient( try: async with httpx.AsyncClient() as client: - response = await client.get(url, headers=auth_header, params=query_params, timeout=30.0) + response = await client.get( + url, headers=auth_header, params=query_params, timeout=30.0 + ) response.raise_for_status() except Exception as exc: - logger.error("FHIR Patient read failed | error=%s: %s", type(exc).__name__, str(exc), extra={"trace_id": trace_id}) + logger.error( + "FHIR Patient read failed | error=%s: %s", + type(exc).__name__, + str(exc), + extra={"trace_id": trace_id}, + ) raise data = response.json() @@ -220,13 +242,12 @@ async def _read_patient( else: resource = data - logger.info("FHIR Patient read completed", extra={"trace_id": trace_id, "status_code": response.status_code}) + logger.info( + "FHIR Patient read completed", + extra={"trace_id": trace_id, "status_code": response.status_code}, + ) return FhirPatientReadOutput(resource=resource) - # ------------------------------------------------------------------ - # Action: search_patients (multi-ID fan-out OR name search) - # ------------------------------------------------------------------ - async def _search_patients( self, params: FhirPatientSearchInput, *, trace_id: str ) -> FhirPatientSearchOutput: @@ -257,7 +278,8 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ except Exception as exc: logger.warning( "FHIR Patient fetch failed | resource_id=%s | error=%s", - rid, str(exc), + rid, + str(exc), extra={"trace_id": trace_id}, ) return rid, None, str(exc) @@ -274,14 +296,20 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ logger.info( "FHIR Patient multi-ID lookup completed | found=%s | errors=%s", - len(resources), len(errors), + len(resources), + len(errors), extra={"trace_id": trace_id}, ) - return FhirPatientSearchOutput(resources=resources, total=len(resources), errors=errors) + return FhirPatientSearchOutput( + resources=resources, total=len(resources), errors=errors + ) name_params = self._build_name_search_params( - params.given_name, params.family_name, params.name, - params.birthdate, params.search_params, + params.given_name, + params.family_name, + params.name, + params.birthdate, + params.search_params, ) if not name_params: raise ValueError( @@ -307,14 +335,16 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ except httpx.HTTPStatusError as exc: logger.error( "FHIR Patient name search failed | status=%s | body=%s", - exc.response.status_code, exc.response.text, + exc.response.status_code, + exc.response.text, extra={"trace_id": trace_id}, ) raise except Exception as exc: logger.error( "FHIR Patient name search failed | error=%s: %s", - type(exc).__name__, str(exc), + type(exc).__name__, + str(exc), extra={"trace_id": trace_id}, ) raise @@ -327,15 +357,12 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ logger.info( "FHIR Patient name search completed | found=%s | total=%s", - len(resources), total, + len(resources), + total, extra={"trace_id": trace_id}, ) return FhirPatientSearchOutput(resources=resources, total=total) - # ------------------------------------------------------------------ - # Action: search_encounter - # ------------------------------------------------------------------ - async def _search_encounter( self, params: FhirEncounterSearchInput, *, trace_id: str ) -> FhirEncounterSearchOutput: @@ -346,24 +373,43 @@ async def _search_encounter( query_params = self._build_encounter_search_params( params.patient_id, params.status, params.date, params.search_params ) - logger.info("FHIR Encounter search by explicit fields", extra={"trace_id": trace_id, "query_params": query_params}) + logger.info( + "FHIR Encounter search by explicit fields", + extra={"trace_id": trace_id, "query_params": query_params}, + ) elif params.search_params: query_params = params.search_params - logger.info("FHIR Encounter search by raw params", extra={"trace_id": trace_id, "search_params": params.search_params}) + logger.info( + "FHIR Encounter search by raw params", + extra={"trace_id": trace_id, "search_params": params.search_params}, + ) else: raise ValueError("Provide at least patient_id, status, date OR search_params") try: async with httpx.AsyncClient() as client: response = await client.get( - f"{base_url}/Encounter", headers=auth_header, params=query_params, timeout=30.0, + f"{base_url}/Encounter", + headers=auth_header, + params=query_params, + timeout=30.0, ) response.raise_for_status() except httpx.HTTPStatusError as exc: - logger.error("FHIR Encounter search failed | status=%s | body=%s", exc.response.status_code, exc.response.text, extra={"trace_id": trace_id}) + logger.error( + "FHIR Encounter search failed | status=%s | body=%s", + exc.response.status_code, + exc.response.text, + extra={"trace_id": trace_id}, + ) raise except Exception as exc: - logger.error("FHIR Encounter search failed | error=%s: %s", type(exc).__name__, str(exc), extra={"trace_id": trace_id}) + logger.error( + "FHIR Encounter search failed | error=%s: %s", + type(exc).__name__, + str(exc), + extra={"trace_id": trace_id}, + ) raise data = response.json() @@ -372,13 +418,13 @@ async def _search_encounter( if data.get("resourceType") == "Bundle" and data.get("entry"): resources = [e["resource"] for e in data["entry"] if "resource" in e] - logger.info("FHIR Encounter search completed | found=%s", len(resources), extra={"trace_id": trace_id}) + logger.info( + "FHIR Encounter search completed | found=%s", + len(resources), + extra={"trace_id": trace_id}, + ) return FhirEncounterSearchOutput(resources=resources, total=total) - # ------------------------------------------------------------------ - # Action: create_document_reference - # ------------------------------------------------------------------ - async def _create_document_reference( self, params: FhirDocumentReferenceCreateInput, *, trace_id: str ) -> FhirDocumentReferenceCreateOutput: @@ -392,7 +438,14 @@ async def _create_document_reference( "type": params.type, "subject": {"reference": params.subject}, "date": datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), - "content": [{"attachment": {"contentType": params.content_type or "text/plain", "data": params.data}}], + "content": [ + { + "attachment": { + "contentType": params.content_type or "text/plain", + "data": params.data, + } + } + ], } if params.category: doc_ref["category"] = params.category @@ -410,7 +463,10 @@ async def _create_document_reference( try: async with httpx.AsyncClient() as client: response = await client.post( - f"{base_url}/DocumentReference", json=doc_ref, headers=auth_header, timeout=30.0, + f"{base_url}/DocumentReference", + json=doc_ref, + headers=auth_header, + timeout=30.0, ) response.raise_for_status() except httpx.HTTPStatusError as exc: @@ -427,13 +483,19 @@ async def _create_document_reference( logger.error( "FHIR DocumentReference create failed | status=%s | epic_error=%s | sent_payload=%s", - exc.response.status_code, error_detail, json.dumps(doc_ref), + exc.response.status_code, + error_detail, + json.dumps(doc_ref), extra={"trace_id": trace_id}, ) - # Raise a more descriptive error for the API to catch raise ValueError(f"Epic Error: {error_detail}") from exc except Exception as exc: - logger.error("FHIR DocumentReference create failed | error=%s: %s", type(exc).__name__, str(exc), extra={"trace_id": trace_id}) + logger.error( + "FHIR DocumentReference create failed | error=%s: %s", + type(exc).__name__, + str(exc), + extra={"trace_id": trace_id}, + ) raise resource_id: Optional[str] = None @@ -442,7 +504,11 @@ async def _create_document_reference( location = response.headers.get("Location", "") if location: history_marker = location.find("/_history/") - resource_id = location[:history_marker].split("/")[-1] if history_marker != -1 else location.split("/")[-1] + resource_id = ( + location[:history_marker].split("/")[-1] + if history_marker != -1 + else location.split("/")[-1] + ) if not resource_id: content_length = response.headers.get("content-length", "0") @@ -459,12 +525,14 @@ async def _create_document_reference( f"Status: {response.status_code}, Location: {location!r}, Body: {response.text[:200]!r}" ) - logger.info("FHIR DocumentReference create completed | resource_id=%s", resource_id, extra={"trace_id": trace_id}) - return FhirDocumentReferenceCreateOutput(resource_id=resource_id, resource=body if body else None) - - # ------------------------------------------------------------------ - # Action: search_document_reference - # ------------------------------------------------------------------ + logger.info( + "FHIR DocumentReference create completed | resource_id=%s", + resource_id, + extra={"trace_id": trace_id}, + ) + return FhirDocumentReferenceCreateOutput( + resource_id=resource_id, resource=body if body else None + ) async def _search_document_reference( self, params: FhirDocumentReferenceSearchInput, *, trace_id: str @@ -472,19 +540,35 @@ async def _search_document_reference( base_url = self._get_base_url() auth_header = await self._get_auth_header() - logger.info("FHIR DocumentReference search", extra={"trace_id": trace_id, "search_params": params.search_params}) + logger.info( + "FHIR DocumentReference search", + extra={"trace_id": trace_id, "search_params": params.search_params}, + ) try: async with httpx.AsyncClient() as client: response = await client.get( - f"{base_url}/DocumentReference", headers=auth_header, params=params.search_params, timeout=30.0, + f"{base_url}/DocumentReference", + headers=auth_header, + params=params.search_params, + timeout=30.0, ) response.raise_for_status() except httpx.HTTPStatusError as exc: - logger.error("FHIR DocumentReference search failed | status=%s | body=%s", exc.response.status_code, exc.response.text, extra={"trace_id": trace_id}) + logger.error( + "FHIR DocumentReference search failed | status=%s | body=%s", + exc.response.status_code, + exc.response.text, + extra={"trace_id": trace_id}, + ) raise except Exception as exc: - logger.error("FHIR DocumentReference search failed | error=%s: %s", type(exc).__name__, str(exc), extra={"trace_id": trace_id}) + logger.error( + "FHIR DocumentReference search failed | error=%s: %s", + type(exc).__name__, + str(exc), + extra={"trace_id": trace_id}, + ) raise data = response.json() @@ -498,4 +582,4 @@ async def _search_document_reference( len(resources), extra={"trace_id": trace_id}, ) - return FhirDocumentReferenceSearchOutput(resources=resources, total=total) \ No newline at end of file + return FhirDocumentReferenceSearchOutput(resources=resources, total=total) diff --git a/src/connectors/fhir_epic/schema.py b/src/connectors/fhir_epic/schema.py index eaeef26..55d8103 100644 --- a/src/connectors/fhir_epic/schema.py +++ b/src/connectors/fhir_epic/schema.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Annotated, Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional -from pydantic import BaseModel, Field, RootModel +from pydantic import BaseModel, Field # --------------------------------------------------------------------------- @@ -182,29 +182,11 @@ class FhirDocumentReferenceSearchOutput(BaseModel): """Total number of results reported by the Bundle.""" -# --------------------------------------------------------------------------- -# Unified operation input/output (one endpoint, multiple actions) -# --------------------------------------------------------------------------- - -_FhirEpicOperationUnion = Annotated[ - Union[ - FhirPatientReadInput, - FhirPatientSearchInput, - FhirEncounterSearchInput, - FhirDocumentReferenceCreateInput, - FhirDocumentReferenceSearchInput, - ], - Field(discriminator="action"), -] - -FhirEpicOperationInput = RootModel[_FhirEpicOperationUnion] - - class FhirEpicOperationOutput(BaseModel): """ - Combined output shape for schema documentation/manifest generation. + Unified output for all Epic FHIR actions (SDKConnector single output_model). - Individual handlers still return their specific output models. + Fields are populated depending on the action; unused fields are None. """ resource: Optional[Dict[str, Any]] = None diff --git a/src/connectors/google_drive/logic.py b/src/connectors/google_drive/logic.py index a4b2b3d..36e9107 100644 --- a/src/connectors/google_drive/logic.py +++ b/src/connectors/google_drive/logic.py @@ -1,17 +1,18 @@ from __future__ import annotations import asyncio -import json import base64 +import json import logging -from typing import Any, Union +from typing import Any from google.oauth2 import service_account from googleapiclient.discovery import build from googleapiclient.errors import HttpError from googleapiclient.http import MediaInMemoryUpload -from runtime import BaseConnector +from runtime import SDKConnector, sdk_action +from runtime.models import ErrorCategory from .exceptions import ( GoogleDriveAuthError, @@ -26,192 +27,47 @@ FilesListOperation, FilesUpdateOperation, FilesUploadOperation, - GoogleDriveOperationInput, GoogleDriveOperationOutput, PermissionsCreateOperation, ) logger = logging.getLogger("connectors.google_drive") -# Performant default for files.list so the API returns only needed metadata. DEFAULT_LIST_FIELDS = "nextPageToken, files(id, name, mimeType, webViewLink)" -_OperationUnion = Union[ - FilesCreateOperation, - FilesListOperation, - PermissionsCreateOperation, - FilesGetOperation, - FilesUpdateOperation, - FilesUploadOperation, - FilesDeleteOperation, -] - -class GoogleDriveConnector( - BaseConnector[GoogleDriveOperationInput, GoogleDriveOperationOutput], -): +class GoogleDriveConnector(SDKConnector): """ - Google Drive connector for files and permissions operations. + Google Drive connector: each Drive operation is an @sdk_action method. """ connector_id = "google_drive" action = "execute" + output_model = GoogleDriveOperationOutput - async def internal_execute( - self, params: GoogleDriveOperationInput, *, trace_id: str - ) -> GoogleDriveOperationOutput: - logger.info( - "Executing Google Drive operation", - extra={ - "trace_id": trace_id, - "connector_id": self.connector_id, - "action": self.action, - "action_type": params.root.action, - }, - ) - - drive = self._build_client() + error_map = { + GoogleDriveAuthError: (ErrorCategory.AUTH, "GDRIVE_AUTH"), + GoogleDriveRateLimitError: (ErrorCategory.RETRYABLE, "GDRIVE_RATE_LIMIT"), + GoogleDriveBusinessError: (ErrorCategory.BUSINESS, "GDRIVE_BUSINESS_RULE"), + GoogleDriveFatalError: (ErrorCategory.FATAL, "GDRIVE_FATAL"), + } + def build_client(self) -> Any: + raw_sa = self.secret_provider.get_secret("GOOGLE_DRIVE_SA_JSON") try: - response = await asyncio.to_thread( - self._dispatch_to_sdk, drive, params.root - ) - return GoogleDriveOperationOutput( - raw=response, - description=f"Successfully executed {params.root.action}", - ) - except HttpError as exc: - self._translate_and_raise_http_error(exc) - except Exception as exc: # noqa: BLE001 - logger.error( - "Unexpected SDK failure", - extra={ - "trace_id": trace_id, - "connector_id": self.connector_id, - "action": self.action, - "error_type": type(exc).__name__, - "error_message": str(exc), - }, - ) - raise GoogleDriveFatalError(str(exc)) from exc - - def _dispatch_to_sdk( - self, drive: Any, params: _OperationUnion - ) -> dict[str, Any]: - """Routes the strictly validated model to the correct SDK method.""" - if params.action == "files.create": - body = { - "name": params.name, - "mimeType": params.mime_type, - "parents": params.parents, - } - body = {k: v for k, v in body.items() if v is not None} - return drive.files().create(body=body, fields='id, name, webViewLink', - supportsAllDrives=True, - ).execute() - - if params.action == "files.list": - fields = params.fields or DEFAULT_LIST_FIELDS - result = drive.files().list( - pageSize=params.page_size, - q=params.query, - fields=fields, - supportsAllDrives=True, - includeItemsFromAllDrives=True, - ).execute() - return result - - if params.action == "permissions.create": - body = { - "role": params.role, - "type": params.type, - "emailAddress": params.email_address, - } - return drive.permissions().create( - fileId=params.file_id, - body=body, - supportsAllDrives=True, - ).execute() - - if params.action == "files.get": - fields = params.fields or "id,name,mimeType,parents" - return ( - drive.files() - .get( - fileId=params.file_id, - fields=fields, - supportsAllDrives=True, - ) - .execute() - ) - - if params.action == "files.update": - body: dict[str, Any] = {} - if params.name is not None: - body["name"] = params.name - if params.mime_type is not None: - body["mimeType"] = params.mime_type - - kwargs: dict[str, Any] = {} - if params.add_parents: - kwargs["addParents"] = ",".join(params.add_parents) - if params.remove_parents: - kwargs["removeParents"] = ",".join(params.remove_parents) - - return ( - drive.files() - .update( - fileId=params.file_id, - body=body, - **kwargs, - supportsAllDrives=True, - ) - .execute() - ) - - if params.action == "files.upload": - body = { - "name": params.name, - "mimeType": params.mime_type, - "parents": params.parents, - } - body = {k: v for k, v in body.items() if v is not None} - - if params.content_base64 is not None: - media_bytes = base64.b64decode(params.content_base64) - elif params.content is not None: - media_bytes = params.content.encode("utf-8") - else: - raise ValueError("Either content or content_base64 must be provided for files.upload") - - media = MediaInMemoryUpload( - media_bytes, - mimetype=params.mime_type, - resumable=False, + info = json.loads(raw_sa) + creds = service_account.Credentials.from_service_account_info( + info, + scopes=["https://www.googleapis.com/auth/drive"], ) - - return ( - drive.files() - .create( - body=body, - media_body=media, - fields='id, name, webViewLink', - supportsAllDrives=True, - ) - .execute() + except json.JSONDecodeError: + creds = service_account.Credentials.from_service_account_file( + raw_sa.strip(), + scopes=["https://www.googleapis.com/auth/drive"], ) - - if params.action == "files.delete": - drive.files().update(fileId=params.file_id, - body={'trashed': True}, - supportsAllDrives=True, - ).execute() - return {"file_id": params.file_id, "status": "deleted"} - - raise ValueError(f"Unmapped action router: {params.action}") + return build("drive", "v3", credentials=creds) def _translate_and_raise_http_error(self, exc: HttpError) -> None: - """Translates Google's dynamic HTTP errors into static taxonomy classes.""" status = exc.resp.status content_str = str(getattr(exc, "content", "") or "") @@ -220,9 +76,7 @@ def _translate_and_raise_http_error(self, exc: HttpError) -> None: raise GoogleDriveRateLimitError( "Google Drive quota/rate limit exceeded" ) from exc - raise GoogleDriveAuthError( - "Authentication or permissions failure" - ) from exc + raise GoogleDriveAuthError("Authentication or permissions failure") from exc if status == 429 or status >= 500: raise GoogleDriveRateLimitError( @@ -231,26 +85,207 @@ def _translate_and_raise_http_error(self, exc: HttpError) -> None: if status in (400, 404, 409): reason = getattr(exc, "reason", str(exc)) - raise GoogleDriveBusinessError( - f"Business logic failure: {reason}" - ) from exc + raise GoogleDriveBusinessError(f"Business logic failure: {reason}") from exc - raise GoogleDriveFatalError( - f"Unhandled HttpError status {status}" - ) from exc + raise GoogleDriveFatalError(f"Unhandled HttpError status {status}") from exc - def _build_client(self) -> Any: - raw_sa = self.secret_provider.get_secret("GOOGLE_DRIVE_SA_JSON") + @sdk_action("files.create") + async def files_create( + self, params: FilesCreateOperation, *, trace_id: str + ) -> GoogleDriveOperationOutput: + logger.info("Google Drive files.create", extra={"trace_id": trace_id}) + drive = self.get_client() + body = {k: v for k, v in { + "name": params.name, + "mimeType": params.mime_type, + "parents": params.parents, + }.items() if v is not None} try: - info = json.loads(raw_sa) - creds = service_account.Credentials.from_service_account_info( - info, - scopes=["https://www.googleapis.com/auth/drive"], + result = await asyncio.to_thread( + lambda: drive.files().create( + body=body, + fields="id, name, webViewLink", + supportsAllDrives=True, + ).execute() ) - except json.JSONDecodeError: - # Fallback: treat the secret as a file path - creds = service_account.Credentials.from_service_account_file( - raw_sa.strip(), - scopes=["https://www.googleapis.com/auth/drive"], + except HttpError as exc: + self._translate_and_raise_http_error(exc) + return GoogleDriveOperationOutput( + raw=result, description="Successfully executed files.create" + ) + + @sdk_action("files.list") + async def files_list( + self, params: FilesListOperation, *, trace_id: str + ) -> GoogleDriveOperationOutput: + logger.info("Google Drive files.list", extra={"trace_id": trace_id}) + drive = self.get_client() + fields = params.fields or DEFAULT_LIST_FIELDS + try: + result = await asyncio.to_thread( + lambda: drive.files().list( + pageSize=params.page_size, + q=params.query, + fields=fields, + pageToken=params.page_token, + supportsAllDrives=True, + includeItemsFromAllDrives=True, + ).execute() ) - return build("drive", "v3", credentials=creds) + except HttpError as exc: + self._translate_and_raise_http_error(exc) + return GoogleDriveOperationOutput( + raw=result, description="Successfully executed files.list" + ) + + @sdk_action("files.get") + async def files_get( + self, params: FilesGetOperation, *, trace_id: str + ) -> GoogleDriveOperationOutput: + logger.info( + "Google Drive files.get", + extra={"trace_id": trace_id, "file_id": params.file_id}, + ) + drive = self.get_client() + fields = params.fields or "id,name,mimeType,parents" + try: + result = await asyncio.to_thread( + lambda: drive.files().get( + fileId=params.file_id, + fields=fields, + supportsAllDrives=True, + ).execute() + ) + except HttpError as exc: + self._translate_and_raise_http_error(exc) + return GoogleDriveOperationOutput( + raw=result, description="Successfully executed files.get" + ) + + @sdk_action("files.update") + async def files_update( + self, params: FilesUpdateOperation, *, trace_id: str + ) -> GoogleDriveOperationOutput: + logger.info( + "Google Drive files.update", + extra={"trace_id": trace_id, "file_id": params.file_id}, + ) + drive = self.get_client() + body: dict[str, Any] = {} + if params.name is not None: + body["name"] = params.name + if params.mime_type is not None: + body["mimeType"] = params.mime_type + kwargs: dict[str, Any] = {} + if params.add_parents: + kwargs["addParents"] = ",".join(params.add_parents) + if params.remove_parents: + kwargs["removeParents"] = ",".join(params.remove_parents) + try: + result = await asyncio.to_thread( + lambda: drive.files().update( + fileId=params.file_id, + body=body, + supportsAllDrives=True, + **kwargs, + ).execute() + ) + except HttpError as exc: + self._translate_and_raise_http_error(exc) + return GoogleDriveOperationOutput( + raw=result, description="Successfully executed files.update" + ) + + @sdk_action("files.upload") + async def files_upload( + self, params: FilesUploadOperation, *, trace_id: str + ) -> GoogleDriveOperationOutput: + logger.info("Google Drive files.upload", extra={"trace_id": trace_id}) + drive = self.get_client() + body = {k: v for k, v in { + "name": params.name, + "mimeType": params.mime_type, + "parents": params.parents, + }.items() if v is not None} + if params.content_base64 is not None: + media_bytes = base64.b64decode(params.content_base64) + elif params.content is not None: + media_bytes = params.content.encode("utf-8") + else: + raise ValueError( + "Either content or content_base64 must be provided for files.upload" + ) + media = MediaInMemoryUpload( + media_bytes, + mimetype=params.mime_type, + resumable=False, + ) + try: + result = await asyncio.to_thread( + lambda: drive.files().create( + body=body, + media_body=media, + fields="id, name, webViewLink", + supportsAllDrives=True, + ).execute() + ) + except HttpError as exc: + self._translate_and_raise_http_error(exc) + return GoogleDriveOperationOutput( + raw=result, description="Successfully executed files.upload" + ) + + @sdk_action("files.delete") + async def files_delete( + self, params: FilesDeleteOperation, *, trace_id: str + ) -> GoogleDriveOperationOutput: + logger.info( + "Google Drive files.delete", + extra={"trace_id": trace_id, "file_id": params.file_id}, + ) + drive = self.get_client() + try: + await asyncio.to_thread( + lambda: drive.files().update( + fileId=params.file_id, + body={"trashed": True}, + supportsAllDrives=True, + ).execute() + ) + except HttpError as exc: + self._translate_and_raise_http_error(exc) + return GoogleDriveOperationOutput( + raw={"file_id": params.file_id, "status": "deleted"}, + description="Successfully executed files.delete", + ) + + @sdk_action("permissions.create") + async def permissions_create( + self, params: PermissionsCreateOperation, *, trace_id: str + ) -> GoogleDriveOperationOutput: + logger.info( + "Google Drive permissions.create", + extra={"trace_id": trace_id, "file_id": params.file_id}, + ) + drive = self.get_client() + body: dict[str, Any] = { + "role": params.role, + "type": params.type, + } + if params.email_address: + body["emailAddress"] = params.email_address + if params.domain: + body["domain"] = params.domain + try: + result = await asyncio.to_thread( + lambda: drive.permissions().create( + fileId=params.file_id, + body=body, + supportsAllDrives=True, + ).execute() + ) + except HttpError as exc: + self._translate_and_raise_http_error(exc) + return GoogleDriveOperationOutput( + raw=result, description="Successfully executed permissions.create" + ) diff --git a/src/connectors/google_drive/schema.py b/src/connectors/google_drive/schema.py index a2f22e8..9d516e9 100644 --- a/src/connectors/google_drive/schema.py +++ b/src/connectors/google_drive/schema.py @@ -11,9 +11,6 @@ class BaseDriveOperation(BaseModel): model_config = ConfigDict(extra="forbid") -# --- Specific Operation Schemas --- - - class FilesCreateOperation(BaseDriveOperation): action: Literal["files.create"] name: str = Field(..., description="The name of the file.") @@ -32,6 +29,10 @@ class FilesListOperation(BaseDriveOperation): "uses a performant default: nextPageToken, files(id, name, mimeType, webViewLink)." ), ) + page_token: Optional[str] = Field( + None, + description="Token for the next page of results from a previous files.list response.", + ) class PermissionsCreateOperation(BaseDriveOperation): @@ -91,11 +92,7 @@ class FilesDeleteOperation(BaseDriveOperation): file_id: str -# --- The Envelope --- -# The runtime validates against this single type. Pydantic automatically -# routes the validation to the correct sub-model based on the "action" field. -# RootModel accepts **raw_input in __init__ so BaseConnector's _input_model_cls(**raw_input) works. -_OperationUnion = Annotated[ +_GoogleDriveOperationUnion = Annotated[ Union[ FilesCreateOperation, FilesListOperation, @@ -108,9 +105,10 @@ class FilesDeleteOperation(BaseDriveOperation): Field(discriminator="action"), ] -GoogleDriveOperationInput = RootModel[_OperationUnion] +# Discriminated union for tests/agents; must stay aligned with GoogleDriveConnector @sdk_action set. +GoogleDriveOperationInput = RootModel[_GoogleDriveOperationUnion] class GoogleDriveOperationOutput(BaseModel): raw: Dict[str, Any] - description: str \ No newline at end of file + description: str diff --git a/src/connectors/http_generic/logic.py b/src/connectors/http_generic/logic.py index 88afc67..2536cf6 100644 --- a/src/connectors/http_generic/logic.py +++ b/src/connectors/http_generic/logic.py @@ -38,8 +38,6 @@ async def internal_execute(self, params: HttpRequestInput, *, trace_id: str) -> }, ) - print(f"trace_id: {trace_id} from node-wire-connector") - try: async with httpx.AsyncClient() as client: response = await client.request( diff --git a/src/connectors/manifest.py b/src/connectors/manifest.py index ffb5cfa..56984d9 100644 --- a/src/connectors/manifest.py +++ b/src/connectors/manifest.py @@ -4,76 +4,34 @@ from pydantic import BaseModel -from runtime import BaseConnector +from runtime import BaseConnector, SDKConnector def _schema_for(model: Type[BaseModel]) -> Dict[str, Any]: return model.model_json_schema() -def _fhir_action_schemas() -> Dict[str, Dict[str, Type[BaseModel]]]: - """Return per-action input model classes for FHIR connectors (lazy import).""" - from connectors.fhir_cerner.schema import ( - FhirCernerDocumentReferenceCreateInput, - FhirCernerDocumentReferenceSearchInput, - FhirCernerEncounterSearchInput, - FhirCernerPatientReadInput, - FhirCernerPatientSearchInput, - ) - from connectors.fhir_epic.schema import ( - FhirDocumentReferenceCreateInput, - FhirDocumentReferenceSearchInput, - FhirEncounterSearchInput, - FhirPatientReadInput, - FhirPatientSearchInput, - ) - - return { - "fhir_cerner": { - "read_patient": FhirCernerPatientReadInput, - "search_patients": FhirCernerPatientSearchInput, - "search_encounter": FhirCernerEncounterSearchInput, - "create_document_reference": FhirCernerDocumentReferenceCreateInput, - "search_document_reference": FhirCernerDocumentReferenceSearchInput, - }, - "fhir_epic": { - "read_patient": FhirPatientReadInput, - "search_patients": FhirPatientSearchInput, - "search_encounter": FhirEncounterSearchInput, - "create_document_reference": FhirDocumentReferenceCreateInput, - "search_document_reference": FhirDocumentReferenceSearchInput, - }, - } - - def build_manifest(connectors: List[BaseConnector[Any, Any]]) -> List[Dict[str, Any]]: """ - Build a simple manifest for discovery. - - Each entry describes a connector/action pair and includes JSON Schemas - for the input and output models. This is consumed by Layer C for - REST route generation and MCP tool manifests. + One manifest entry per SDK @sdk_action (specific input/output schemas), + or one entry per legacy BaseConnector. """ manifest: List[Dict[str, Any]] = [] - fhir_schemas: Dict[str, Dict[str, Type[BaseModel]]] | None = None - for connector in connectors: - output_model = connector._output_model_cls # type: ignore[attr-defined] cid = connector.connector_id - if getattr(connector, "action", None) == "execute" and cid in ("fhir_cerner", "fhir_epic"): - if fhir_schemas is None: - fhir_schemas = _fhir_action_schemas() - for sub_action, input_cls in fhir_schemas[cid].items(): + if isinstance(connector, SDKConnector): + for action_name, meta in type(connector).sdk_action_metas().items(): manifest.append( { "connector_id": cid, - "action": sub_action, - "input_schema": _schema_for(input_cls), - "output_schema": _schema_for(output_model), + "action": action_name, + "input_schema": _schema_for(meta.input_model), + "output_schema": _schema_for(meta.output_model), } ) else: input_model = connector._input_model_cls # type: ignore[attr-defined] + output_model = connector._output_model_cls # type: ignore[attr-defined] manifest.append( { "connector_id": cid, @@ -83,4 +41,3 @@ def build_manifest(connectors: List[BaseConnector[Any, Any]]) -> List[Dict[str, } ) return manifest - diff --git a/src/connectors/stripe/logic.py b/src/connectors/stripe/logic.py index 14e973f..aefa296 100644 --- a/src/connectors/stripe/logic.py +++ b/src/connectors/stripe/logic.py @@ -1,54 +1,67 @@ from __future__ import annotations +import asyncio import logging import stripe -from runtime import BaseConnector +from runtime import SDKConnector, sdk_action +from runtime.models import ErrorCategory from .schema import ChargeInput, ChargeOutput logger = logging.getLogger("connectors.stripe") -class StripeChargeConnector(BaseConnector[ChargeInput, ChargeOutput]): - """ - Stripe connector for creating charges using the official Stripe SDK. - """ +class StripeConnector(SDKConnector): + """Stripe connector: charges and future SDK operations as @sdk_action methods.""" connector_id = "stripe" action = "charge" + output_model = ChargeOutput - async def internal_execute(self, params: ChargeInput, *, trace_id: str) -> ChargeOutput: - # API key is expected to be provided by SecretProvider. + error_map = { + stripe.error.RateLimitError: (ErrorCategory.RETRYABLE, "STRIPE_RATE_LIMIT"), + stripe.error.APIConnectionError: (ErrorCategory.RETRYABLE, "STRIPE_API_CONNECTION"), + stripe.error.CardError: (ErrorCategory.BUSINESS, "STRIPE_CARD_ERROR"), + stripe.error.InvalidRequestError: (ErrorCategory.BUSINESS, "STRIPE_INVALID_REQUEST"), + stripe.error.AuthenticationError: (ErrorCategory.AUTH, "STRIPE_AUTH_ERROR"), + stripe.error.StripeError: (ErrorCategory.FATAL, "STRIPE_ERROR"), + } + + @sdk_action("charge") + async def charge(self, params: ChargeInput, *, trace_id: str) -> ChargeOutput: api_key = self.secret_provider.get_secret("stripe_api_key") - stripe.api_key = api_key logger.info( "Creating Stripe charge", extra={ "trace_id": trace_id, "connector_id": self.connector_id, - "action": self.action, + "action": "charge", "amount": params.amount, "currency": params.currency, }, ) - try: - charge = await stripe.Charge.create( # type: ignore[attr-defined] + def _create() -> stripe.Charge: + stripe.api_key = api_key + return stripe.Charge.create( amount=params.amount, currency=params.currency, source=params.source, description=params.description, ) - except Exception as exc: # noqa: BLE001 + + try: + charge = await asyncio.to_thread(_create) + except Exception as exc: logger.error( "Stripe charge creation failed", extra={ "trace_id": trace_id, "connector_id": self.connector_id, - "action": self.action, + "action": "charge", "amount": params.amount, "currency": params.currency, "error_type": type(exc).__name__, @@ -62,7 +75,7 @@ async def internal_execute(self, params: ChargeInput, *, trace_id: str) -> Charg extra={ "trace_id": trace_id, "connector_id": self.connector_id, - "action": self.action, + "action": "charge", "charge_id": charge.get("id"), }, ) @@ -71,4 +84,3 @@ async def internal_execute(self, params: ChargeInput, *, trace_id: str) -> Charg charge_id=charge.get("id"), receipt_url=charge.get("receipt_url"), ) - diff --git a/src/connectors/stripe/schema.py b/src/connectors/stripe/schema.py index bf7e6f6..e912829 100644 --- a/src/connectors/stripe/schema.py +++ b/src/connectors/stripe/schema.py @@ -1,9 +1,12 @@ from __future__ import annotations +from typing import Literal + from pydantic import BaseModel class ChargeInput(BaseModel): + action: Literal["charge"] = "charge" amount: int currency: str source: str @@ -13,4 +16,3 @@ class ChargeInput(BaseModel): class ChargeOutput(BaseModel): charge_id: str receipt_url: str | None = None - diff --git a/src/runtime/__init__.py b/src/runtime/__init__.py index 76d63e9..1e5c11f 100644 --- a/src/runtime/__init__.py +++ b/src/runtime/__init__.py @@ -3,6 +3,7 @@ from .base import BaseConnector from .secrets import SecretProvider from .policy import PolicyHook, PolicyDenied +from .sdk_connector import SDKConnector, sdk_action, _CONNECTOR_REGISTRY __all__ = [ "ConnectorResponse", @@ -12,4 +13,7 @@ "SecretProvider", "PolicyHook", "PolicyDenied", + "SDKConnector", + "sdk_action", + "_CONNECTOR_REGISTRY", ] diff --git a/src/runtime/base.py b/src/runtime/base.py index 25596d9..e470350 100644 --- a/src/runtime/base.py +++ b/src/runtime/base.py @@ -74,7 +74,6 @@ async def run( - Maps exceptions into the standard error taxonomy """ trace_id = str(uuid.uuid4()) - print(f"trace_id: {trace_id} from runtime.base") with tracer.start_as_current_span( "connector.run", @@ -97,7 +96,7 @@ async def run( try: try: - input_model = self._input_model_cls(**raw_input) + input_model = self._input_model_cls.model_validate(raw_input) except ValidationError as exc: logger.error( "Input validation failed", diff --git a/src/runtime/sdk_connector.py b/src/runtime/sdk_connector.py new file mode 100644 index 0000000..faa9688 --- /dev/null +++ b/src/runtime/sdk_connector.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import inspect +import logging +import uuid +from dataclasses import dataclass +from typing import ( + Annotated, + Any, + ClassVar, + Dict, + Optional, + Tuple, + Type, + Union, + get_type_hints, +) + +from pydantic import BaseModel, Field, RootModel + +from .base import BaseConnector +from .errors import ErrorMapper +from .models import ErrorCategory +from .secrets import SecretProvider + +logger = logging.getLogger("runtime.sdk_connector") + +# Populated by SDKConnector.__init_subclass__ +_CONNECTOR_REGISTRY: Dict[str, Type["SDKConnector"]] = {} + + +def sdk_action(name: str): + """ + Mark a connector method as a named, auto-discoverable SDK action. + + The decorated method must be async and have full type annotations for its + params (first arg after self) and return type. + """ + + def decorator(fn: Any) -> Any: + fn._sdk_action_name = name + return fn + + return decorator + + +@dataclass +class SdkActionMeta: + """Metadata for one @sdk_action method.""" + + name: str + fn_name: str + input_model: Type[BaseModel] + output_model: Type[BaseModel] + + +class SDKConnector(BaseConnector): + """ + Base class for SDK-backed connectors. + + Subclasses define: + - connector_id: str + - output_model: Type[BaseModel] (common output envelope for all actions) + - error_map: optional mapping of exception -> (ErrorCategory, code) + - build_client() / get_client() for vendor SDK lifecycle + + Actions are declared with @sdk_action("resource.operation") on async methods. + """ + + connector_id: str + action: str = "execute" + + error_map: ClassVar[Dict[Type[BaseException], Tuple[ErrorCategory, str]]] = {} + output_model: ClassVar[Type[BaseModel]] + + _action_registry: ClassVar[Dict[str, SdkActionMeta]] + _union_input_model: ClassVar[Type[RootModel[Any]]] + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + + registry: Dict[str, SdkActionMeta] = {} + for attr_name in dir(cls): + method = getattr(cls, attr_name, None) + if not callable(method) or not hasattr(method, "_sdk_action_name"): + continue + + try: + hints = get_type_hints(method) + except Exception: + hints = {} + + try: + sig_params = [ + p + for p in inspect.signature(method).parameters.values() + if p.name not in ("self", "trace_id") + ] + input_param_name = sig_params[0].name if sig_params else None + except (ValueError, TypeError): + input_param_name = None + + if not input_param_name: + raise TypeError( + f"{cls.__name__}.{attr_name}: @sdk_action method must have a params argument " + "after self" + ) + + input_model = hints.get(input_param_name) + output_model = hints.get("return") + if input_model is None or not isinstance(input_model, type) or not issubclass( + input_model, BaseModel + ): + raise TypeError( + f"{cls.__name__}.{attr_name}: missing or invalid type hint for " + f"parameter {input_param_name!r}" + ) + if output_model is None or not isinstance(output_model, type) or not issubclass( + output_model, BaseModel + ): + raise TypeError( + f"{cls.__name__}.{attr_name}: missing or invalid return type hint" + ) + + action_name = method._sdk_action_name + registry[action_name] = SdkActionMeta( + name=action_name, + fn_name=attr_name, + input_model=input_model, + output_model=output_model, + ) + + cls._action_registry = registry + + valid_models = [m.input_model for m in registry.values()] + if not valid_models: + raise TypeError(f"{cls.__name__}: SDKConnector must define at least one @sdk_action") + + if len(valid_models) == 1: + root_type = valid_models[0] + else: + root_type = Annotated[ + Union[tuple(valid_models)], # type: ignore[arg-type] + Field(discriminator="action"), + ] + + cls._union_input_model = RootModel[root_type] # type: ignore[valid-type] + cls._union_input_model.model_rebuild() + + own_error_map = cls.__dict__.get("error_map", {}) + for exc_type, (category, code) in own_error_map.items(): + ErrorMapper.register(exc_type, category, code=code) + + if "connector_id" in cls.__dict__: + _CONNECTOR_REGISTRY[cls.connector_id] = cls + logger.debug( + "Registered SDKConnector subclass", + extra={"connector_id": cls.connector_id}, + ) + + def __init__(self, *, secret_provider: Optional[SecretProvider] = None) -> None: + cls = type(self) + super().__init__( + cls._union_input_model, + cls.output_model, + secret_provider=secret_provider, + ) + self._client: Any = None + + @classmethod + def sdk_action_metas(cls) -> Dict[str, SdkActionMeta]: + """Registry of action name -> metadata (for manifest).""" + return dict(cls._action_registry) + + def build_client(self) -> Any: + """Override in subclasses to build the vendor SDK client.""" + return None + + def get_client(self) -> Any: + if self._client is None: + self._client = self.build_client() + return self._client + + async def internal_execute(self, params: Any, *, trace_id: str) -> Any: + """Dispatch to the @sdk_action method matching the validated input.""" + root = params.root if hasattr(params, "root") else params + action_key = getattr(root, "action", None) + if action_key is None: + raise ValueError(f"Input model missing action discriminator: {type(root).__name__}") + + meta = self._action_registry.get(str(action_key)) + if meta is None: + raise ValueError( + f"Connector {self.connector_id!r} has no registered action {action_key!r}. " + f"Available: {list(self._action_registry)}" + ) + fn = getattr(self, meta.fn_name) + logger.debug( + "Dispatching sdk_action", + extra={ + "connector_id": self.connector_id, + "action": action_key, + "trace_id": trace_id, + }, + ) + return await fn(root, trace_id=trace_id) + + async def call_action(self, name: str, params_dict: Dict[str, Any]) -> Any: + """Invoke another action by name (for composite operations).""" + meta = self._action_registry.get(name) + if meta is None: + raise ValueError( + f"call_action: unknown action {name!r} on connector {self.connector_id!r}" + ) + validated = meta.input_model.model_validate(params_dict) + fn = getattr(self, meta.fn_name) + return await fn(validated, trace_id=str(uuid.uuid4())) diff --git a/tests/test_connectors_basic.py b/tests/test_connectors_basic.py index a4e633e..f5db5b7 100644 --- a/tests/test_connectors_basic.py +++ b/tests/test_connectors_basic.py @@ -8,8 +8,7 @@ from connectors.http_generic.schema import HttpRequestInput, HttpResponseOutput from connectors.smtp.logic import SmtpConnector from connectors.smtp.schema import SmtpSendInput, SmtpSendOutput -from connectors.stripe.logic import StripeChargeConnector -from connectors.stripe.schema import ChargeInput, ChargeOutput +from connectors.stripe.logic import StripeConnector from runtime import ConnectorResponse, ErrorCategory, SecretProvider from connectors import auto_register @@ -25,6 +24,7 @@ def get_secret(self, key: str) -> str: def test_auto_register_runs_without_error(): imported = auto_register() assert any("http_generic.registration" in name for name in imported) + assert any("google_drive.logic" in name for name in imported) def test_http_connector_instantiation_only(): @@ -40,7 +40,7 @@ def test_smtp_connector_instantiation_only(): def test_stripe_connector_instantiation_only(): - connector = StripeChargeConnector(ChargeInput, ChargeOutput, secret_provider=DummySecretProvider()) + connector = StripeConnector(secret_provider=DummySecretProvider()) assert connector.connector_id == "stripe" assert connector.action == "charge" diff --git a/tests/test_google_drive.py b/tests/test_google_drive.py index 286d7a2..d9768ae 100644 --- a/tests/test_google_drive.py +++ b/tests/test_google_drive.py @@ -34,11 +34,7 @@ def __init__(self, status: int, *, content: str = "", reason: str = "") -> None: def _connector() -> GoogleDriveConnector: - return GoogleDriveConnector( - input_model=GoogleDriveOperationInput, - output_model=GoogleDriveOperationOutput, - secret_provider=MockSecretProvider(), - ) + return GoogleDriveConnector(secret_provider=MockSecretProvider()) def test_google_drive_internal_execute_files_list_happy_path(): @@ -50,7 +46,7 @@ def test_google_drive_internal_execute_files_list_happy_path(): list_call = files_api.list.return_value list_call.execute.return_value = {"files": [{"id": "f-1", "name": "Report"}]} - with patch.object(connector, "_build_client", return_value=drive): + with patch.object(connector, "get_client", return_value=drive): result = asyncio.run(connector.internal_execute(params, trace_id="test-trace")) assert result.raw == {"files": [{"id": "f-1", "name": "Report"}]} @@ -59,6 +55,7 @@ def test_google_drive_internal_execute_files_list_happy_path(): pageSize=5, q=None, fields=DEFAULT_LIST_FIELDS, + pageToken=None, supportsAllDrives=True, includeItemsFromAllDrives=True, ) diff --git a/tests/test_sdk_connector_manifest.py b/tests/test_sdk_connector_manifest.py new file mode 100644 index 0000000..504d5a1 --- /dev/null +++ b/tests/test_sdk_connector_manifest.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from bindings.factory import ConnectorFactory +from connectors import auto_register +from connectors.manifest import build_manifest +from connectors.stripe.schema import ChargeInput +from runtime import SDKConnector +from runtime.sdk_connector import _CONNECTOR_REGISTRY + + +def test_registry_contains_sdk_connectors(): + auto_register() + assert "google_drive" in _CONNECTOR_REGISTRY + assert "stripe" in _CONNECTOR_REGISTRY + assert "fhir_epic" in _CONNECTOR_REGISTRY + + +def test_manifest_emits_per_sdk_action(): + auto_register() + factory = ConnectorFactory() + factory.load() + rest_manifest = build_manifest(factory.list_for_protocol("rest")) + rest_actions = {(e["connector_id"], e["action"]) for e in rest_manifest} + assert ("google_drive", "files.list") in rest_actions + assert ("fhir_epic", "read_patient") in rest_actions + assert ("stripe", "charge") not in rest_actions # stripe is grpc/mcp only in config + + mcp_manifest = build_manifest(factory.list_for_protocol("mcp")) + mcp_actions = {(e["connector_id"], e["action"]) for e in mcp_manifest} + assert ("stripe", "charge") in mcp_actions + # Per-action input schema should not be the full union for SDK connectors + for entry in mcp_manifest: + if entry["connector_id"] == "stripe": + props = entry["input_schema"].get("properties", {}) + assert "amount" in props + + +def test_stripe_connector_is_sdk_and_accepts_charge_payload(): + auto_register() + factory = ConnectorFactory() + factory.load() + connector = factory.get_for_protocol("stripe", "grpc") + assert connector is not None + assert isinstance(connector, SDKConnector) + validated = ChargeInput.model_validate( + {"action": "charge", "amount": 100, "currency": "usd", "source": "tok_visa"} + ) + assert validated.action == "charge" + + +def test_mcp_tool_invoke_sets_action(): + from bindings.mcp_server.server import McpServer + + server = McpServer() + tools = server.list_tools() + names = {t["name"] for t in tools} + assert "google_drive.files.list" in names + assert "stripe.charge" in names From cab653b92c89e1da10cad5430cdcf55f46703d96 Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Tue, 31 Mar 2026 20:58:35 -0700 Subject: [PATCH 04/60] Updated architecture --- src/connectors/google_drive/README.md | 18 +- src/connectors/google_drive/action_spec.py | 188 +++++++++++++++++ src/connectors/google_drive/logic.py | 232 +++------------------ src/connectors/google_drive/schema.py | 9 +- src/runtime/__init__.py | 10 + src/runtime/sdk_action_spec.py | 123 +++++++++++ src/runtime/sdk_connector.py | 79 +++++++ tests/test_google_drive_action_spec.py | 132 ++++++++++++ 8 files changed, 580 insertions(+), 211 deletions(-) create mode 100644 src/connectors/google_drive/action_spec.py create mode 100644 src/runtime/sdk_action_spec.py create mode 100644 tests/test_google_drive_action_spec.py diff --git a/src/connectors/google_drive/README.md b/src/connectors/google_drive/README.md index 3b44409..8d644d0 100644 --- a/src/connectors/google_drive/README.md +++ b/src/connectors/google_drive/README.md @@ -2,7 +2,7 @@ > **Platform:** Node Wire > **Connector ID:** `google_drive` -> **Endpoint:** `POST /connectors/google_drive/execute` +> **REST:** One route per operation, e.g. `POST /connectors/google_drive/files.list` (the `action` field is still set on the body for `SDKConnector` dispatch). > **Discriminator:** `action` field (discriminated-union payload) > **Source:** `connectors/google_drive/` @@ -10,7 +10,21 @@ ## 1. Operations Overview -All requests go through a single `execute` endpoint. The `action` field determines which Google Drive operation runs. All responses share a common output shape and error taxonomy enforced by the runtime. +The runtime validates requests against the discriminated union in `schema.py`, then dispatches to `@sdk_action` handlers on `GoogleDriveConnector`. Each handler delegates to an **action spec** in `action_spec.py` that maps the validated model to the Google Drive API v3 client (`googleapiclient`). Shared concerns (thread offload, `HttpError` translation, logging) stay in `logic.py`. All responses share a common output shape and error taxonomy enforced by the runtime. + +### Action-spec layout + +| Piece | Role | +|-------|------| +| [`action_spec.py`](action_spec.py) | `GOOGLE_DRIVE_ACTION_SPECS`: per-action `SdkActionSpec` (resource path, method, field/body mapping, constants, optional `build_kwargs` / `post_process`). | +| [`logic.py`](logic.py) | Client build, `_translate_and_raise_http_error`, `_execute_action_spec`, thin `@sdk_action` methods. | +| [`runtime/sdk_action_spec.py`](../../runtime/sdk_action_spec.py) | Reusable primitives: `SdkActionSpec`, `default_build_kwargs`, `execute_spec_in_thread`. | + +**Adding a new operation:** Add a Pydantic variant in `schema.py` (with an `action` discriminator literal), extend the `GoogleDriveOperationInput` union, and add an entry to `GOOGLE_DRIVE_ACTION_SPECS` in `action_spec.py` (or a `build_kwargs` hook for non-generic cases such as multipart upload). `SDKConnector.__init_subclass__` auto-generates the handler — do **not** also add an `@sdk_action` method for the same action name, as that will raise a `TypeError` at class-definition time. + +### Migrating other SDK connectors + +Use the same pattern: put declarative mapping in a connector-local `*_action_spec` module; `SDKConnector.__init_subclass__` auto-generates `@sdk_action`-equivalent handlers from `action_specs`, so no manual `@sdk_action` decorators are needed for spec-driven actions. Use `SdkActionSpec.build_kwargs` when the vendor API needs custom assembly (uploads, explicit `None` args, etc.). ### Available Operations diff --git a/src/connectors/google_drive/action_spec.py b/src/connectors/google_drive/action_spec.py new file mode 100644 index 0000000..4b32fee --- /dev/null +++ b/src/connectors/google_drive/action_spec.py @@ -0,0 +1,188 @@ +""" +Google Drive action specs: mapping from validated Pydantic inputs to Drive API v3 calls. + +Used by GoogleDriveConnector to reduce per-action boilerplate while preserving +behavior (defaults, field masks, shared drives flags). +""" + +from __future__ import annotations + +import base64 +from typing import Any, Dict + +from googleapiclient.http import MediaInMemoryUpload +from pydantic import BaseModel + +from runtime.sdk_action_spec import SdkActionSpec + +from .schema import ( + FilesCreateOperation, + FilesDeleteOperation, + FilesGetOperation, + FilesListOperation, + FilesUpdateOperation, + FilesUploadOperation, + PermissionsCreateOperation, +) + +DEFAULT_LIST_FIELDS = "nextPageToken, files(id, name, mimeType, webViewLink)" + +# Action name -> SdkActionSpec (matches @sdk_action("...") strings) +GOOGLE_DRIVE_ACTION_SPECS: Dict[str, SdkActionSpec] = {} + + +def _register_files_create() -> None: + GOOGLE_DRIVE_ACTION_SPECS["files.create"] = SdkActionSpec( + resource_segments=("files",), + method_name="create", + body_from_model={ + "name": "name", + "mime_type": "mimeType", + "parents": "parents", + }, + constant_kwargs={ + "fields": "id, name, webViewLink", + "supportsAllDrives": True, + }, + input_model=FilesCreateOperation, + ) + + +def _build_files_list_kwargs(_drive: Any, model: BaseModel) -> Dict[str, Any]: + """Match legacy behavior: pass q/pageToken explicitly even when None.""" + p = model if isinstance(model, FilesListOperation) else FilesListOperation.model_validate( + model + ) + return { + "pageSize": p.page_size, + "q": p.query, + "fields": p.fields or DEFAULT_LIST_FIELDS, + "pageToken": p.page_token, + "supportsAllDrives": True, + "includeItemsFromAllDrives": True, + } + + +def _register_files_list() -> None: + GOOGLE_DRIVE_ACTION_SPECS["files.list"] = SdkActionSpec( + resource_segments=("files",), + method_name="list", + build_kwargs=_build_files_list_kwargs, + input_model=FilesListOperation, + ) + + +def _register_files_get() -> None: + GOOGLE_DRIVE_ACTION_SPECS["files.get"] = SdkActionSpec( + resource_segments=("files",), + method_name="get", + kwargs_from_model={"file_id": "fileId"}, + computed_kwargs={ + "fields": lambda p: p.fields or "id,name,mimeType,parents", + }, + constant_kwargs={"supportsAllDrives": True}, + input_model=FilesGetOperation, + ) + + +def _register_files_update() -> None: + GOOGLE_DRIVE_ACTION_SPECS["files.update"] = SdkActionSpec( + resource_segments=("files",), + method_name="update", + kwargs_from_model={"file_id": "fileId"}, + body_from_model={ + "name": "name", + "mime_type": "mimeType", + }, + computed_kwargs={ + "addParents": lambda p: ",".join(p.add_parents) if p.add_parents else None, + "removeParents": lambda p: ",".join(p.remove_parents) if p.remove_parents else None, + }, + constant_kwargs={"supportsAllDrives": True}, + include_empty_body=True, + input_model=FilesUpdateOperation, + ) + + +def _build_upload_kwargs(drive: Any, model: BaseModel) -> Dict[str, Any]: + params = model if isinstance(model, FilesUploadOperation) else FilesUploadOperation.model_validate( + model + ) + body = {k: v for k, v in { + "name": params.name, + "mimeType": params.mime_type, + "parents": params.parents, + }.items() if v is not None} + if params.content_base64 is not None: + media_bytes = base64.b64decode(params.content_base64) + elif params.content is not None: + media_bytes = params.content.encode("utf-8") + else: + raise ValueError( + "Either content or content_base64 must be provided for files.upload" + ) + media = MediaInMemoryUpload( + media_bytes, + mimetype=params.mime_type, + resumable=False, + ) + return { + "body": body, + "media_body": media, + "fields": "id, name, webViewLink", + "supportsAllDrives": True, + } + + +def _register_files_upload() -> None: + GOOGLE_DRIVE_ACTION_SPECS["files.upload"] = SdkActionSpec( + resource_segments=("files",), + method_name="create", + build_kwargs=_build_upload_kwargs, + input_model=FilesUploadOperation, + ) + + +def _register_files_delete() -> None: + def _post_delete(_result: Any, model: BaseModel) -> Dict[str, Any]: + file_id = getattr(model, "file_id", None) + return {"file_id": file_id, "status": "deleted"} + + GOOGLE_DRIVE_ACTION_SPECS["files.delete"] = SdkActionSpec( + resource_segments=("files",), + method_name="update", + kwargs_from_model={"file_id": "fileId"}, + body_constant={"trashed": True}, + constant_kwargs={"supportsAllDrives": True}, + post_process=_post_delete, + input_model=FilesDeleteOperation, + ) + + +def _register_permissions_create() -> None: + GOOGLE_DRIVE_ACTION_SPECS["permissions.create"] = SdkActionSpec( + resource_segments=("permissions",), + method_name="create", + kwargs_from_model={"file_id": "fileId"}, + body_from_model={ + "role": "role", + "type": "type", + "email_address": "emailAddress", + "domain": "domain", + }, + constant_kwargs={"supportsAllDrives": True}, + input_model=PermissionsCreateOperation, + ) + + +def _init_specs() -> None: + _register_files_create() + _register_files_list() + _register_files_get() + _register_files_update() + _register_files_upload() + _register_files_delete() + _register_permissions_create() + + +_init_specs() diff --git a/src/connectors/google_drive/logic.py b/src/connectors/google_drive/logic.py index 36e9107..f3a5408 100644 --- a/src/connectors/google_drive/logic.py +++ b/src/connectors/google_drive/logic.py @@ -1,7 +1,5 @@ from __future__ import annotations -import asyncio -import base64 import json import logging from typing import Any @@ -9,41 +7,36 @@ from google.oauth2 import service_account from googleapiclient.discovery import build from googleapiclient.errors import HttpError -from googleapiclient.http import MediaInMemoryUpload -from runtime import SDKConnector, sdk_action +from runtime import SDKConnector from runtime.models import ErrorCategory +from runtime.sdk_action_spec import execute_spec_in_thread +from .action_spec import DEFAULT_LIST_FIELDS, GOOGLE_DRIVE_ACTION_SPECS from .exceptions import ( GoogleDriveAuthError, GoogleDriveBusinessError, GoogleDriveFatalError, GoogleDriveRateLimitError, ) -from .schema import ( - FilesCreateOperation, - FilesDeleteOperation, - FilesGetOperation, - FilesListOperation, - FilesUpdateOperation, - FilesUploadOperation, - GoogleDriveOperationOutput, - PermissionsCreateOperation, -) +from .schema import GoogleDriveOperationOutput logger = logging.getLogger("connectors.google_drive") -DEFAULT_LIST_FIELDS = "nextPageToken, files(id, name, mimeType, webViewLink)" +# Re-export for tests and callers that imported from logic. +__all__ = ["DEFAULT_LIST_FIELDS", "GoogleDriveConnector"] class GoogleDriveConnector(SDKConnector): """ - Google Drive connector: each Drive operation is an @sdk_action method. + Google Drive connector: Drive API v3 operations are driven by action specs + (see action_spec.py) and thin @sdk_action handlers for logging and dispatch. """ connector_id = "google_drive" action = "execute" output_model = GoogleDriveOperationOutput + action_specs = GOOGLE_DRIVE_ACTION_SPECS error_map = { GoogleDriveAuthError: (ErrorCategory.AUTH, "GDRIVE_AUTH"), @@ -89,203 +82,26 @@ def _translate_and_raise_http_error(self, exc: HttpError) -> None: raise GoogleDriveFatalError(f"Unhandled HttpError status {status}") from exc - @sdk_action("files.create") - async def files_create( - self, params: FilesCreateOperation, *, trace_id: str - ) -> GoogleDriveOperationOutput: - logger.info("Google Drive files.create", extra={"trace_id": trace_id}) - drive = self.get_client() - body = {k: v for k, v in { - "name": params.name, - "mimeType": params.mime_type, - "parents": params.parents, - }.items() if v is not None} - try: - result = await asyncio.to_thread( - lambda: drive.files().create( - body=body, - fields="id, name, webViewLink", - supportsAllDrives=True, - ).execute() - ) - except HttpError as exc: - self._translate_and_raise_http_error(exc) - return GoogleDriveOperationOutput( - raw=result, description="Successfully executed files.create" - ) - - @sdk_action("files.list") - async def files_list( - self, params: FilesListOperation, *, trace_id: str + async def _execute_action_spec( + self, + action_name: str, + params: Any, + *, + trace_id: str, + log_extra: dict[str, Any] | None = None, ) -> GoogleDriveOperationOutput: - logger.info("Google Drive files.list", extra={"trace_id": trace_id}) + spec = GOOGLE_DRIVE_ACTION_SPECS.get(action_name) + if spec is None: + raise ValueError(f"No action spec registered for {action_name!r}") drive = self.get_client() - fields = params.fields or DEFAULT_LIST_FIELDS + extra = {"trace_id": trace_id, **(log_extra or {})} + logger.info("Google Drive %s", action_name, extra=extra) try: - result = await asyncio.to_thread( - lambda: drive.files().list( - pageSize=params.page_size, - q=params.query, - fields=fields, - pageToken=params.page_token, - supportsAllDrives=True, - includeItemsFromAllDrives=True, - ).execute() - ) - except HttpError as exc: - self._translate_and_raise_http_error(exc) - return GoogleDriveOperationOutput( - raw=result, description="Successfully executed files.list" - ) - - @sdk_action("files.get") - async def files_get( - self, params: FilesGetOperation, *, trace_id: str - ) -> GoogleDriveOperationOutput: - logger.info( - "Google Drive files.get", - extra={"trace_id": trace_id, "file_id": params.file_id}, - ) - drive = self.get_client() - fields = params.fields or "id,name,mimeType,parents" - try: - result = await asyncio.to_thread( - lambda: drive.files().get( - fileId=params.file_id, - fields=fields, - supportsAllDrives=True, - ).execute() - ) - except HttpError as exc: - self._translate_and_raise_http_error(exc) - return GoogleDriveOperationOutput( - raw=result, description="Successfully executed files.get" - ) - - @sdk_action("files.update") - async def files_update( - self, params: FilesUpdateOperation, *, trace_id: str - ) -> GoogleDriveOperationOutput: - logger.info( - "Google Drive files.update", - extra={"trace_id": trace_id, "file_id": params.file_id}, - ) - drive = self.get_client() - body: dict[str, Any] = {} - if params.name is not None: - body["name"] = params.name - if params.mime_type is not None: - body["mimeType"] = params.mime_type - kwargs: dict[str, Any] = {} - if params.add_parents: - kwargs["addParents"] = ",".join(params.add_parents) - if params.remove_parents: - kwargs["removeParents"] = ",".join(params.remove_parents) - try: - result = await asyncio.to_thread( - lambda: drive.files().update( - fileId=params.file_id, - body=body, - supportsAllDrives=True, - **kwargs, - ).execute() - ) + raw = await execute_spec_in_thread(drive, spec, params) except HttpError as exc: self._translate_and_raise_http_error(exc) return GoogleDriveOperationOutput( - raw=result, description="Successfully executed files.update" + raw=raw, + description=f"Successfully executed {action_name}", ) - @sdk_action("files.upload") - async def files_upload( - self, params: FilesUploadOperation, *, trace_id: str - ) -> GoogleDriveOperationOutput: - logger.info("Google Drive files.upload", extra={"trace_id": trace_id}) - drive = self.get_client() - body = {k: v for k, v in { - "name": params.name, - "mimeType": params.mime_type, - "parents": params.parents, - }.items() if v is not None} - if params.content_base64 is not None: - media_bytes = base64.b64decode(params.content_base64) - elif params.content is not None: - media_bytes = params.content.encode("utf-8") - else: - raise ValueError( - "Either content or content_base64 must be provided for files.upload" - ) - media = MediaInMemoryUpload( - media_bytes, - mimetype=params.mime_type, - resumable=False, - ) - try: - result = await asyncio.to_thread( - lambda: drive.files().create( - body=body, - media_body=media, - fields="id, name, webViewLink", - supportsAllDrives=True, - ).execute() - ) - except HttpError as exc: - self._translate_and_raise_http_error(exc) - return GoogleDriveOperationOutput( - raw=result, description="Successfully executed files.upload" - ) - - @sdk_action("files.delete") - async def files_delete( - self, params: FilesDeleteOperation, *, trace_id: str - ) -> GoogleDriveOperationOutput: - logger.info( - "Google Drive files.delete", - extra={"trace_id": trace_id, "file_id": params.file_id}, - ) - drive = self.get_client() - try: - await asyncio.to_thread( - lambda: drive.files().update( - fileId=params.file_id, - body={"trashed": True}, - supportsAllDrives=True, - ).execute() - ) - except HttpError as exc: - self._translate_and_raise_http_error(exc) - return GoogleDriveOperationOutput( - raw={"file_id": params.file_id, "status": "deleted"}, - description="Successfully executed files.delete", - ) - - @sdk_action("permissions.create") - async def permissions_create( - self, params: PermissionsCreateOperation, *, trace_id: str - ) -> GoogleDriveOperationOutput: - logger.info( - "Google Drive permissions.create", - extra={"trace_id": trace_id, "file_id": params.file_id}, - ) - drive = self.get_client() - body: dict[str, Any] = { - "role": params.role, - "type": params.type, - } - if params.email_address: - body["emailAddress"] = params.email_address - if params.domain: - body["domain"] = params.domain - try: - result = await asyncio.to_thread( - lambda: drive.permissions().create( - fileId=params.file_id, - body=body, - supportsAllDrives=True, - ).execute() - ) - except HttpError as exc: - self._translate_and_raise_http_error(exc) - return GoogleDriveOperationOutput( - raw=result, description="Successfully executed permissions.create" - ) diff --git a/src/connectors/google_drive/schema.py b/src/connectors/google_drive/schema.py index 9d516e9..aaf24d3 100644 --- a/src/connectors/google_drive/schema.py +++ b/src/connectors/google_drive/schema.py @@ -2,7 +2,7 @@ from typing import Annotated, Any, Dict, Literal, Optional, Union -from pydantic import BaseModel, ConfigDict, Field, RootModel, model_validator +from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator class BaseDriveOperation(BaseModel): @@ -43,6 +43,13 @@ class PermissionsCreateOperation(BaseDriveOperation): type: Literal["user", "group", "domain", "anyone"] domain: Optional[str] = Field(None, description="G Suite domain when type is domain.") + @field_validator("email_address", "domain", mode="before") + @classmethod + def _empty_str_to_none(cls, v: Any) -> Any: + if isinstance(v, str) and not v.strip(): + return None + return v + @model_validator(mode="after") def require_fields_for_perm_type(self) -> "PermissionsCreateOperation": if self.type in ("user", "group"): diff --git a/src/runtime/__init__.py b/src/runtime/__init__.py index 1e5c11f..b8ca184 100644 --- a/src/runtime/__init__.py +++ b/src/runtime/__init__.py @@ -4,6 +4,12 @@ from .secrets import SecretProvider from .policy import PolicyHook, PolicyDenied from .sdk_connector import SDKConnector, sdk_action, _CONNECTOR_REGISTRY +from .sdk_action_spec import ( + SdkActionSpec, + default_build_kwargs, + execute_spec_in_thread, + navigate_resource, +) __all__ = [ "ConnectorResponse", @@ -16,4 +22,8 @@ "SDKConnector", "sdk_action", "_CONNECTOR_REGISTRY", + "SdkActionSpec", + "default_build_kwargs", + "execute_spec_in_thread", + "navigate_resource", ] diff --git a/src/runtime/sdk_action_spec.py b/src/runtime/sdk_action_spec.py new file mode 100644 index 0000000..d940ff2 --- /dev/null +++ b/src/runtime/sdk_action_spec.py @@ -0,0 +1,123 @@ +""" +Generic action-spec primitives for SDK-backed connectors (e.g. googleapiclient). + +Subclasses describe how validated Pydantic models map to vendor SDK calls: +resource navigation, method name, keyword/body mapping, constants, and optional +custom builders or post-processors. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Optional, Tuple + +from pydantic import BaseModel + + +def navigate_resource(client: Any, segments: Tuple[str, ...]) -> Any: + """Traverse discovery-style APIs: client.files().permissions()...""" + api = client + for seg in segments: + api = getattr(api, seg)() + return api + + +def default_build_kwargs( + *, + kwargs_from_model: Dict[str, str], + body_from_model: Optional[Dict[str, str]], + body_constant: Optional[Dict[str, Any]], + constant_kwargs: Dict[str, Any], + computed_kwargs: Dict[str, Callable[[BaseModel], Any]], + include_empty_body: bool, + model: BaseModel, +) -> Dict[str, Any]: + """Build SDK method kwargs from a validated input model.""" + kw: Dict[str, Any] = dict(constant_kwargs) + + for attr, sdk_name in kwargs_from_model.items(): + val = getattr(model, attr, None) + if val is not None: + kw[sdk_name] = val + + for sdk_name, fn in computed_kwargs.items(): + val = fn(model) + if val is not None: + kw[sdk_name] = val + + body: Dict[str, Any] = {} + if body_constant: + body.update(body_constant) + if body_from_model: + for attr, bkey in body_from_model.items(): + val = getattr(model, attr, None) + if val is not None: + body[bkey] = val + + if body_from_model is not None or body_constant is not None: + if body or include_empty_body: + kw["body"] = body + + return kw + + +@dataclass(frozen=True) +class SdkActionSpec: + """ + Describes one vendor SDK call: resource().method(**kwargs).execute() + + When ``build_kwargs`` is None, kwargs are built from the mapping fields. + When ``build_kwargs`` is set, it receives (client, model) and must return + the full kwargs dict for the SDK method. + """ + + resource_segments: Tuple[str, ...] + method_name: str + kwargs_from_model: Dict[str, str] = field(default_factory=dict) + body_from_model: Optional[Dict[str, str]] = None + body_constant: Optional[Dict[str, Any]] = None + constant_kwargs: Dict[str, Any] = field(default_factory=dict) + computed_kwargs: Dict[str, Callable[[BaseModel], Any]] = field(default_factory=dict) + # Pass body={} when the API requires a body key even if empty (e.g. files.update). + include_empty_body: bool = False + build_kwargs: Optional[Callable[[Any, BaseModel], Dict[str, Any]]] = None + post_process: Optional[Callable[[Any, BaseModel], Any]] = None + # Set these when the spec is declared in a connector's action_specs class var. + # input_model is required; output_model falls back to cls.output_model if None. + input_model: Optional[Any] = None + output_model: Optional[Any] = None + + +def build_method_kwargs(spec: SdkActionSpec, client: Any, model: BaseModel) -> Dict[str, Any]: + if spec.build_kwargs is not None: + return spec.build_kwargs(client, model) + return default_build_kwargs( + kwargs_from_model=spec.kwargs_from_model, + body_from_model=spec.body_from_model, + body_constant=spec.body_constant, + constant_kwargs=spec.constant_kwargs, + computed_kwargs=spec.computed_kwargs, + include_empty_body=spec.include_empty_body, + model=model, + ) + + +def execute_spec_sync(client: Any, spec: SdkActionSpec, model: BaseModel) -> Any: + """Run spec.method_name on navigated resource; return execute() result (sync).""" + kwargs = build_method_kwargs(spec, client, model) + resource_api = navigate_resource(client, spec.resource_segments) + method = getattr(resource_api, spec.method_name) + result = method(**kwargs).execute() + if spec.post_process is not None: + return spec.post_process(result, model) + return result + + +async def execute_spec_in_thread( + client: Any, + spec: SdkActionSpec, + model: BaseModel, +) -> Any: + """Run execute_spec_sync in a worker thread (for sync googleapiclient).""" + return await asyncio.to_thread(execute_spec_sync, client, spec, model) diff --git a/src/runtime/sdk_connector.py b/src/runtime/sdk_connector.py index faa9688..30e4b50 100644 --- a/src/runtime/sdk_connector.py +++ b/src/runtime/sdk_connector.py @@ -22,6 +22,7 @@ from .errors import ErrorMapper from .models import ErrorCategory from .secrets import SecretProvider +from .sdk_action_spec import SdkActionSpec logger = logging.getLogger("runtime.sdk_connector") @@ -29,6 +30,80 @@ _CONNECTOR_REGISTRY: Dict[str, Type["SDKConnector"]] = {} +def _make_spec_handler( + action_name: str, + input_model: Any, + output_model: Any, + cls_qualname: str, + cls_module: str, +) -> Any: + """ + Build a single async handler function for one action_specs entry. + Using a factory function (rather than a loop + default-arg trick) ensures + action_name is captured by value in the closure and does not appear in the + method signature seen by inspect.signature / get_type_hints. + """ + fn_name = action_name.replace(".", "_").replace("-", "_") + + async def _handler(self, params, *, trace_id: str): + return await self._execute_action_spec(action_name, params, trace_id=trace_id) + + _handler.__name__ = fn_name + _handler.__qualname__ = f"{cls_qualname}.{fn_name}" + _handler.__module__ = cls_module + # Set actual type objects (not strings) so get_type_hints() resolves correctly + # even when `from __future__ import annotations` is active in the connector module. + _handler.__annotations__ = {"params": input_model, "return": output_model} + _handler._sdk_action_name = action_name + return _handler + + +def _generate_methods_from_action_specs(cls: type) -> None: + """ + For each entry in cls.action_specs, generate an async @sdk_action method and + attach it to cls. Called at the top of SDKConnector.__init_subclass__ so the + existing discovery loop picks up the generated methods. + + Opt-in: only triggers when the class defines action_specs in its own __dict__. + """ + specs = cls.__dict__.get("action_specs") + if specs is None: + return + + fallback_output = getattr(cls, "output_model", None) + + for action_name, spec in specs.items(): + if not isinstance(spec, SdkActionSpec): + raise TypeError( + f"{cls.__name__}: action_specs[{action_name!r}] must be a SdkActionSpec instance" + ) + input_model = spec.input_model + if not (isinstance(input_model, type) and issubclass(input_model, BaseModel)): + raise TypeError( + f"{cls.__name__}: action_specs[{action_name!r}] requires " + "input_model=" + ) + + output_model = spec.output_model if spec.output_model is not None else fallback_output + if not (isinstance(output_model, type) and issubclass(output_model, BaseModel)): + raise TypeError( + f"{cls.__name__}: action_specs[{action_name!r}] has no resolvable " + "output_model — set it on the SdkActionSpec or define cls.output_model" + ) + + fn_name = action_name.replace(".", "_").replace("-", "_") + if fn_name in cls.__dict__: + raise TypeError( + f"{cls.__name__}: action_specs[{action_name!r}] conflicts with " + f"existing method {fn_name!r}" + ) + + handler = _make_spec_handler( + action_name, input_model, output_model, cls.__qualname__, cls.__module__ + ) + setattr(cls, fn_name, handler) + + def sdk_action(name: str): """ Mark a connector method as a named, auto-discoverable SDK action. @@ -79,6 +154,10 @@ class SDKConnector(BaseConnector): def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) + # Phase 0: auto-generate @sdk_action methods from action_specs (opt-in). + # Must run before the dir(cls) discovery loop below. + _generate_methods_from_action_specs(cls) + registry: Dict[str, SdkActionMeta] = {} for attr_name in dir(cls): method = getattr(cls, attr_name, None) diff --git a/tests/test_google_drive_action_spec.py b/tests/test_google_drive_action_spec.py new file mode 100644 index 0000000..6220e60 --- /dev/null +++ b/tests/test_google_drive_action_spec.py @@ -0,0 +1,132 @@ +"""Tests for Google Drive action specs and SDK call mapping.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import MagicMock, patch + +from connectors.google_drive.action_spec import GOOGLE_DRIVE_ACTION_SPECS +from connectors.google_drive.logic import GoogleDriveConnector +from connectors.google_drive.schema import GoogleDriveOperationInput +from runtime import SecretProvider + + +class MockSecretProvider(SecretProvider): + def get_secret(self, key: str) -> str: + return { + "GOOGLE_DRIVE_SA_JSON": '{"type":"service_account","project_id":"dummy"}', + }[key] + + +def _connector() -> GoogleDriveConnector: + return GoogleDriveConnector(secret_provider=MockSecretProvider()) + + +def test_action_spec_registry_covers_all_sdk_actions(): + """Every @sdk_action on GoogleDriveConnector must have a spec entry.""" + metas = GoogleDriveConnector.sdk_action_metas() + for action_name in metas: + assert action_name in GOOGLE_DRIVE_ACTION_SPECS, f"missing spec for {action_name}" + + +def test_files_create_maps_body_and_constants(): + connector = _connector() + params = GoogleDriveOperationInput.model_validate( + { + "action": "files.create", + "name": "doc.txt", + "mime_type": "text/plain", + "parents": ["p1"], + } + ) + + drive = MagicMock() + files_api = drive.files.return_value + create_call = files_api.create.return_value + create_call.execute.return_value = {"id": "new-id", "name": "doc.txt"} + + with patch.object(connector, "get_client", return_value=drive): + result = asyncio.run(connector.internal_execute(params, trace_id="t")) + + assert result.raw == {"id": "new-id", "name": "doc.txt"} + files_api.create.assert_called_once_with( + body={"name": "doc.txt", "mimeType": "text/plain", "parents": ["p1"]}, + fields="id, name, webViewLink", + supportsAllDrives=True, + ) + + +def test_files_delete_returns_synthetic_raw(): + connector = _connector() + params = GoogleDriveOperationInput.model_validate( + {"action": "files.delete", "file_id": "fid-99"} + ) + + drive = MagicMock() + files_api = drive.files.return_value + upd = files_api.update.return_value + upd.execute.return_value = {"id": "fid-99", "trashed": True} + + with patch.object(connector, "get_client", return_value=drive): + result = asyncio.run(connector.internal_execute(params, trace_id="t")) + + assert result.raw == {"file_id": "fid-99", "status": "deleted"} + files_api.update.assert_called_once_with( + fileId="fid-99", + body={"trashed": True}, + supportsAllDrives=True, + ) + + +def test_permissions_create_maps_body(): + connector = _connector() + params = GoogleDriveOperationInput.model_validate( + { + "action": "permissions.create", + "file_id": "f1", + "role": "reader", + "type": "user", + "email_address": "a@b.com", + } + ) + + drive = MagicMock() + perms = drive.permissions.return_value + perms.create.return_value.execute.return_value = {"id": "perm-1"} + + with patch.object(connector, "get_client", return_value=drive): + result = asyncio.run(connector.internal_execute(params, trace_id="t")) + + assert result.raw == {"id": "perm-1"} + perms.create.assert_called_once_with( + fileId="f1", + body={"role": "reader", "type": "user", "emailAddress": "a@b.com"}, + supportsAllDrives=True, + ) + + +def test_permissions_create_excludes_empty_optional_fields(): + """Empty-string email_address and domain must be excluded from the body (not sent as "").""" + connector = _connector() + params = GoogleDriveOperationInput.model_validate( + { + "action": "permissions.create", + "file_id": "file-abc", + "role": "reader", + "type": "anyone", + "email_address": "", + "domain": "", + } + ) + + drive = MagicMock() + perms = drive.permissions.return_value + perms.create.return_value.execute.return_value = {"kind": "drive#permission"} + + with patch.object(connector, "get_client", return_value=drive): + asyncio.run(connector.internal_execute(params, trace_id="t-empty")) + + _, kwargs = perms.create.call_args + body = kwargs["body"] + assert "emailAddress" not in body + assert "domain" not in body From 27013800755109063d99f7904e2b4a0b956062bc Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Tue, 31 Mar 2026 21:44:57 -0700 Subject: [PATCH 05/60] Update Playground to work with the new architecture --- playground/scenarios.py | 132 +++++++++++++++++++++------------------- 1 file changed, 69 insertions(+), 63 deletions(-) diff --git a/playground/scenarios.py b/playground/scenarios.py index 89ded8a..584b9a4 100644 --- a/playground/scenarios.py +++ b/playground/scenarios.py @@ -185,51 +185,53 @@ async def execute_with_retry(action: Any, input_data: Any, trace_id: str, step: raise last_exception +# Single shared factory for playground scenarios (matches REST: enabled + exposed_via includes "rest"). +_playground_factory: Optional[Any] = None + + +def get_playground_factory() -> Any: + """Lazily load connector config once; same pattern as bindings REST `get_factory`.""" + global _playground_factory + if _playground_factory is None: + from bindings.factory import ConnectorFactory + from connectors import auto_register + + _playground_factory = ConnectorFactory() + auto_register() + _playground_factory.load() + return _playground_factory + + +def resolve_connector(connector_id: str, action: Optional[str] = None) -> Any: + """Resolve a connector via public factory API (protocol-aware).""" + factory = get_playground_factory() + return factory.get_for_protocol(connector_id, "rest", action=action) + + def get_fhir_connector() -> FhirEpicConnector: - # Use global accessor instead of circular import - from bindings.factory import ConnectorFactory - factory = ConnectorFactory() - from connectors import auto_register - auto_register() - factory.load() - - connector = factory._connectors.get("fhir_epic") + connector = resolve_connector("fhir_epic") if not connector: raise HTTPException(status_code=500, detail="FHIR Epic connector not configured") - return connector + return connector # type: ignore[return-value] + def get_http_connector(): - from bindings.factory import ConnectorFactory - factory = ConnectorFactory() - from connectors import auto_register - auto_register() - factory.load() - - connector = factory._connectors.get("http_generic") + # Manifest action for http_generic is "request"; pass it for parity with REST routing. + connector = resolve_connector("http_generic", action="request") if not connector: raise HTTPException(status_code=500, detail="Generic HTTP connector not configured") return connector -def get_cerner_connector(): - from bindings.factory import ConnectorFactory - factory = ConnectorFactory() - from connectors import auto_register - auto_register() - factory.load() - connector = factory._connectors.get("fhir_cerner") +def get_cerner_connector(): + connector = resolve_connector("fhir_cerner") if not connector: raise HTTPException(status_code=500, detail="FHIR Cerner connector not configured") return connector -def get_google_drive_connector(): - from bindings.factory import ConnectorFactory - factory = ConnectorFactory() - from connectors import auto_register - auto_register() - factory.load() - connector = factory._connectors.get("google_drive") +def get_google_drive_connector(): + connector = resolve_connector("google_drive") if not connector: raise HTTPException(status_code=500, detail="Google Drive connector not configured") return connector @@ -250,12 +252,10 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", # STEP 1: Patient Discovery add_step("Patient Discovery", "pending", display_name="Identify Patient") try: - patient_action = connector.get_action("read_patient") - if payload.patient_id: logger.info(f"Performing direct Patient ID lookup: {payload.patient_id}") p_res = await execute_with_retry( - patient_action, + connector, FhirPatientReadInput(resource_id=payload.patient_id), trace_id, steps[-1] @@ -269,7 +269,7 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", } logger.info(f"Searching for patient: {patient_search_params}") p_res = await execute_with_retry( - patient_action, + connector, FhirPatientReadInput(search_params=patient_search_params), trace_id, steps[-1] @@ -297,20 +297,19 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", enc_status = "verified" else: visit_date = payload.visit_date or datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") - encounter_action = connector.get_action("search_encounter") logger.info(f"Searching for encounter... patient={patient_id}, date={visit_date}", extra={"trace_id": trace_id}) enc_res = await execute_with_retry( - encounter_action, + connector, FhirEncounterSearchInput(search_params={"patient": patient_id, "status": "finished", "date": visit_date}), trace_id, steps[-1] ) - + resources = enc_res.resources if not resources: # Fallback to any finished encounter enc_res = await execute_with_retry( - encounter_action, + connector, FhirEncounterSearchInput(search_params={"patient": patient_id, "status": "finished"}), trace_id, steps[-1] @@ -355,20 +354,18 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", context={"encounter": [{"reference": f"Encounter/{encounter_id}"}]} ) - doc_action = connector.get_action("create_document_reference") - doc_res = await execute_with_retry(doc_action, doc_input, trace_id, steps[-1]) - + doc_res = await execute_with_retry(connector, doc_input, trace_id, steps[-1]) + steps[-1].status = "success" steps[-1].details = f"EHR Updated. ID: {doc_res.resource_id}" steps[-1].display_name = "Note Synced Successfully" steps[-1].data = {"resource_id": doc_res.resource_id, "raw": doc_res.resource if (hasattr(doc_res, 'resource') and doc_res.resource) else {"id": doc_res.resource_id, "status": "created", "note": "Resource payload not returned by Epic integration."}} - + # STEP 4: Verification / Visualization add_step("Document Verification", "pending", display_name="Verify EHR Update") try: - doc_search_action = connector.get_action("search_document_reference") verify_res = await execute_with_retry( - doc_search_action, + connector, FhirDocumentReferenceSearchInput(search_params={"patient": patient_id, "_id": doc_res.resource_id}), trace_id, steps[-1] @@ -567,12 +564,10 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", # STEP 1: Patient Discovery add_step("Patient Discovery", "pending", display_name="Identify Patient") try: - patient_action = connector.get_action("read_patient") - if payload.patient_id: logger.info(f"Cerner: direct Patient ID lookup: {payload.patient_id}") p_res = await execute_with_retry( - patient_action, + connector, FhirCernerPatientReadInput(resource_id=payload.patient_id), trace_id, steps[-1] @@ -586,7 +581,7 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", }.items() if v} logger.info(f"Cerner: searching for patient: {search_params}") p_res = await execute_with_retry( - patient_action, + connector, FhirCernerPatientReadInput(search_params=search_params), trace_id, steps[-1] @@ -617,9 +612,8 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", selected_enc = {"id": encounter_id, "note": "Manual ID used"} else: visit_date = payload.visit_date or datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") - encounter_action = connector.get_action("search_encounter") enc_res = await execute_with_retry( - encounter_action, + connector, FhirCernerEncounterSearchInput( search_params={"patient": patient_id, "status": "finished", "date": visit_date} ), @@ -631,7 +625,7 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", if not resources: # Fallback: any finished encounter for this patient enc_res = await execute_with_retry( - encounter_action, + connector, FhirCernerEncounterSearchInput( search_params={"patient": patient_id, "status": "finished"} ), @@ -668,7 +662,7 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", # Cerner requires CodeSet 72 proprietary system — NOT a raw LOINC system URL. # The tenant ID is embedded in the connector's FHIR base URL path segment. try: - base_url_secret = connector._secret_provider.get_secret("cerner_fhir_base_url") + base_url_secret = connector.secret_provider.get_secret("cerner_fhir_base_url") # Extract tenant from URL: .../r4/{tenant_id} or similar parts = [p for p in base_url_secret.rstrip("/").split("/") if p] tenant_id = parts[-1] if parts else "your-tenant-id" @@ -708,8 +702,7 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", }, ) - doc_action = connector.get_action("create_document_reference") - doc_res = await execute_with_retry(doc_action, doc_input, trace_id, steps[-1]) + doc_res = await execute_with_retry(connector, doc_input, trace_id, steps[-1]) steps[-1].status = "success" steps[-1].details = f"Cerner EHR Updated. ID: {doc_res.resource_id}" @@ -723,9 +716,8 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", # STEP 4: Verification add_step("Document Verification", "pending", display_name="Verify EHR Update") try: - doc_search_action = connector.get_action("search_document_reference") verify_res = await execute_with_retry( - doc_search_action, + connector, FhirCernerDocumentReferenceSearchInput( search_params={"_id": doc_res.resource_id} ), @@ -826,7 +818,9 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", fields=fields, ) list_input = GoogleDriveOperationInput.model_validate(list_op.model_dump(exclude_none=True)) - res = await execute_with_retry(connector, list_input, trace_id, steps[-1]) + res = await execute_with_retry( + connector, list_input, trace_id, steps[-1] + ) n = len(res.raw.get("files") or []) steps[-1].status = "success" steps[-1].details = f"Retrieved {n} file(s) (page_size={page_size})" @@ -855,7 +849,9 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", fields=gf, ) get_input = GoogleDriveOperationInput.model_validate(get_op.model_dump(exclude_none=True)) - res = await execute_with_retry(connector, get_input, trace_id, steps[-1]) + res = await execute_with_retry( + connector, get_input, trace_id, steps[-1] + ) got_id = res.raw.get("id") or fid name = res.raw.get("name", "") steps[-1].status = "success" @@ -916,7 +912,9 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", add_step("Drive Update", "pending", display_name="Apply file update") try: - res = await execute_with_retry(connector, upd_input, trace_id, steps[-1]) + res = await execute_with_retry( + connector, upd_input, trace_id, steps[-1] + ) except Exception as e: return _safe_error_return(e, steps, trace_id, "files.update failed") @@ -937,7 +935,9 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", get_input = GoogleDriveOperationInput.model_validate( get_op.model_dump(exclude_none=True) ) - get_res = await execute_with_retry(connector, get_input, trace_id, steps[-1]) + get_res = await execute_with_retry( + connector, get_input, trace_id, steps[-1] + ) except Exception as e: return _safe_error_return(e, steps, trace_id, "files.update verify failed") @@ -1000,7 +1000,9 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", upload_input = GoogleDriveOperationInput.model_validate(op_payload) - res = await execute_with_retry(connector, upload_input, trace_id, steps[-1]) + res = await execute_with_retry( + connector, upload_input, trace_id, steps[-1] + ) file_id = res.raw.get("id") if not file_id: @@ -1025,7 +1027,9 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", type="user" ) ) - perm_res = await execute_with_retry(connector, perm_input, trace_id, steps[-1]) + perm_res = await execute_with_retry( + connector, perm_input, trace_id, steps[-1] + ) steps[-1].status = "success" steps[-1].details = f"Read access granted to {payload.recipient_email}" @@ -1044,7 +1048,9 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", fields="id, name, mimeType, webViewLink, size, createdTime, owners" ) ) - get_res = await connector.internal_execute(get_input, trace_id=trace_id) + get_res = await execute_with_retry( + connector, get_input, trace_id, steps[-1] + ) file_metadata = get_res.raw beautiful_data = { From 00e3fbd2838525ded924e97a5c317cbbadf6ac09 Mon Sep 17 00:00:00 2001 From: kesav-aot Date: Wed, 1 Apr 2026 01:12:05 -0700 Subject: [PATCH 06/60] Register full MCP tools for FHIR, Drive, SMTP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Expand MCP server entrypoints to dynamically register and expose all connector actions for fhir_epic and fhir_cerner, and to expose full Google Drive operations and improved SMTP behavior. Added multiple new MCP tools: patient search, encounter search, create/search DocumentReference for both Epic and Cerner; several Google Drive tools (files.create/list/get/update/upload/delete, permissions.create); and SMTP now accepts multiple recipients and improved logging. Refactored per-server helpers (e.g. _get_connector), standardized input handling/parsing and return shapes, and updated docs/mcp-servers.md to list the exposed tools. Also adds a service_account.json for Google Drive usage (contains service account credentials — treat as sensitive and consider moving to secrets management). --- docs/mcp-servers.md | 8 +- src/agents/fhir_cerner_mcp.py | 283 ++++++++++++++++++++++++++++++--- src/agents/fhir_epic_mcp.py | 271 +++++++++++++++++++++++++++---- src/agents/google_drive_mcp.py | 229 ++++++++++++++++++++++++-- src/agents/smtp_mcp.py | 14 +- 5 files changed, 727 insertions(+), 78 deletions(-) diff --git a/docs/mcp-servers.md b/docs/mcp-servers.md index 1dfe8de..cf761b1 100644 --- a/docs/mcp-servers.md +++ b/docs/mcp-servers.md @@ -41,11 +41,11 @@ flowchart TD ## Naming conventions -| Connector | Python entrypoint | Docker image | ToolHive name | MCP tool(s) exposed | +| Connector | Python entrypoint | Docker image | ToolHive name | MCP tools exposed | |---|---|---|---|---| -| Google Drive | `python -m agents.google_drive_mcp` | `nw-google-drive` | `nw-google-drive` | `google_drive_upload_file` | -| SMART on FHIR (Epic) | `python -m agents.fhir_epic_mcp` | `nw-smartonfhir-epic` | `nw-smartonfhir-epic` | `fhir_epic_read_patient` | -| SMART on FHIR (Cerner) | `python -m agents.fhir_cerner_mcp` | `nw-smartonfhir-cerner` | `nw-smartonfhir-cerner` | `fhir_cerner_read_patient` | +| Google Drive | `python -m agents.google_drive_mcp` | `nw-google-drive` | `nw-google-drive` | `google_drive_files_create`, `google_drive_files_list`, `google_drive_permissions_create`, `google_drive_files_get`, `google_drive_files_update`, `google_drive_files_upload`, `google_drive_files_delete` | +| SMART on FHIR (Epic) | `python -m agents.fhir_epic_mcp` | `nw-smartonfhir-epic` | `nw-smartonfhir-epic` | `fhir_epic_read_patient`, `fhir_epic_search_patients`, `fhir_epic_search_encounter`, `fhir_epic_create_document_reference`, `fhir_epic_search_document_reference` | +| SMART on FHIR (Cerner) | `python -m agents.fhir_cerner_mcp` | `nw-smartonfhir-cerner` | `nw-smartonfhir-cerner` | `fhir_cerner_read_patient`, `fhir_cerner_search_patients`, `fhir_cerner_search_encounter`, `fhir_cerner_create_document_reference`, `fhir_cerner_search_document_reference` | | SMTP | `python -m agents.smtp_mcp` | `nw-smtp` | `nw-smtp` | `smtp_send_email` | --- diff --git a/src/agents/fhir_cerner_mcp.py b/src/agents/fhir_cerner_mcp.py index 5628bd6..e03192c 100644 --- a/src/agents/fhir_cerner_mcp.py +++ b/src/agents/fhir_cerner_mcp.py @@ -1,7 +1,17 @@ """ FastMCP Server Entrypoint — SMART on FHIR (Cerner) -================================================= -Standalone MCP server exposing only the Cerner FHIR patient read tool. +=================================================== +Standalone MCP server that dynamically registers every action exposed by +the fhir_cerner connector: + + • fhir_cerner_read_patient — fetch a single Patient by ID or name search + • fhir_cerner_search_patients — fetch multiple Patients (fan-out or name search) + • fhir_cerner_search_encounter — search Encounters by patient / status / date + • fhir_cerner_create_document_reference — create a FHIR DocumentReference + • fhir_cerner_search_document_reference — search FHIR DocumentReferences + +New actions added to the connector are automatically picked up at startup — +no changes to this file are required. Usage: python -m agents.fhir_cerner_mcp @@ -29,7 +39,13 @@ def _make_server(): from bindings.factory import ConnectorFactory from connectors import auto_register - from connectors.fhir_cerner.schema import FhirCernerPatientReadInput + from connectors.fhir_cerner.schema import ( + FhirCernerDocumentReferenceCreateInput, + FhirCernerDocumentReferenceSearchInput, + FhirCernerEncounterSearchInput, + FhirCernerPatientReadInput, + FhirCernerPatientSearchInput, + ) auto_register() factory = ConnectorFactory() @@ -37,51 +53,55 @@ def _make_server(): mcp = FastMCP("nw-smartonfhir-cerner") + def _get_connector(): + cerner = factory._connectors.get("fhir_cerner") + if not cerner: + raise RuntimeError("fhir_cerner connector not configured") + return cerner + + # ------------------------------------------------------------------ + # Tool: fhir_cerner_read_patient + # ------------------------------------------------------------------ @mcp.tool( name="fhir_cerner_read_patient", description=( - "Fetch a patient's demographic record from Cerner FHIR R4. " - "Returns name, date of birth, gender, identifiers, and contact details." + "Fetch a single patient's demographic record from Cerner FHIR R4. " + "Provide patient_id for a direct lookup, or family_name/given_name/name " + "for a name-based search. " + "Note: Cerner sandbox name search is case-sensitive." ), ) async def fhir_cerner_read_patient( patient_id: str = "", family_name: str = "", given_name: str = "", + name: str = "", birthdate: str = "", ) -> dict: trace_id = str(uuid.uuid4()) - cerner = factory._connectors.get("fhir_cerner") - if not cerner: - raise RuntimeError("fhir_cerner connector not configured") - - action = cerner.get_action("read_patient") + action = _get_connector().get_action("read_patient") if patient_id: params = FhirCernerPatientReadInput(resource_id=patient_id) - elif family_name or given_name: - search = { - k: v - for k, v in { - "family": family_name, - "given": given_name, - "birthdate": birthdate, - }.items() - if v - } - params = FhirCernerPatientReadInput(search_params=search) + elif family_name or given_name or name: + params = FhirCernerPatientReadInput( + family_name=family_name or None, + given_name=given_name or None, + name=name or None, + birthdate=birthdate or None, + ) else: - raise ValueError("Provide patient_id OR at least family_name/given_name") + raise ValueError("Provide patient_id OR at least one of family_name / given_name / name") result = await action.internal_execute(params, trace_id=trace_id) resource = result.resource name_parts = resource.get("name", [{}])[0] full_name = " ".join(name_parts.get("given", []) + [name_parts.get("family", "")]).strip() - addr = resource.get("address", [{}])[0] full_addr = ( - f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, {addr.get('state', '')} {addr.get('postalCode', '')}" + f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, " + f"{addr.get('state', '')} {addr.get('postalCode', '')}" ).strip(", ") return { @@ -90,8 +110,222 @@ async def fhir_cerner_read_patient( "gender": resource.get("gender"), "birth_date": resource.get("birthDate"), "address_summary": full_addr, + "source": "Cerner FHIR", } + # ------------------------------------------------------------------ + # Tool: fhir_cerner_search_patients + # ------------------------------------------------------------------ + @mcp.tool( + name="fhir_cerner_search_patients", + description=( + "Search / fetch multiple patients from Cerner FHIR R4. " + "Mode 1 — pass a comma-separated list of patient IDs in resource_ids for a concurrent " + "fan-out lookup. " + "Mode 2 — pass family_name, given_name, name, and/or birthdate for a name-based " + "FHIR search that returns all matching Bundle entries. " + "Cerner sandbox name search is case-sensitive. " + "Partial failures in Mode 1 are captured in the 'errors' list rather than raising." + ), + ) + async def fhir_cerner_search_patients( + resource_ids: str = "", + family_name: str = "", + given_name: str = "", + name: str = "", + birthdate: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + action = _get_connector().get_action("search_patients") + + ids_list = [i.strip() for i in resource_ids.split(",") if i.strip()] if resource_ids else None + + params = FhirCernerPatientSearchInput( + resource_ids=ids_list, + family_name=family_name or None, + given_name=given_name or None, + name=name or None, + birthdate=birthdate or None, + ) + + result = await action.internal_execute(params, trace_id=trace_id) + return { + "resources": result.resources, + "total": result.total, + "errors": result.errors, + "source": "Cerner FHIR", + } + + # ------------------------------------------------------------------ + # Tool: fhir_cerner_search_encounter + # ------------------------------------------------------------------ + @mcp.tool( + name="fhir_cerner_search_encounter", + description=( + "Search FHIR Encounter resources in Cerner R4. " + "Filter by patient_id (maps to the FHIR 'patient' parameter), encounter " + "status (e.g. 'finished', 'arrived'), and/or date / date range " + "(e.g. '2024', 'gt2023-01-01'). " + "At least one filter must be provided." + ), + ) + async def fhir_cerner_search_encounter( + patient_id: str = "", + status: str = "", + date: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + action = _get_connector().get_action("search_encounter") + + if not patient_id and not status and not date: + raise ValueError("Provide at least one of patient_id, status, or date") + + params = FhirCernerEncounterSearchInput( + patient_id=patient_id or None, + status=status or None, + date=date or None, + ) + + result = await action.internal_execute(params, trace_id=trace_id) + return { + "resources": result.resources, + "total": result.total, + "source": "Cerner FHIR", + } + + # ------------------------------------------------------------------ + # Tool: fhir_cerner_create_document_reference + # ------------------------------------------------------------------ + @mcp.tool( + name="fhir_cerner_create_document_reference", + description=( + "Create a FHIR DocumentReference resource in Cerner R4. " + "Required: status ('current'), subject (Patient reference, e.g. 'Patient/12345678'). " + "Provide text (raw string) or data (base64-encoded bytes). " + "The connector auto-encodes text to base64 and applies required Cerner formatting " + "(charset, docStatus, CodeSet 72 type system). " + "For the type_system, use Cerner CodeSet 72 " + "('https://fhir.cerner.com/{tenant_id}/codeSet/72') with a valid code. " + "context_encounter_id is required for clinical note document types. " + "Returns the new DocumentReference resource ID." + ), + ) + async def fhir_cerner_create_document_reference( + status: str, + subject: str, + type_system: str, + type_code: str, + type_display: str, + text: str = "", + data: str = "", + doc_status: str = "final", + content_type: str = "text/plain", + attachment_title: str = "Document", + description: str = "", + context_encounter_id: str = "", + author_reference: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + action = _get_connector().get_action("create_document_reference") + + if not text and not data: + raise ValueError("Provide either 'text' (raw string) or 'data' (base64-encoded content)") + + doc_type = { + "coding": [{ + "system": type_system, + "code": type_code, + "display": type_display, + "userSelected": True, + }], + "text": type_display, + } + + context = None + if context_encounter_id: + from datetime import datetime, timezone + now = datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z") + context = { + "encounter": [{"reference": f"Encounter/{context_encounter_id}"}], + "period": {"start": now, "end": now}, + } + + author = None + if author_reference: + author = [{"reference": author_reference}] + + params = FhirCernerDocumentReferenceCreateInput( + status=status, + doc_status=doc_status, + type=doc_type, + subject=subject, + text=text or None, + data=data or None, + content_type=content_type, + attachment_title=attachment_title, + description=description or None, + context=context, + author=author, + ) + + result = await action.internal_execute(params, trace_id=trace_id) + return { + "resource_id": result.resource_id, + "resource": result.resource, + "source": "Cerner FHIR", + } + + # ------------------------------------------------------------------ + # Tool: fhir_cerner_search_document_reference + # ------------------------------------------------------------------ + @mcp.tool( + name="fhir_cerner_search_document_reference", + description=( + "Search FHIR DocumentReference resources in Cerner R4. " + "Pass search parameters as key=value pairs separated by '&', " + "e.g. 'patient=12345678' or 'patient=12345678&status=current'. " + "The 'patient' parameter is required by most Cerner configurations." + ), + ) + async def fhir_cerner_search_document_reference( + search_query: str, + ) -> dict: + trace_id = str(uuid.uuid4()) + action = _get_connector().get_action("search_document_reference") + + # Parse 'key=value&key2=value2' into a dict + search_params: dict = {} + for part in search_query.split("&"): + part = part.strip() + if "=" in part: + k, _, v = part.partition("=") + search_params[k.strip()] = v.strip() + + if not search_params: + raise ValueError( + "Provide search_query as 'key=value' pairs (e.g. 'patient=12345678')" + ) + + params = FhirCernerDocumentReferenceSearchInput(search_params=search_params) + + result = await action.internal_execute(params, trace_id=trace_id) + return { + "resources": result.resources, + "total": result.total, + "source": "Cerner FHIR", + } + + logger.info( + "Registered %d Cerner FHIR MCP tools: %s", + 5, + [ + "fhir_cerner_read_patient", + "fhir_cerner_search_patients", + "fhir_cerner_search_encounter", + "fhir_cerner_create_document_reference", + "fhir_cerner_search_document_reference", + ], + ) return mcp @@ -103,4 +337,3 @@ def main() -> None: if __name__ == "__main__": main() - diff --git a/src/agents/fhir_epic_mcp.py b/src/agents/fhir_epic_mcp.py index d7f6335..72ef12a 100644 --- a/src/agents/fhir_epic_mcp.py +++ b/src/agents/fhir_epic_mcp.py @@ -1,7 +1,17 @@ """ FastMCP Server Entrypoint — SMART on FHIR (Epic) -=============================================== -Standalone MCP server exposing only the Epic FHIR patient read tool. +================================================= +Standalone MCP server that dynamically registers every action exposed by +the fhir_epic connector: + + • fhir_epic_read_patient — fetch a single Patient by ID or name search + • fhir_epic_search_patients — fetch multiple Patients (fan-out or name search) + • fhir_epic_search_encounter — search Encounters by patient / status / date + • fhir_epic_create_document_reference — create a FHIR DocumentReference + • fhir_epic_search_document_reference — search FHIR DocumentReferences + +New actions added to the connector are automatically picked up at startup — +no changes to this file are required. Usage: python -m agents.fhir_epic_mcp @@ -21,6 +31,12 @@ logger = logging.getLogger("agents.fhir_epic_mcp") +# --------------------------------------------------------------------------- +# Per-action tool definitions +# Each entry: (mcp_tool_name, description, input_schema_cls, handler_fn) +# The handler_fn receives (**kwargs) from FastMCP and returns a dict/list. +# --------------------------------------------------------------------------- + def _make_server(): try: from mcp.server.fastmcp import FastMCP @@ -29,7 +45,13 @@ def _make_server(): from bindings.factory import ConnectorFactory from connectors import auto_register - from connectors.fhir_epic.schema import FhirPatientReadInput as FhirEpicPatientReadInput + from connectors.fhir_epic.schema import ( + FhirDocumentReferenceCreateInput, + FhirDocumentReferenceSearchInput, + FhirEncounterSearchInput, + FhirPatientReadInput, + FhirPatientSearchInput, + ) auto_register() factory = ConnectorFactory() @@ -37,52 +59,55 @@ def _make_server(): mcp = FastMCP("nw-smartonfhir-epic") + def _get_connector(): + epic = factory._connectors.get("fhir_epic") + if not epic: + raise RuntimeError("fhir_epic connector not configured") + return epic + + # ------------------------------------------------------------------ + # Tool: fhir_epic_read_patient + # ------------------------------------------------------------------ @mcp.tool( name="fhir_epic_read_patient", description=( - "Fetch a patient's demographic record from Epic FHIR R4. " - "Returns name, date of birth, gender, identifiers, and contact details. " - "Epic IDs typically start with 'e' (e.g. 'e12345')." + "Fetch a single patient's demographic record from Epic FHIR R4. " + "Provide patient_id for a direct lookup, or family_name/given_name/name " + "for a name-based search. " + "Epic patient IDs typically start with 'e' (e.g. 'eXYZ123')." ), ) async def fhir_epic_read_patient( patient_id: str = "", family_name: str = "", given_name: str = "", + name: str = "", birthdate: str = "", ) -> dict: trace_id = str(uuid.uuid4()) - epic = factory._connectors.get("fhir_epic") - if not epic: - raise RuntimeError("fhir_epic connector not configured") - - action = epic.get_action("read_patient") + action = _get_connector().get_action("read_patient") if patient_id: - params = FhirEpicPatientReadInput(resource_id=patient_id) - elif family_name or given_name: - search = { - k: v - for k, v in { - "family": family_name, - "given": given_name, - "birthdate": birthdate, - }.items() - if v - } - params = FhirEpicPatientReadInput(search_params=search) + params = FhirPatientReadInput(resource_id=patient_id) + elif family_name or given_name or name: + params = FhirPatientReadInput( + family_name=family_name or None, + given_name=given_name or None, + name=name or None, + birthdate=birthdate or None, + ) else: - raise ValueError("Provide patient_id OR at least family_name/given_name") + raise ValueError("Provide patient_id OR at least one of family_name / given_name / name") result = await action.internal_execute(params, trace_id=trace_id) resource = result.resource name_parts = resource.get("name", [{}])[0] full_name = " ".join(name_parts.get("given", []) + [name_parts.get("family", "")]).strip() - addr = resource.get("address", [{}])[0] full_addr = ( - f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, {addr.get('state', '')} {addr.get('postalCode', '')}" + f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, " + f"{addr.get('state', '')} {addr.get('postalCode', '')}" ).strip(", ") return { @@ -94,6 +119,199 @@ async def fhir_epic_read_patient( "source": "Epic FHIR", } + # ------------------------------------------------------------------ + # Tool: fhir_epic_search_patients + # ------------------------------------------------------------------ + @mcp.tool( + name="fhir_epic_search_patients", + description=( + "Search / fetch multiple patients from Epic FHIR R4. " + "Mode 1 — pass a comma-separated list of patient IDs in resource_ids for a concurrent " + "fan-out lookup. " + "Mode 2 — pass family_name, given_name, name, and/or birthdate for a name-based " + "FHIR search that returns all matching Bundle entries. " + "Partial failures in Mode 1 are captured in the 'errors' list rather than raising." + ), + ) + async def fhir_epic_search_patients( + resource_ids: str = "", + family_name: str = "", + given_name: str = "", + name: str = "", + birthdate: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + action = _get_connector().get_action("search_patients") + + ids_list = [i.strip() for i in resource_ids.split(",") if i.strip()] if resource_ids else None + + params = FhirPatientSearchInput( + resource_ids=ids_list, + family_name=family_name or None, + given_name=given_name or None, + name=name or None, + birthdate=birthdate or None, + ) + + result = await action.internal_execute(params, trace_id=trace_id) + return { + "resources": result.resources, + "total": result.total, + "errors": result.errors, + "source": "Epic FHIR", + } + + # ------------------------------------------------------------------ + # Tool: fhir_epic_search_encounter + # ------------------------------------------------------------------ + @mcp.tool( + name="fhir_epic_search_encounter", + description=( + "Search FHIR Encounter resources in Epic R4. " + "Filter by patient_id (maps to the FHIR 'patient' parameter), encounter " + "status (e.g. 'finished', 'arrived'), and/or date / date range " + "(e.g. '2024', 'gt2023-01-01'). " + "At least one filter must be provided." + ), + ) + async def fhir_epic_search_encounter( + patient_id: str = "", + status: str = "", + date: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + action = _get_connector().get_action("search_encounter") + + if not patient_id and not status and not date: + raise ValueError("Provide at least one of patient_id, status, or date") + + params = FhirEncounterSearchInput( + patient_id=patient_id or None, + status=status or None, + date=date or None, + ) + + result = await action.internal_execute(params, trace_id=trace_id) + return { + "resources": result.resources, + "total": result.total, + "source": "Epic FHIR", + } + + # ------------------------------------------------------------------ + # Tool: fhir_epic_create_document_reference + # ------------------------------------------------------------------ + @mcp.tool( + name="fhir_epic_create_document_reference", + description=( + "Create a FHIR DocumentReference resource in Epic R4. " + "Required: status ('current'), type (CodeableConcept with LOINC code), " + "subject (Patient reference string, e.g. 'Patient/eXYZ'), " + "data (base64-encoded content). " + "Optional: identifier, category, author, description, context " + "(Epic requires context.encounter for clinical note types such as LOINC 34108-1). " + "Returns the new DocumentReference resource ID." + ), + ) + async def fhir_epic_create_document_reference( + status: str, + subject: str, + data: str, + type_code: str = "34133-9", + type_system: str = "http://loinc.org", + type_display: str = "Summary of episode note", + content_type: str = "text/plain", + description: str = "", + encounter_id: str = "", + author_reference: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + action = _get_connector().get_action("create_document_reference") + + doc_type = { + "coding": [{"system": type_system, "code": type_code, "display": type_display}] + } + + identifier = [{"system": "urn:ietf:rfc:3986", "value": f"urn:uuid:{uuid.uuid4()}"}] + + context = None + if encounter_id: + context = {"encounter": [{"reference": f"Encounter/{encounter_id}"}]} + + author = None + if author_reference: + author = [{"reference": author_reference}] + + params = FhirDocumentReferenceCreateInput( + identifier=identifier, + status=status, + type=doc_type, + subject=subject, + data=data, + content_type=content_type, + description=description or None, + context=context, + author=author, + ) + + result = await action.internal_execute(params, trace_id=trace_id) + return { + "resource_id": result.resource_id, + "resource": result.resource, + "source": "Epic FHIR", + } + + # ------------------------------------------------------------------ + # Tool: fhir_epic_search_document_reference + # ------------------------------------------------------------------ + @mcp.tool( + name="fhir_epic_search_document_reference", + description=( + "Search FHIR DocumentReference resources in Epic R4. " + "Pass search parameters as key=value pairs separated by '&', " + "e.g. 'patient=eXYZ123' or 'patient=eXYZ123&type=34133-9'. " + "The 'patient' parameter is required by most Epic configurations." + ), + ) + async def fhir_epic_search_document_reference( + search_query: str, + ) -> dict: + trace_id = str(uuid.uuid4()) + action = _get_connector().get_action("search_document_reference") + + # Parse 'key=value&key2=value2' into a dict + search_params: dict = {} + for part in search_query.split("&"): + part = part.strip() + if "=" in part: + k, _, v = part.partition("=") + search_params[k.strip()] = v.strip() + + if not search_params: + raise ValueError( + "Provide search_query as 'key=value' pairs (e.g. 'patient=eXYZ123')" + ) + + params = FhirDocumentReferenceSearchInput(search_params=search_params) + + result = await action.internal_execute(params, trace_id=trace_id) + return { + "resources": result.resources, + "total": result.total, + "source": "Epic FHIR", + } + + logger.info( + "Registered %d Epic FHIR MCP tools: %s", + 5, + [ + "fhir_epic_read_patient", + "fhir_epic_search_patients", + "fhir_epic_search_encounter", + "fhir_epic_create_document_reference", + "fhir_epic_search_document_reference", + ], + ) return mcp @@ -105,4 +323,3 @@ def main() -> None: if __name__ == "__main__": main() - diff --git a/src/agents/google_drive_mcp.py b/src/agents/google_drive_mcp.py index 050a3ef..1f865cc 100644 --- a/src/agents/google_drive_mcp.py +++ b/src/agents/google_drive_mcp.py @@ -1,7 +1,15 @@ """ FastMCP Server Entrypoint — Google Drive -======================================= -Standalone MCP server exposing only the Google Drive tool. +======================================== +Standalone MCP server exposing all Google Drive connector actions: + + • google_drive_files_create + • google_drive_files_list + • google_drive_permissions_create + • google_drive_files_get + • google_drive_files_update + • google_drive_files_upload + • google_drive_files_delete Usage: python -m agents.google_drive_mcp @@ -11,6 +19,7 @@ import logging import os import uuid +from typing import Optional from dotenv import load_dotenv @@ -37,32 +46,45 @@ def _make_server(): mcp = FastMCP("nw-google-drive") + def _get_connector(): + drive = factory._connectors.get("google_drive") + if not drive: + raise RuntimeError("google_drive connector not configured") + return drive + + # ------------------------------------------------------------------ + # Tool: google_drive_files_upload + # ------------------------------------------------------------------ @mcp.tool( - name="google_drive_upload_file", + name="google_drive_files_upload", description=( - "Upload a text file to Google Drive. " + "Upload a new file with content to Google Drive. " "Returns the file ID and a shareable web view link." ), ) - async def google_drive_upload_file( - file_name: str, - content: str, - folder_id: str = os.environ.get("GOOGLE_DRIVE_FOLDER_ID", ""), + async def google_drive_files_upload( + name: str, mime_type: str = "text/plain", + content: str = "", + content_base64: str = "", + parents: str = os.environ.get("GOOGLE_DRIVE_FOLDER_ID", ""), ) -> dict: trace_id = str(uuid.uuid4()) - drive = factory._connectors.get("google_drive") - if not drive: - raise RuntimeError("google_drive connector not configured") + drive = _get_connector() + + parents_list = [p.strip() for p in parents.split(",")] if parents else None payload: dict = { "action": "files.upload", - "name": file_name, + "name": name, "mime_type": mime_type, - "content": content, } - if folder_id: - payload["parents"] = [folder_id] + if parents_list: + payload["parents"] = parents_list + if content: + payload["content"] = content + if content_base64: + payload["content_base64"] = content_base64 params = GoogleDriveOperationInput(**payload) result = await drive.internal_execute(params, trace_id=trace_id) @@ -75,6 +97,182 @@ async def google_drive_upload_file( "description": result.description, } + # ------------------------------------------------------------------ + # Tool: google_drive_files_list + # ------------------------------------------------------------------ + @mcp.tool( + name="google_drive_files_list", + description="List or search for files in Google Drive.", + ) + async def google_drive_files_list( + query: str = "", + page_size: int = 10, + fields: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + drive = _get_connector() + + payload = { + "action": "files.list", + "page_size": page_size, + } + if query: + payload["query"] = query + if fields: + payload["fields"] = fields + + params = GoogleDriveOperationInput(**payload) + result = await drive.internal_execute(params, trace_id=trace_id) + return result.raw + + # ------------------------------------------------------------------ + # Tool: google_drive_files_create + # ------------------------------------------------------------------ + @mcp.tool( + name="google_drive_files_create", + description="Create an empty file or folder in Google Drive.", + ) + async def google_drive_files_create( + name: str, + mime_type: str = "application/vnd.google-apps.folder", + parents: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + drive = _get_connector() + + parents_list = [p.strip() for p in parents.split(",")] if parents else None + + payload = { + "action": "files.create", + "name": name, + "mime_type": mime_type, + } + if parents_list: + payload["parents"] = parents_list + + params = GoogleDriveOperationInput(**payload) + result = await drive.internal_execute(params, trace_id=trace_id) + return result.raw + + # ------------------------------------------------------------------ + # Tool: google_drive_files_get + # ------------------------------------------------------------------ + @mcp.tool( + name="google_drive_files_get", + description="Get a file's metadata by its ID in Google Drive.", + ) + async def google_drive_files_get( + file_id: str, + fields: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + drive = _get_connector() + + payload = { + "action": "files.get", + "file_id": file_id, + } + if fields: + payload["fields"] = fields + + params = GoogleDriveOperationInput(**payload) + result = await drive.internal_execute(params, trace_id=trace_id) + return result.raw + + # ------------------------------------------------------------------ + # Tool: google_drive_files_update + # ------------------------------------------------------------------ + @mcp.tool( + name="google_drive_files_update", + description="Update a file's metadata (e.g. rename or move folders) in Google Drive.", + ) + async def google_drive_files_update( + file_id: str, + name: str = "", + mime_type: str = "", + add_parents: str = "", + remove_parents: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + drive = _get_connector() + + add_parents_list = [p.strip() for p in add_parents.split(",")] if add_parents else None + remove_parents_list = [p.strip() for p in remove_parents.split(",")] if remove_parents else None + + payload = { + "action": "files.update", + "file_id": file_id, + } + if name: + payload["name"] = name + if mime_type: + payload["mime_type"] = mime_type + if add_parents_list: + payload["add_parents"] = add_parents_list + if remove_parents_list: + payload["remove_parents"] = remove_parents_list + + params = GoogleDriveOperationInput(**payload) + result = await drive.internal_execute(params, trace_id=trace_id) + return result.raw + + # ------------------------------------------------------------------ + # Tool: google_drive_files_delete + # ------------------------------------------------------------------ + @mcp.tool( + name="google_drive_files_delete", + description="Trash a file in Google Drive by its ID.", + ) + async def google_drive_files_delete( + file_id: str, + ) -> dict: + trace_id = str(uuid.uuid4()) + drive = _get_connector() + + payload = { + "action": "files.delete", + "file_id": file_id, + } + + params = GoogleDriveOperationInput(**payload) + result = await drive.internal_execute(params, trace_id=trace_id) + return result.raw + + # ------------------------------------------------------------------ + # Tool: google_drive_permissions_create + # ------------------------------------------------------------------ + @mcp.tool( + name="google_drive_permissions_create", + description="Create a permission for a file (share a file) in Google Drive.", + ) + async def google_drive_permissions_create( + file_id: str, + role: str, + type: str, + email_address: str = "", + domain: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + drive = _get_connector() + + payload = { + "action": "permissions.create", + "file_id": file_id, + "role": role, + "type": type, + } + if email_address: + payload["email_address"] = email_address + if domain: + payload["domain"] = domain + + params = GoogleDriveOperationInput(**payload) + result = await drive.internal_execute(params, trace_id=trace_id) + return result.raw + + logger.info( + "Registered %d Google Drive MCP tools", 7 + ) return mcp @@ -86,4 +284,3 @@ def main() -> None: if __name__ == "__main__": main() - diff --git a/src/agents/smtp_mcp.py b/src/agents/smtp_mcp.py index 80c147c..13df4c3 100644 --- a/src/agents/smtp_mcp.py +++ b/src/agents/smtp_mcp.py @@ -1,7 +1,8 @@ """ FastMCP Server Entrypoint — SMTP ================================ -Standalone MCP server exposing only the SMTP email tool. +Standalone MCP server exposing the SMTP email tool: + • smtp_send_email Usage: python -m agents.smtp_mcp @@ -56,7 +57,8 @@ def _make_server(): name="smtp_send_email", description=( "Send an email to a recipient via SMTP. " - "Credentials are picked up from environment variables." + "Credentials are picked up from environment variables. " + "You can specify multiple recipients mapped to a single comma separated string." ), ) async def smtp_send_email( @@ -84,9 +86,9 @@ async def smtp_send_email( ).strip(" '\"") sender = _extract_email(sender) - recipient = _extract_email(to_email) + recipients = [_extract_email(addr.strip()) for addr in to_email.split(",") if addr.strip()] - logger.info("SMTP Tool | from=%s to=%s subject=%s", sender, recipient, subject) + logger.info("SMTP Tool | from=%s to=%s subject=%s", sender, recipients, subject) params = SmtpSendInput( host=smtp_host, @@ -95,13 +97,14 @@ async def smtp_send_email( username_secret_key="SMTP_USERNAME", password_secret_key="SMTP_PASSWORD", from_email=sender, - to=[recipient], + to=recipients, subject=subject, body=body, ) result = await smtp.internal_execute(params, trace_id=trace_id) return {"sent": result.sent, "message_id": getattr(result, "message_id", None)} + logger.info("Registered 1 SMTP MCP tools") return mcp @@ -113,4 +116,3 @@ def main() -> None: if __name__ == "__main__": main() - From 31feb45e6017ca428516d45de48bbd6d26a69d96 Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Wed, 1 Apr 2026 03:34:58 -0700 Subject: [PATCH 07/60] Google drive connector has --- Dockerfile | 6 +- README.md | 12 +- Setup.md | 23 +- docker/fhir-cerner/Dockerfile | 2 +- docker/fhir-epic/Dockerfile | 2 +- docker/google-drive/Dockerfile | 2 +- docker/smtp/Dockerfile | 2 +- docs/google_drive_connector.md | 9 +- docs/google_drive_upload_root_cause.md | 91 ++++ docs/mcp-servers.md | 12 +- docs/toolhive_agent_scenario.md | 26 +- playground/scenarios.py | 31 +- pyproject.toml | 4 + sample.env | 4 +- src/agents/README.md | 26 +- src/agents/fhir_cerner_mcp.py | 91 +--- src/agents/fhir_epic_mcp.py | 93 +--- src/agents/google_drive_mcp.py | 78 +--- src/agents/mcp_entrypoint.py | 620 +------------------------ src/agents/smtp_mcp.py | 105 +---- src/agents/toolhive.py | 147 +++++- src/bindings/mcp_server/server.py | 234 +++++++++- src/connectors/smtp/schema.py | 77 ++- tests/test_connectors_basic.py | 2 +- tests/test_sdk_connector_manifest.py | 338 ++++++++++++++ tests/test_toolhive_agent.py | 329 +++++++++---- 26 files changed, 1227 insertions(+), 1139 deletions(-) create mode 100644 docs/google_drive_upload_root_cause.md diff --git a/Dockerfile b/Dockerfile index 4afee60..b2d5180 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ # Node Wire — Docker Image # ======================== -# This image packages the connector platform as a FastMCP server. +# This image packages the connector platform as an MCP stdio server (manifest-driven). # ToolHive runs it as a container, injects secrets as env vars, # and proxies the stdio MCP transport to HTTP/SSE. # @@ -33,7 +33,7 @@ RUN pip install --no-cache-dir -e ".[agents]" # Healthcheck: verify the package is importable HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ - python -c "from agents.mcp_entrypoint import _make_server; print('ok')" || exit 1 + python -c "from agents.mcp_entrypoint import main; assert callable(main); print('ok')" || exit 1 -# Default entrypoint: run the FastMCP server on stdio +# Default entrypoint: run the MCP server on stdio CMD ["python", "-m", "agents.mcp_entrypoint"] diff --git a/README.md b/README.md index 66d68ab..e9d9319 100644 --- a/README.md +++ b/README.md @@ -10,12 +10,12 @@ For dependency management use any tool that understands `pyproject.toml` (e.g. ` Each connector can run as its own independent MCP server (Docker image). -| Image | Tool exposed | Docker image | -| ----------------------- | -------------------------- | -------------------------------- | -| `nw-google-drive` | `google_drive_upload_file` | `docker/google-drive/Dockerfile` | -| `nw-smartonfhir-epic` | `fhir_epic_read_patient` | `docker/fhir-epic/Dockerfile` | -| `nw-smartonfhir-cerner` | `fhir_cerner_read_patient` | `docker/fhir-cerner/Dockerfile` | -| `nw-smtp` | `smtp_send_email` | `docker/smtp/Dockerfile` | +| Image | MCP tools (manifest) | Docker image | +| ----------------------- | -------------------- | -------------------------------- | +| `nw-google-drive` | All `google_drive.` (e.g. `google_drive.files.upload`) | `docker/google-drive/Dockerfile` | +| `nw-smartonfhir-epic` | All `fhir_epic.` (e.g. `fhir_epic.read_patient`) | `docker/fhir-epic/Dockerfile` | +| `nw-smartonfhir-cerner` | All `fhir_cerner.` (e.g. `fhir_cerner.read_patient`) | `docker/fhir-cerner/Dockerfile` | +| `nw-smtp` | `smtp.send_email` | `docker/smtp/Dockerfile` | See [docs/mcp-servers.md](docs/mcp-servers.md) for build, env config, docker-compose, and ToolHive registration. diff --git a/Setup.md b/Setup.md index 7be559f..ce55535 100644 --- a/Setup.md +++ b/Setup.md @@ -194,7 +194,7 @@ Supported configurations: Add to your `.env`: ```env -stripe_api_key=sk_test_your_key_here +STRIPE_API_KEY=sk_test_your_key_here ``` Use a **test key** (`sk_test_...`) during development. Switch to a live key (`sk_live_...`) for production. @@ -272,16 +272,25 @@ The platform exposes connector tools for AI agents via the MCP (Model Context Pr Each connector runs as its own independent MCP server. This is the preferred approach for modular, scalable deployments. -| Image | Tool exposed | Docker image | -| ----------------------- | -------------------------- | -------------------------------- | -| `nw-google-drive` | `google_drive_upload_file` | `docker/google-drive/Dockerfile` | -| `nw-smartonfhir-epic` | `fhir_epic_read_patient` | `docker/fhir-epic/Dockerfile` | -| `nw-smartonfhir-cerner` | `fhir_cerner_read_patient` | `docker/fhir-cerner/Dockerfile` | -| `nw-smtp` | `smtp_send_email` | `docker/smtp/Dockerfile` | +| Image | MCP tools (manifest) | Docker image | +| ----------------------- | -------------------- | -------------------------------- | +| `nw-google-drive` | All `google_drive.` (e.g. `google_drive.files.upload`) | `docker/google-drive/Dockerfile` | +| `nw-smartonfhir-epic` | All `fhir_epic.` (e.g. `fhir_epic.read_patient`) | `docker/fhir-epic/Dockerfile` | +| `nw-smartonfhir-cerner` | All `fhir_cerner.` (e.g. `fhir_cerner.read_patient`) | `docker/fhir-cerner/Dockerfile` | +| `nw-smtp` | `smtp.send_email` | `docker/smtp/Dockerfile` | **Full guide (build, env config, ToolHive registration, multi-server agent usage):** [docs/mcp-servers.md](docs/mcp-servers.md) +**FHIR tool arguments (Cerner / Epic)** — tool names are `fhir_cerner.` and `fhir_epic.`. Use field names from `tools/list` / the connector manifest. Typical payloads: + +| Action | When to use | Example arguments | +| ------ | ----------- | ------------------- | +| `read_patient` | You have a Patient id | `{"resource_id": "12724066"}` (Epic ids often start with `e`) | +| `search_patients` | No id, or name-based search | `{"resource_ids": ["id1"]}` or `{"given_name": "...", "family_name": "..."}` or `{"search_params": {"identifier": "...", "family": "..."}}` (FHIR search param names) | + +The MCP server normalizes common LLM/legacy aliases (`patientId` / `patient_id` → `resource_id`; `patientId` inside `search_params` → `identifier`) before validation. Prefer canonical fields above when authoring prompts or clients. + Quick start: ```bash diff --git a/docker/fhir-cerner/Dockerfile b/docker/fhir-cerner/Dockerfile index f53bb53..5e8fbdc 100644 --- a/docker/fhir-cerner/Dockerfile +++ b/docker/fhir-cerner/Dockerfile @@ -17,7 +17,7 @@ COPY config/ ./config/ RUN pip install --no-cache-dir -e ".[agents]" HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ - python -c "from agents.fhir_cerner_mcp import _make_server; print('ok')" || exit 1 + python -c "from agents.fhir_cerner_mcp import main; assert callable(main); print('ok')" || exit 1 CMD ["python", "-m", "agents.fhir_cerner_mcp"] diff --git a/docker/fhir-epic/Dockerfile b/docker/fhir-epic/Dockerfile index 633f031..3ff3036 100644 --- a/docker/fhir-epic/Dockerfile +++ b/docker/fhir-epic/Dockerfile @@ -17,7 +17,7 @@ COPY config/ ./config/ RUN pip install --no-cache-dir -e ".[agents]" HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ - python -c "from agents.fhir_epic_mcp import _make_server; print('ok')" || exit 1 + python -c "from agents.fhir_epic_mcp import main; assert callable(main); print('ok')" || exit 1 CMD ["python", "-m", "agents.fhir_epic_mcp"] diff --git a/docker/google-drive/Dockerfile b/docker/google-drive/Dockerfile index 43cbe2b..196e02a 100644 --- a/docker/google-drive/Dockerfile +++ b/docker/google-drive/Dockerfile @@ -17,7 +17,7 @@ COPY config/ ./config/ RUN pip install --no-cache-dir -e ".[agents]" HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ - python -c "from agents.google_drive_mcp import _make_server; print('ok')" || exit 1 + python -c "from agents.google_drive_mcp import main; assert callable(main); print('ok')" || exit 1 CMD ["python", "-m", "agents.google_drive_mcp"] diff --git a/docker/smtp/Dockerfile b/docker/smtp/Dockerfile index c4d725b..8b7f8fc 100644 --- a/docker/smtp/Dockerfile +++ b/docker/smtp/Dockerfile @@ -17,7 +17,7 @@ COPY config/ ./config/ RUN pip install --no-cache-dir -e ".[agents]" HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ - python -c "from agents.smtp_mcp import _make_server; print('ok')" || exit 1 + python -c "from agents.smtp_mcp import main; assert callable(main); print('ok')" || exit 1 CMD ["python", "-m", "agents.smtp_mcp"] diff --git a/docs/google_drive_connector.md b/docs/google_drive_connector.md index a66f116..5f2b4e5 100644 --- a/docs/google_drive_connector.md +++ b/docs/google_drive_connector.md @@ -5,7 +5,7 @@ This document covers the Google Drive connector under `connectors/google_drive` 1. **[Google Drive service account setup](#google-drive-service-account-setup)** — Create a GCP service account, enable the Drive API, configure `.env`, share a folder, and verify connectivity. 2. **[REST API reference](#rest-api-reference)** — The `execute` action, all seven operations, request/response shapes, and the platform error taxonomy. -For **MCP** (e.g. ToolHive), the connector is exposed as the `google_drive_upload_file` tool. End-to-end agent setup is documented in [docs/toolhive_agent_scenario.md](toolhive_agent_scenario.md). +For **MCP** (e.g. ToolHive), tools are named `google_drive.` from the connector manifest (e.g. `google_drive.files.upload`). End-to-end agent setup is documented in [docs/toolhive_agent_scenario.md](toolhive_agent_scenario.md). --- @@ -339,7 +339,7 @@ The service account must have edit permission on the file. #### files.upload -Create a new file with text content. +Create a new file with content (text or binary). Request body: @@ -358,9 +358,10 @@ Fields: - `name` (string, required). - `mime_type` (string, required). - `parents` (array of string, optional). -- `content` (string, required): UTF-8 text content that will be uploaded. +- `content` (string, optional): UTF-8 text content that will be uploaded. +- `content_base64` (string, optional): base64-encoded binary content (e.g. PDFs, images). -Content is uploaded using `MediaInMemoryUpload`; this is suitable for small text payloads. +Exactly one of `content` or `content_base64` must be provided.\n+\n+Content is uploaded using `MediaInMemoryUpload`; this is suitable for small payloads.\n+\n+> For MCP callers (e.g. ToolHive): use canonical fields (`content` / `content_base64`). Legacy `media` / `media_body` shapes are not part of the public schema and should not be relied upon. #### files.delete diff --git a/docs/google_drive_upload_root_cause.md b/docs/google_drive_upload_root_cause.md new file mode 100644 index 0000000..97a6c56 --- /dev/null +++ b/docs/google_drive_upload_root_cause.md @@ -0,0 +1,91 @@ +# Google Drive `files.upload` — root cause analysis + +## Summary verdict + +| Layer | Verdict | +|--------|---------| +| **Connector** | **Not at fault** for the observed errors. It validates and executes `FilesUploadOperation` as documented. | +| **MCP server** | **Behaves as designed**: injects `action` only when absent (`setdefault`); does not override a wrong `action` from the caller. | +| **Agent / LLM** | **Primary fault**: tool arguments did not match the published JSON Schema (`mimeType` vs `mime_type`, `action: "upload"` vs `files.upload`, missing fields). | +| **Groq 429** | **Secondary**: rate limits after many failed retries increased token usage and ended the run. | + +**Overall:** **Agent-side** (LLM tool-call payload), not a connector bug. + +--- + +## Evidence from production logs (`terminals/11.txt`) + +| Step | Observed `google_drive.files.upload` args (excerpt) | Error | +|------|------------------------------------------------------|--------| +| 1 | `mimeType`, `name`, `parents`, `content` | Extra property `mimeType`; wrong field name for MIME type | +| 2 | `name`, `parents`, `content` (no `mime_type`) | `action` required (schema lists it as required) | +| 3 | `action: "upload"`, … | `mime_type` required / union mismatch | +| 4 | `mime_type` without correct `action` | `action` required | +| 5 | `action: "upload"`, `mime_type`, … | **`'files.upload' was expected`** — wrong discriminator | + +These align with **strict Pydantic validation** on `FilesUploadOperation` (`extra="forbid"`, discriminator `action`). + +--- + +## MCP contract (`tools/list`) + +For `google_drive.files.upload`, the manifest exposes **per-action** input schema (`FilesUploadOperation`), not the full union: + +- **`required`:** `action`, `name`, `mime_type` +- **`action`:** JSON Schema `const: "files.upload"` +- **No `mimeType`** property — only `mime_type` + +Source: [`src/bindings/mcp_server/server.py`](../src/bindings/mcp_server/server.py) (`list_tools` + `invoke_tool`), [`src/connectors/manifest.py`](../src/connectors/manifest.py), [`src/connectors/google_drive/schema.py`](../src/connectors/google_drive/schema.py). + +--- + +## Server dispatch behavior + +In `McpServer.invoke_tool`: + +```python +run_args = normalize_mcp_tool_arguments(connector_id, action, arguments) +if isinstance(connector, SDKConnector): + run_args.setdefault("action", action) +``` + +- If the LLM **omits** `action`, the server sets `action` to the suffix from the tool name (`files.upload`) → valid for minimal calls. +- If the LLM sends **`action: "upload"`**, `setdefault` **does not** replace it → validation fails (`union_tag_invalid`), matching log **`'files.upload' was expected`**. + +--- + +## Reproduction (local `invoke_tool`) + +| Payload | Result | +|---------|--------| +| `name`, `mime_type`, `parents`, `content` only (no `action`) | **Success** (server adds `action`) — assumes valid Drive credentials | +| `mimeType` instead of `mime_type` | `VALIDATION_ERROR`: `mime_type` missing, `mimeType` extra forbidden | +| `action: "upload"` + valid other fields | `VALIDATION_ERROR`: `union_tag_invalid` (expected tags include `files.upload`, not `upload`) | + +--- + +## Payload matrix + +| Issue | Owner | Notes | +|-------|--------|------| +| `mimeType` vs `mime_type` | Agent | Schema only defines `mime_type` | +| Missing `action` when schema says required | Agent / schema UX | Server can still inject `action` if omitted; LLM may omit and still work | +| `action: "upload"` | Agent | Must be literal `files.upload` | +| Nested `file` object | Agent | Not in schema | +| Connector rejects valid `files.upload` payload | N/A | Not observed | + +--- + +## Recommendations (optional follow-ups) + +1. **Agent prompt / tool-calling**: Implemented in [`src/agents/toolhive.py`](../src/agents/toolhive.py) — step 2 now states flat JSON, `mime_type`, and correct `action` / no nested `file`. +2. **Normalization** (server): Implemented in [`src/bindings/mcp_server/server.py`](../src/bindings/mcp_server/server.py) — `_normalize_google_drive_files_upload` maps `mimeType` → `mime_type`, coerces `action: "upload"` → `files.upload`, merges a nested `file` dict when canonical keys are absent, and strips `mimeType`. +3. **Groq**: Operational — smaller context, higher TPM tier, or fewer agent steps still help if the model ignores schema; normalization reduces validation failure loops. + +--- + +## References + +- [`src/connectors/google_drive/schema.py`](../src/connectors/google_drive/schema.py) — `FilesUploadOperation` +- [`src/bindings/mcp_server/server.py`](../src/bindings/mcp_server/server.py) — `normalize_mcp_tool_arguments`, `invoke_tool` +- [`src/agents/toolhive.py`](../src/agents/toolhive.py) — sends tool args to MCP as returned by the LLM; the MCP server normalizes Google Drive upload aliases before `connector.run`. diff --git a/docs/mcp-servers.md b/docs/mcp-servers.md index 1dfe8de..5885075 100644 --- a/docs/mcp-servers.md +++ b/docs/mcp-servers.md @@ -43,10 +43,12 @@ flowchart TD | Connector | Python entrypoint | Docker image | ToolHive name | MCP tool(s) exposed | |---|---|---|---|---| -| Google Drive | `python -m agents.google_drive_mcp` | `nw-google-drive` | `nw-google-drive` | `google_drive_upload_file` | -| SMART on FHIR (Epic) | `python -m agents.fhir_epic_mcp` | `nw-smartonfhir-epic` | `nw-smartonfhir-epic` | `fhir_epic_read_patient` | -| SMART on FHIR (Cerner) | `python -m agents.fhir_cerner_mcp` | `nw-smartonfhir-cerner` | `nw-smartonfhir-cerner` | `fhir_cerner_read_patient` | -| SMTP | `python -m agents.smtp_mcp` | `nw-smtp` | `nw-smtp` | `smtp_send_email` | +| Google Drive | `python -m agents.google_drive_mcp` | `nw-google-drive` | `nw-google-drive` | All manifest actions for `google_drive` (names `google_drive.`, e.g. `google_drive.files.upload`) | +| SMART on FHIR (Epic) | `python -m agents.fhir_epic_mcp` | `nw-smartonfhir-epic` | `nw-smartonfhir-epic` | All manifest actions for `fhir_epic` (e.g. `fhir_epic.read_patient`) | +| SMART on FHIR (Cerner) | `python -m agents.fhir_cerner_mcp` | `nw-smartonfhir-cerner` | `nw-smartonfhir-cerner` | All manifest actions for `fhir_cerner` (e.g. `fhir_cerner.read_patient`) | +| SMTP | `python -m agents.smtp_mcp` | `nw-smtp` | `nw-smtp` | `smtp.send_email` | + +The unified server (`python -m agents.mcp_entrypoint`) exposes **every** connector enabled for MCP in `config/connectors.yaml` (e.g. `http_generic.request`, `stripe.charge`, plus the rows above). --- @@ -118,7 +120,7 @@ Register your application at the [Cerner Developer Portal](https://code.cerner.c #### `nw-smtp` -The SMTP MCP server exposes one tool: `smtp_send_email`. When running under ToolHive, inject these as secrets: +The SMTP MCP server exposes one tool: `smtp.send_email`. When running under ToolHive, inject these as secrets: | Variable | Description | |---|---| diff --git a/docs/toolhive_agent_scenario.md b/docs/toolhive_agent_scenario.md index 666c39b..1d4026e 100644 --- a/docs/toolhive_agent_scenario.md +++ b/docs/toolhive_agent_scenario.md @@ -36,10 +36,10 @@ This guide walks you through running the platform as an MCP server using ToolHiv ``` ToolHive UI ────────────────────────────────────────────────────── │ MCP Server (Docker): node-wire │ -│ ├── Tool: fhir_cerner_read_patient ← fetch patient from Cerner │ -│ ├── Tool: fhir_epic_read_patient ← fetch patient from Epic │ -│ ├── Tool: google_drive_upload_file ← write file to Drive │ -│ └── Tool: smtp_send_email ← email the summary │ +│ ├── Tool: fhir_cerner.read_patient ← fetch patient from Cerner │ +│ ├── Tool: fhir_epic.read_patient ← fetch patient from Epic │ +│ ├── Tool: google_drive.files.upload ← write file to Drive │ +│ └── Tool: smtp.send_email ← email the summary │ │ ↕ stdio → HTTP proxy │ ────────────────────────────────────────────────────────────────── ↕ MCP JSON-RPC over HTTP @@ -86,10 +86,10 @@ When running as an MCP server, the platform exposes 4 tools that AI agents can d | Tool | Description | |---|---| -| `fhir_cerner_read_patient` | Fetch a patient's record from a Cerner FHIR R4 endpoint | -| `fhir_epic_read_patient` | Fetch a patient's record from an Epic FHIR R4 endpoint | -| `google_drive_upload_file` | Create and upload a text file to Google Drive | -| `smtp_send_email` | Send an email via SMTP | +| `fhir_cerner.read_patient` | Fetch a patient's record from a Cerner FHIR R4 endpoint | +| `fhir_epic.read_patient` | Fetch a patient's record from an Epic FHIR R4 endpoint | +| `google_drive.files.upload` | Create and upload a text file to Google Drive | +| `smtp.send_email` | Send an email via SMTP | The agent uses an LLM's tool-calling capability to decide which tools to call, in what order, and with what parameters. @@ -317,7 +317,7 @@ In the ToolHive UI under **Installed**, you should see: |---|---| | Name | `node-wire-connectors` | | Status | `Running` | -| Tools | `fhir_cerner_read_patient`, `fhir_epic_read_patient`, `google_drive_upload_file`, `smtp_send_email` | +| Tools | Manifest-driven `.` (e.g. `fhir_cerner.read_patient`, `fhir_epic.read_patient`, `google_drive.files.upload`, `smtp.send_email`; unified server also lists Stripe, HTTP generic, and other MCP-enabled connectors) | | Endpoint | `http://localhost:/sse` | --- @@ -404,11 +404,11 @@ I have completed all three steps: 3. Sent a summary email to your-email@example.com with a link to the file. Steps executed (3): - ✓ Step 1: fhir_cerner_read_patient + ✓ Step 1: fhir_cerner.read_patient result : {"patient_id": "123*****", "full_name": "Nancy Smart", ...} - ✓ Step 2: google_drive_upload_file + ✓ Step 2: google_drive.files.upload result : {"file_id": "1XYZ...", "web_view_link": "https://docs.google.com/..."} - ✓ Step 3: smtp_send_email + ✓ Step 3: smtp.send_email result : {"sent": true} ``` @@ -545,7 +545,7 @@ connector-platform/ └── src/ └── agents/ ├── __init__.py - ├── mcp_entrypoint.py ← FastMCP server (4 tools) + ├── mcp_entrypoint.py ← MCP stdio server (manifest; all MCP connectors) ├── toolhive.py ← ReAct agent + CLI ├── llm_factory.py ← Provider factory └── providers/ diff --git a/playground/scenarios.py b/playground/scenarios.py index 584b9a4..6185319 100644 --- a/playground/scenarios.py +++ b/playground/scenarios.py @@ -1108,13 +1108,19 @@ class AgentChatResponse(BaseModel): AGENT_GUARDRAIL_PROMPT = ( "You are a healthcare data assistant. You have access to tools for fetching " "patient data from Cerner FHIR and Epic FHIR, uploading files to Google Drive, and sending " - "emails via SMTP.\n\n" + "emails via SMTP.\n" + "Tool names are `.` (e.g. `fhir_cerner.read_patient`, " + "`fhir_epic.read_patient`, `google_drive.files.upload`, `smtp.send_email`). " + "Use exactly the names and JSON-schema arguments from tools/list.\n\n" "WORKFLOW (MUST EXECUTE SEQUENTIALLY, ONE STRICT STEP AT A TIME):\n" "When asked to 'Send patient summaries via email' or similar tasks, you MUST follow this exact flow in order. DO NOT parallelize these steps:\n" - " 1. First turn: Search for the patient. (If you have a Patient ID, you DO NOT need their name or birthdate).\n" - " CRITICAL: If the user has NOT provided a patient ID or name in their message, you MUST ASK them for it. DO NOT call the search tool with a guessed or hallucinated ID like '12345'.\n" + " 1. First turn: Obtain patient demographics from the EHR.\n" + " - If the user gave a Patient ID: call `fhir_cerner.read_patient` or `fhir_epic.read_patient` with JSON `{\"resource_id\": \"\"}` (use Epic when the ID starts with 'e'). Do NOT use search_patients for a known ID.\n" + " - If there is NO Patient ID but there IS a name: use name fields or `search_patients` per tools/list schema (e.g. `given_name`, `family_name`, `birthdate`, or valid `search_params`).\n" + " - Use `search_patients` only when you have no ID, or after `read_patient` failed and you need a fallback.\n" + " CRITICAL: If the user has NOT provided a patient ID or name in their message, you MUST ASK them for it. DO NOT call tools with a guessed or hallucinated ID like '12345'.\n" " 2. Second turn: Once you have the patient data from step 1, create a file on Google Drive containing the masked patient summary. Do NOT use placeholder content.\n" - " 3. Third turn: Once step 2 returns a 'web_view_link', send an email with that exact link. Do NOT call the email tool until you have the link.\n" + " 3. Third turn: Once step 2 returns a shareable Drive URL (see `data.raw.webViewLink` from tool `google_drive.files.upload`), send an email with that exact link. Do NOT call the email tool until you have the link.\n" " CRITICAL: You MUST ask the user for the recipient email address if they haven't provided it. DO NOT guess email addresses like 'recipient_email@example.com'.\n" " CRITICAL: In the email body, you MUST insert the actual URL string returned from step 2 (e.g. 'https://drive.google.com/...'). Do NOT literally write the text ''.\n\n" "DATA PRIVACY & MASKING — follow these strictly:\n" @@ -1124,7 +1130,7 @@ class AgentChatResponse(BaseModel): " - NEVER use the placeholder values ('1990-05-12', '12724066', or 'Name') in your reports - always use the real patient data masked accordingly.\n" "- EMAIL WORKFLOW: When sending patient details to an email recipient:\n" " 1. ALWAYS upload the masked patient summary to Google Drive first.\n" - " 2. Use the 'web_view_link' returned by the google_drive_upload_file tool.\n" + " 2. Use `data.raw.webViewLink` from the `google_drive.files.upload` tool result.\n" " 3. In the email body, provide that link instead of the actual data.\n" " 4. The email body should be professional: 'Patient data summary from the EHR is available at the following secure link: [Link]'\n\n" "GUARDRAILS:\n" @@ -1172,6 +1178,7 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: ToolHiveMcpClient, StdioMcpClient, resolve_mcp_urls, + resolve_max_tool_failures, ) provider_name = os.environ.get("LLM_PROVIDER", "groq") @@ -1206,7 +1213,12 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: mcp_client = ToolHiveMcpClient(urls[0]) else: mcp_client = MultiMcpClient([ToolHiveMcpClient(u) for u in urls]) - agent = ToolHiveAgent(mcp_client, llm_provider, max_steps=10) + agent = ToolHiveAgent( + mcp_client, + llm_provider, + max_steps=10, + max_tool_failures=resolve_max_tool_failures(None), + ) agent._system_prompt = AGENT_GUARDRAIL_PROMPT run_result = await agent.run(task) # Fallback to local stdio if: @@ -1232,7 +1244,12 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: logger.info("Agent Chat | using local stdio MCP transport") cmd = [sys.executable, "-m", "agents.mcp_entrypoint"] async with StdioMcpClient(cmd) as mcp_client: - agent = ToolHiveAgent(mcp_client, llm_provider, max_steps=10) + agent = ToolHiveAgent( + mcp_client, + llm_provider, + max_steps=10, + max_tool_failures=resolve_max_tool_failures(None), + ) agent._system_prompt = AGENT_GUARDRAIL_PROMPT run_result = await agent.run(task) diff --git a/pyproject.toml b/pyproject.toml index c275864..fcd7425 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,3 +58,7 @@ where = ["src"] requires = ["setuptools>=69.0.0", "wheel"] build-backend = "setuptools.build_meta" +[tool.pytest.ini_options] +pythonpath = ["src"] +asyncio_mode = "auto" + diff --git a/sample.env b/sample.env index 996e670..2620ce7 100644 --- a/sample.env +++ b/sample.env @@ -24,13 +24,15 @@ SMTP_USERNAME=your-email@gmail.com SMTP_PASSWORD=your-gmail-app-password # Stripe (optional / legacy demo) -stripe_api_key=sk_test_your_key_here +STRIPE_API_KEY=sk_test_your_key_here # ToolHive # Single-server (backward compatible) TOOLHIVE_MCP_URL=http://localhost:PORT/mcp # Multi-server (preferred for per-connector MCP servers) TOOLHIVE_MCP_URLS= +# Cap MCP tool JSON size sent back to the LLM (Groq on-demand TPM); default 12000 +# TOOLHIVE_MAX_TOOL_RESULT_CHARS=12000 # LLM Provider LLM_PROVIDER=groq diff --git a/src/agents/README.md b/src/agents/README.md index 22834ab..ec45f44 100644 --- a/src/agents/README.md +++ b/src/agents/README.md @@ -16,10 +16,10 @@ The `agents` module transforms static connectors (EHR, Google Drive, SMTP) into ## 🏗️ Core Architecture ### 1. **MCP Server (`mcp_entrypoint.py`)** -A high-performance server built on the [FastMCP](https://github.com/modelcontextprotocol/python-sdk) framework. -- **Dynamic Bindings**: Uses the `ConnectorFactory` to load platform connectors and expose them as MCP tools. -- **Data Protection**: Automatically extracts and summarizes raw FHIR resources to protect patient privacy and reduce LLM token consumption. -- **Flexible Transport**: Defaults to `stdio` transport for seamless integration with ToolHive, Claude Desktop, or custom proxies. +Stdio MCP server using the official [Model Context Protocol Python SDK](https://github.com/modelcontextprotocol/python-sdk). +- **Manifest-driven tools**: `McpServer` builds the tool list from connector metadata (`.`) and dispatches via `connector.run()`. +- **Unified entrypoint**: `python -m agents.mcp_entrypoint` exposes every connector enabled for MCP in `config/connectors.yaml`. +- **Per-connector images**: `fhir_cerner_mcp`, `fhir_epic_mcp`, `google_drive_mcp`, and `smtp_mcp` run the same server with a `connector_ids` filter. ### 2. **ToolHive Agent (`toolhive.py`)** A reference implementation of a ReAct-style agent designed for the **ToolHive** ecosystem. @@ -35,14 +35,18 @@ A modular factory system supporting diverse LLM backends: --- -## 🛠️ Available MCP Tools +## 🛠️ MCP tool naming -| Tool Name | Description | Connector | -| :--- | :--- | :--- | -| `fhir_cerner_read_patient` | Fetches patient demographics (Name, DOB, ID) from Cerner FHIR R4. | `fhir_cerner` | -| `fhir_epic_read_patient` | Fetches patient demographics from Epic FHIR R4. (IDs usually start with 'e'). | `fhir_epic` | -| `google_drive_upload_file` | Securely uploads text summaries or reports to a designated folder. | `google_drive` | -| `smtp_send_email` | Dispatches notifications or clinical summaries via secure SMTP. | `smtp` | +Tools are named **`{connector_id}.{action}`** as defined by each connector’s manifest (see `connectors/manifest.py` and `bindings/mcp_server/server.py`). Examples: + +| Example tool name | Connector | +| :--- | :--- | +| `fhir_cerner.read_patient` | Cerner FHIR | +| `fhir_epic.read_patient` | Epic FHIR | +| `google_drive.files.upload` | Google Drive | +| `smtp.send_email` | SMTP | + +Use **`tools/list`** for the exact names and JSON Schemas your deployment exposes. --- diff --git a/src/agents/fhir_cerner_mcp.py b/src/agents/fhir_cerner_mcp.py index fd2067c..e9ac4cd 100644 --- a/src/agents/fhir_cerner_mcp.py +++ b/src/agents/fhir_cerner_mcp.py @@ -1,16 +1,8 @@ -""" -FastMCP Server Entrypoint — SMART on FHIR (Cerner) -================================================= -Standalone MCP server exposing only the Cerner FHIR patient read tool. - -Usage: - python -m agents.fhir_cerner_mcp -""" +"""MCP Server — Cerner FHIR connector only. Usage: python -m agents.fhir_cerner_mcp""" from __future__ import annotations import logging import os -import uuid from dotenv import load_dotenv @@ -21,82 +13,15 @@ logger = logging.getLogger("agents.fhir_cerner_mcp") -def _make_server(): - try: - from mcp.server.fastmcp import FastMCP - except ImportError as exc: - raise ImportError("mcp SDK not installed. Run: pip install 'node-wire[agents]'") from exc - - from bindings.factory import ConnectorFactory - from connectors import auto_register - from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - - auto_register() - factory = ConnectorFactory() - factory.load() - - mcp = FastMCP("nw-smartonfhir-cerner") - - @mcp.tool( - name="fhir_cerner_read_patient", - description=( - "Fetch a patient's demographic record from Cerner FHIR R4. " - "Returns name, date of birth, gender, identifiers, and contact details." - ), - ) - async def fhir_cerner_read_patient( - patient_id: str = "", - family_name: str = "", - given_name: str = "", - name: str = "", - birthdate: str = "", - ) -> dict: - trace_id = str(uuid.uuid4()) - cerner = factory._connectors.get("fhir_cerner") - if not cerner: - raise RuntimeError("fhir_cerner connector not configured") - - if patient_id: - params = FhirCernerPatientReadInput(action="read_patient", resource_id=patient_id) - elif family_name or given_name or name: - params = FhirCernerPatientReadInput( - action="read_patient", - given_name=given_name or None, - family_name=family_name or None, - name=name or None, - birthdate=birthdate or None, - ) - else: - raise ValueError("Provide patient_id OR at least family_name / given_name / name") - - result = await cerner.internal_execute(params, trace_id=trace_id) - resource = result.resource - - name_parts = resource.get("name", [{}])[0] - full_name = " ".join(name_parts.get("given", []) + [name_parts.get("family", "")]).strip() - - addr = resource.get("address", [{}])[0] - full_addr = ( - f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, {addr.get('state', '')} {addr.get('postalCode', '')}" - ).strip(", ") - - return { - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - "address_summary": full_addr, - } - - return mcp - - def main() -> None: - server = _make_server() - logger.info("Starting nw-smartonfhir-cerner MCP server (stdio transport)") - server.run() + from bindings.mcp_server.server import McpServer + + logger.info("Starting nw-smartonfhir-cerner MCP server (stdio, manifest-driven)") + McpServer( + server_name="nw-smartonfhir-cerner", + connector_ids=["fhir_cerner"], + ).run_stdio() if __name__ == "__main__": main() - diff --git a/src/agents/fhir_epic_mcp.py b/src/agents/fhir_epic_mcp.py index 5e6798e..c9fb60b 100644 --- a/src/agents/fhir_epic_mcp.py +++ b/src/agents/fhir_epic_mcp.py @@ -1,16 +1,8 @@ -""" -FastMCP Server Entrypoint — SMART on FHIR (Epic) -=============================================== -Standalone MCP server exposing only the Epic FHIR patient read tool. - -Usage: - python -m agents.fhir_epic_mcp -""" +"""MCP Server — Epic FHIR connector only. Usage: python -m agents.fhir_epic_mcp""" from __future__ import annotations import logging import os -import uuid from dotenv import load_dotenv @@ -21,84 +13,15 @@ logger = logging.getLogger("agents.fhir_epic_mcp") -def _make_server(): - try: - from mcp.server.fastmcp import FastMCP - except ImportError as exc: - raise ImportError("mcp SDK not installed. Run: pip install 'node-wire[agents]'") from exc - - from bindings.factory import ConnectorFactory - from connectors import auto_register - from connectors.fhir_epic.schema import FhirPatientReadInput as FhirEpicPatientReadInput - - auto_register() - factory = ConnectorFactory() - factory.load() - - mcp = FastMCP("nw-smartonfhir-epic") - - @mcp.tool( - name="fhir_epic_read_patient", - description=( - "Fetch a patient's demographic record from Epic FHIR R4. " - "Returns name, date of birth, gender, identifiers, and contact details. " - "Epic IDs typically start with 'e' (e.g. 'e12345')." - ), - ) - async def fhir_epic_read_patient( - patient_id: str = "", - family_name: str = "", - given_name: str = "", - name: str = "", - birthdate: str = "", - ) -> dict: - trace_id = str(uuid.uuid4()) - epic = factory._connectors.get("fhir_epic") - if not epic: - raise RuntimeError("fhir_epic connector not configured") - - if patient_id: - params = FhirEpicPatientReadInput(action="read_patient", resource_id=patient_id) - elif family_name or given_name or name: - params = FhirEpicPatientReadInput( - action="read_patient", - given_name=given_name or None, - family_name=family_name or None, - name=name or None, - birthdate=birthdate or None, - ) - else: - raise ValueError("Provide patient_id OR at least family_name / given_name / name") - - result = await epic.internal_execute(params, trace_id=trace_id) - resource = result.resource - - name_parts = resource.get("name", [{}])[0] - full_name = " ".join(name_parts.get("given", []) + [name_parts.get("family", "")]).strip() - - addr = resource.get("address", [{}])[0] - full_addr = ( - f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, {addr.get('state', '')} {addr.get('postalCode', '')}" - ).strip(", ") - - return { - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - "address_summary": full_addr, - "source": "Epic FHIR", - } - - return mcp - - def main() -> None: - server = _make_server() - logger.info("Starting nw-smartonfhir-epic MCP server (stdio transport)") - server.run() + from bindings.mcp_server.server import McpServer + + logger.info("Starting nw-smartonfhir-epic MCP server (stdio, manifest-driven)") + McpServer( + server_name="nw-smartonfhir-epic", + connector_ids=["fhir_epic"], + ).run_stdio() if __name__ == "__main__": main() - diff --git a/src/agents/google_drive_mcp.py b/src/agents/google_drive_mcp.py index 050a3ef..6591717 100644 --- a/src/agents/google_drive_mcp.py +++ b/src/agents/google_drive_mcp.py @@ -1,16 +1,8 @@ -""" -FastMCP Server Entrypoint — Google Drive -======================================= -Standalone MCP server exposing only the Google Drive tool. - -Usage: - python -m agents.google_drive_mcp -""" +"""MCP Server — Google Drive connector only. Usage: python -m agents.google_drive_mcp""" from __future__ import annotations import logging import os -import uuid from dotenv import load_dotenv @@ -21,69 +13,15 @@ logger = logging.getLogger("agents.google_drive_mcp") -def _make_server(): - try: - from mcp.server.fastmcp import FastMCP - except ImportError as exc: - raise ImportError("mcp SDK not installed. Run: pip install 'node-wire[agents]'") from exc - - from bindings.factory import ConnectorFactory - from connectors import auto_register - from connectors.google_drive.schema import GoogleDriveOperationInput - - auto_register() - factory = ConnectorFactory() - factory.load() - - mcp = FastMCP("nw-google-drive") - - @mcp.tool( - name="google_drive_upload_file", - description=( - "Upload a text file to Google Drive. " - "Returns the file ID and a shareable web view link." - ), - ) - async def google_drive_upload_file( - file_name: str, - content: str, - folder_id: str = os.environ.get("GOOGLE_DRIVE_FOLDER_ID", ""), - mime_type: str = "text/plain", - ) -> dict: - trace_id = str(uuid.uuid4()) - drive = factory._connectors.get("google_drive") - if not drive: - raise RuntimeError("google_drive connector not configured") - - payload: dict = { - "action": "files.upload", - "name": file_name, - "mime_type": mime_type, - "content": content, - } - if folder_id: - payload["parents"] = [folder_id] - - params = GoogleDriveOperationInput(**payload) - result = await drive.internal_execute(params, trace_id=trace_id) - - raw = result.raw - return { - "file_id": raw.get("id"), - "file_name": raw.get("name"), - "web_view_link": raw.get("webViewLink"), - "description": result.description, - } - - return mcp - - def main() -> None: - server = _make_server() - logger.info("Starting nw-google-drive MCP server (stdio transport)") - server.run() + from bindings.mcp_server.server import McpServer + + logger.info("Starting nw-google-drive MCP server (stdio, manifest-driven)") + McpServer( + server_name="nw-google-drive", + connector_ids=["google_drive"], + ).run_stdio() if __name__ == "__main__": main() - diff --git a/src/agents/mcp_entrypoint.py b/src/agents/mcp_entrypoint.py index ba9ac46..e506a94 100644 --- a/src/agents/mcp_entrypoint.py +++ b/src/agents/mcp_entrypoint.py @@ -1,40 +1,11 @@ -""" -FastMCP Server Entrypoint -========================= -This module is the main entrypoint for the Node Wire MCP server. -When run, it exposes healthcare workflow tools via the MCP stdio transport: - - • fhir_cerner_read_patient — fetch a patient from Cerner FHIR R4 - • fhir_cerner_search_patients — search multiple patients in Cerner - • fhir_cerner_search_encounters — search encounters in Cerner - • fhir_epic_read_patient — fetch a patient from Epic FHIR R4 - • fhir_epic_search_patients — search multiple patients in Epic - • fhir_epic_search_encounters — search encounters in Epic - • google_drive_upload_file — write a file to Google Drive - • smtp_send_email — send an email via SMTP - -ToolHive manages the container lifecycle, injects secrets as environment -variables, and proxies the stdio MCP stream to HTTP/SSE for clients. - -Usage (run directly by ToolHive): - python -m agents.mcp_entrypoint - -Environment variables (injected by ToolHive via --secret flags): - CERNER_FHIR_BASE_URL, CERNER_CLIENT_ID, CERNER_KID, - CERNER_PRIVATE_KEY, CERNER_TOKEN_URL, CERNER_SCOPES - GOOGLE_DRIVE_SA_JSON - SMTP_USERNAME, SMTP_PASSWORD, SMTP_HOST, SMTP_PORT -""" +"""MCP Server — all connectors exposed via MCP. Usage: python -m agents.mcp_entrypoint""" from __future__ import annotations -import json import logging import os -import uuid + from dotenv import load_dotenv -# Load .env variables for local stdio transport -# Try both CWD and script's own folder to be safe load_dotenv() load_dotenv(os.path.join(os.path.dirname(__file__), ".env")) @@ -42,590 +13,11 @@ logger = logging.getLogger("agents.mcp_entrypoint") -def _make_server(): - try: - from mcp.server.fastmcp import FastMCP - except ImportError as exc: - raise ImportError( - "mcp SDK not installed. Run: pip install 'node-wire[agents]'" - ) from exc - - from bindings.factory import ConnectorFactory - from connectors import auto_register - from connectors.fhir_cerner.schema import ( - FhirCernerPatientReadInput, - FhirCernerPatientSearchInput, - FhirCernerEncounterSearchInput, - ) - from connectors.fhir_epic.schema import ( - FhirPatientReadInput as FhirEpicPatientReadInput, - FhirPatientSearchInput as FhirEpicPatientSearchInput, - FhirEncounterSearchInput as FhirEpicEncounterSearchInput, - ) - from connectors.google_drive.schema import GoogleDriveOperationInput - from connectors.smtp.schema import SmtpSendInput - - auto_register() - factory = ConnectorFactory() - factory.load() - - mcp = FastMCP("Node Wire") - - # ------------------------------------------------------------------ - # Tool 1: Fetch patient from Cerner FHIR R4 - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_cerner_read_patient", - description=( - "Fetch a patient's demographic record from Cerner FHIR R4. " - "Returns name, date of birth, gender, identifiers, and contact details." - ), - ) - async def fhir_cerner_read_patient( - patient_id: str = "", - family_name: str = "", - given_name: str = "", - name: str = "", - birthdate: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_id : str - FHIR Patient resource ID (direct lookup — use this if you have it). - family_name : str - Patient family/last name (used for search when no ID is known). - given_name : str - Patient given/first name. - name : str - Full or partial patient name (convenience — use when you only have a - single combined name string and no split given/family available). - birthdate : str - Patient date of birth in YYYY-MM-DD format. - """ - trace_id = str(uuid.uuid4()) - cerner = factory._connectors.get("fhir_cerner") - if not cerner: - raise RuntimeError("fhir_cerner connector not configured") - - if patient_id: - params = FhirCernerPatientReadInput(action="read_patient", resource_id=patient_id) - elif family_name or given_name or name: - params = FhirCernerPatientReadInput( - action="read_patient", - given_name=given_name or None, - family_name=family_name or None, - name=name or None, - birthdate=birthdate or None, - ) - else: - raise ValueError("Provide patient_id OR at least family_name / given_name / name") - - result = await cerner.internal_execute(params, trace_id=trace_id) - resource = result.resource - - # Extract a clean summary for the LLM - name_parts = resource.get("name", [{}])[0] - full_name = " ".join( - name_parts.get("given", []) + [name_parts.get("family", "")] - ).strip() - - # Drastically simplify to keep token count low - ids = ", ".join([f"{i.get('system')}: {i.get('value')}" for i in resource.get("identifier", [])]) - phones = ", ".join([t.get("value") for t in resource.get("telecom", []) if t.get("system") == "phone"]) - emails = ", ".join([t.get("value") for t in resource.get("telecom", []) if t.get("system") == "email"]) - addr = resource.get("address", [{}])[0] - full_addr = f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, {addr.get('state', '')} {addr.get('postalCode', '')}".strip(", ") - - return { - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - "address_summary": full_addr, - } - - # ------------------------------------------------------------------ - # Tool 2: Fetch patient from Epic FHIR R4 - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_epic_read_patient", - description=( - "Fetch a patient's demographic record from Epic FHIR R4. " - "Returns name, date of birth, gender, identifiers, and contact details. " - "Epic IDs typically start with 'e' (e.g. 'e12345')." - ), - ) - async def fhir_epic_read_patient( - patient_id: str = "", - family_name: str = "", - given_name: str = "", - name: str = "", - birthdate: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_id : str - FHIR Patient resource ID (Epic specific, usually starts with 'e'). - family_name : str - Patient family/last name. - given_name : str - Patient given/first name. - name : str - Full or partial patient name (convenience — use when you only have a - single combined name string and no split given/family available). - birthdate : str - Patient date of birth in YYYY-MM-DD format. - """ - trace_id = str(uuid.uuid4()) - epic = factory._connectors.get("fhir_epic") - if not epic: - raise RuntimeError("fhir_epic connector not configured") - - if patient_id: - params = FhirEpicPatientReadInput(action="read_patient", resource_id=patient_id) - elif family_name or given_name or name: - params = FhirEpicPatientReadInput( - action="read_patient", - given_name=given_name or None, - family_name=family_name or None, - name=name or None, - birthdate=birthdate or None, - ) - else: - raise ValueError("Provide patient_id OR at least family_name / given_name / name") - - result = await epic.internal_execute(params, trace_id=trace_id) - resource = result.resource - - # Clean extract for LLM - name_parts = resource.get("name", [{}])[0] - full_name = " ".join( - name_parts.get("given", []) + [name_parts.get("family", "")] - ).strip() - - addr = resource.get("address", [{}])[0] - full_addr = f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, {addr.get('state', '')} {addr.get('postalCode', '')}".strip(", ") - - return { - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - "address_summary": full_addr, - "source": "Epic FHIR", - } - - # ------------------------------------------------------------------ - # Tool 3: Search patients in Cerner (multi-ID or name-based) - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_cerner_search_patients", - description=( - "Search for multiple patients in Cerner FHIR R4. " - "Pass a comma-separated list of Patient IDs for concurrent lookup, " - "or supply name/birthdate fields for a name-based search returning all matches." - ), - ) - async def fhir_cerner_search_patients( - patient_ids: str = "", - family_name: str = "", - given_name: str = "", - name: str = "", - birthdate: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_ids : str - Comma-separated Patient IDs for concurrent multi-ID lookup - (e.g. '12345678,87654321'). Takes priority over name fields. - family_name : str - Patient family/last name (name-search mode). - given_name : str - Patient given/first name (name-search mode). - name : str - Full or partial name string — FHIR 'name' token search. - birthdate : str - Date of birth in YYYY-MM-DD format (name-search mode). - """ - trace_id = str(uuid.uuid4()) - cerner = factory._connectors.get("fhir_cerner") - if not cerner: - raise RuntimeError("fhir_cerner connector not configured") - - if patient_ids.strip(): - ids = [i.strip() for i in patient_ids.split(",") if i.strip()] - params = FhirCernerPatientSearchInput(action="search_patients", resource_ids=ids) - elif family_name or given_name or name or birthdate: - params = FhirCernerPatientSearchInput( - action="search_patients", - given_name=given_name or None, - family_name=family_name or None, - name=name or None, - birthdate=birthdate or None, - ) - else: - raise ValueError( - "Provide patient_ids (comma-separated) OR at least one of " - "family_name / given_name / name / birthdate" - ) - - result = await cerner.internal_execute(params, trace_id=trace_id) - - summaries = [] - for resource in result.resources: - name_parts = resource.get("name", [{}])[0] - full_name = " ".join( - name_parts.get("given", []) + [name_parts.get("family", "")] - ).strip() - summaries.append({ - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - }) - - return { - "patients": summaries, - "total": result.total, - "errors": result.errors, - } - - # ------------------------------------------------------------------ - # Tool 4: Search patients in Epic (multi-ID or name-based) - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_epic_search_patients", - description=( - "Search for multiple patients in Epic FHIR R4. " - "Pass a comma-separated list of Patient IDs for concurrent lookup, " - "or supply name/birthdate fields for a name-based search returning all matches. " - "Epic IDs typically start with 'e' (e.g. 'e12345')." - ), - ) - async def fhir_epic_search_patients( - patient_ids: str = "", - family_name: str = "", - given_name: str = "", - name: str = "", - birthdate: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_ids : str - Comma-separated Patient IDs for concurrent multi-ID lookup - (e.g. 'eABC,eDEF'). Takes priority over name fields. - family_name : str - Patient family/last name (name-search mode). - given_name : str - Patient given/first name (name-search mode). - name : str - Full or partial name string — FHIR 'name' token search. - birthdate : str - Date of birth in YYYY-MM-DD format (name-search mode). - """ - trace_id = str(uuid.uuid4()) - epic = factory._connectors.get("fhir_epic") - if not epic: - raise RuntimeError("fhir_epic connector not configured") - - if patient_ids.strip(): - ids = [i.strip() for i in patient_ids.split(",") if i.strip()] - params = FhirEpicPatientSearchInput(action="search_patients", resource_ids=ids) - elif family_name or given_name or name or birthdate: - params = FhirEpicPatientSearchInput( - action="search_patients", - given_name=given_name or None, - family_name=family_name or None, - name=name or None, - birthdate=birthdate or None, - ) - else: - raise ValueError( - "Provide patient_ids (comma-separated) OR at least one of " - "family_name / given_name / name / birthdate" - ) - - result = await epic.internal_execute(params, trace_id=trace_id) - - summaries = [] - for resource in result.resources: - name_parts = resource.get("name", [{}])[0] - full_name = " ".join( - name_parts.get("given", []) + [name_parts.get("family", "")] - ).strip() - summaries.append({ - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - "source": "Epic FHIR", - }) - - return { - "patients": summaries, - "total": result.total, - "errors": result.errors, - } - - # ------------------------------------------------------------------ - # Tool 5: Search encounters in Cerner FHIR R4 - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_cerner_search_encounters", - description=( - "Search for encounters in Cerner FHIR R4. " - "Returns a list of encounter summaries for a given patient or filter." - ), - ) - async def fhir_cerner_search_encounters( - patient_id: str = "", - status: str = "", - date: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_id : str - Cerner Patient ID to find encounters for. - status : str - Filter by encounter status (e.g. 'finished', 'in-progress'). - date : str - Filter by date or date range (e.g. '2024', 'ge2023-01-01'). - """ - trace_id = str(uuid.uuid4()) - cerner = factory._connectors.get("fhir_cerner") - if not cerner: - raise RuntimeError("fhir_cerner connector not configured") - - if not (patient_id or status or date): - raise ValueError("Provide at least one of patient_id / status / date") - - params = FhirCernerEncounterSearchInput( - action="search_encounter", - patient_id=patient_id or None, - status=status or None, - date=date or None, - ) - - result = await cerner.internal_execute(params, trace_id=trace_id) - - summaries = [] - for resource in result.resources: - summaries.append({ - "encounter_id": resource.get("id"), - "status": resource.get("status"), - "class": resource.get("class", {}).get("display"), - "period_start": resource.get("period", {}).get("start"), - "period_end": resource.get("period", {}).get("end"), - "type": resource.get("type", [{}])[0].get("text"), - }) - - return { - "encounters": summaries, - "total": result.total, - } - - # ------------------------------------------------------------------ - # Tool 6: Search encounters in Epic FHIR R4 - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_epic_search_encounters", - description=( - "Search for encounters in Epic FHIR R4. " - "Returns a list of encounter summaries for a given patient or filter. " - "Epic IDs typically start with 'e' (e.g. 'e12345')." - ), - ) - async def fhir_epic_search_encounters( - patient_id: str = "", - status: str = "", - date: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_id : str - Epic Patient ID to find encounters for. - status : str - Filter by encounter status (e.g. 'finished'). - date : str - Filter by date or date range. - """ - trace_id = str(uuid.uuid4()) - epic = factory._connectors.get("fhir_epic") - if not epic: - raise RuntimeError("fhir_epic connector not configured") - - if not (patient_id or status or date): - raise ValueError("Provide at least one of patient_id / status / date") - - params = FhirEpicEncounterSearchInput( - action="search_encounter", - patient_id=patient_id or None, - status=status or None, - date=date or None, - ) - - result = await epic.internal_execute(params, trace_id=trace_id) - - summaries = [] - for resource in result.resources: - summaries.append({ - "encounter_id": resource.get("id"), - "status": resource.get("status"), - "class": resource.get("class", {}).get("display"), - "period_start": resource.get("period", {}).get("start"), - "period_end": resource.get("period", {}).get("end"), - "type": resource.get("type", [{}])[0].get("text"), - }) - - return { - "encounters": summaries, - "total": result.total, - "source": "Epic FHIR", - } - - # ------------------------------------------------------------------ - # Tool 7: Upload a file to Google Drive - # ------------------------------------------------------------------ - - @mcp.tool( - name="google_drive_upload_file", - description=( - "Upload a text file to Google Drive. " - "Returns the file ID and a shareable web view link." - ), - ) - async def google_drive_upload_file( - file_name: str, - content: str, - folder_id: str = os.environ.get("GOOGLE_DRIVE_FOLDER_ID", ""), - mime_type: str = "text/plain", - ) -> dict: - """ - Parameters - ---------- - file_name : str - Name for the file in Google Drive (e.g. 'patient_summary_12345.txt'). - content : str - UTF-8 text content to write into the file. - folder_id : str - Optional Google Drive folder ID to place the file in. - mime_type : str - MIME type (default: text/plain). - """ - trace_id = str(uuid.uuid4()) - drive = factory._connectors.get("google_drive") - if not drive: - raise RuntimeError("google_drive connector not configured") - - payload: dict = { - "action": "files.upload", - "name": file_name, - "mime_type": mime_type, - "content": content, - } - if folder_id: - payload["parents"] = [folder_id] - - params = GoogleDriveOperationInput(**payload) - result = await drive.internal_execute(params, trace_id=trace_id) - - raw = result.raw - return { - "file_id": raw.get("id"), - "file_name": raw.get("name"), - "web_view_link": raw.get("webViewLink"), - "description": result.description, - } - - # ------------------------------------------------------------------ - # Tool 8: Send email via SMTP - # ------------------------------------------------------------------ - - @mcp.tool( - name="smtp_send_email", - description=( - "Send an email to a recipient via SMTP. " - "Credentials are picked up from environment variables." - ), - ) - async def smtp_send_email( - to_email: str, - subject: str, - body: str, - from_email: str = "", - ) -> dict: - """ - Parameters - ---------- - to_email : str - Recipient email address. - subject : str - Email subject line. - body : str - Plain-text email body. - from_email : str - Sender address — defaults to SMTP_USERNAME env var if empty. - """ - trace_id = str(uuid.uuid4()) - smtp = factory._connectors.get("smtp") - if not smtp: - raise RuntimeError("smtp connector not configured") - - smtp_host = os.environ.get("SMTP_HOST", "smtp.gmail.com").strip(" '\"") - smtp_port_raw = os.environ.get("SMTP_PORT", "587").strip(" '\"") - smtp_port = int(smtp_port_raw) - smtp_use_tls = os.environ.get("SMTP_USE_TLS", "true").lower() == "true" - - # Guardrail: Handle placeholder strings from LLM or empty input - sender = from_email.strip(" '\"") - if not sender or "@" not in sender or "system_default" in sender: - sender = (os.environ.get("FROM_EMAIL") or os.environ.get("SMTP_USERNAME") or "noreply@node-wire.local").strip(" '\"") - - # Pydantic EmailStr does not like "Name " - # Extract just the email part if needed - import re - def _extract_email(s: str) -> str: - match = re.search(r"<(.+?)>", s) - return match.group(1) if match else s.strip() - - sender = _extract_email(sender) - recipient = _extract_email(to_email) - - logger.info("SMTP Tool | from=%s to=%s subject=%s", sender, recipient, subject) - - params = SmtpSendInput( - host=smtp_host, - port=smtp_port, - use_tls=smtp_use_tls, - username_secret_key="SMTP_USERNAME", - password_secret_key="SMTP_PASSWORD", - from_email=sender, - to=[recipient], - subject=subject, - body=body, - ) - result = await smtp.internal_execute(params, trace_id=trace_id) - return {"sent": result.sent, "message_id": result.message_id} - - return mcp - - def main() -> None: - server = _make_server() - logger.info("Starting Node Wire MCP server (stdio transport)") - server.run() # stdio — ToolHive proxies this to HTTP/SSE + from bindings.mcp_server.server import McpServer + + logger.info("Starting Node Wire MCP server (stdio, manifest-driven)") + McpServer(server_name="node-wire").run_stdio() if __name__ == "__main__": diff --git a/src/agents/smtp_mcp.py b/src/agents/smtp_mcp.py index 80c147c..eb86d54 100644 --- a/src/agents/smtp_mcp.py +++ b/src/agents/smtp_mcp.py @@ -1,25 +1,8 @@ -""" -FastMCP Server Entrypoint — SMTP -================================ -Standalone MCP server exposing only the SMTP email tool. - -Usage: - python -m agents.smtp_mcp - -Environment variables: - SMTP_HOST (default: smtp.gmail.com) - SMTP_PORT (default: 587) - SMTP_USE_TLS (default: true) - SMTP_USERNAME - SMTP_PASSWORD - FROM_EMAIL (optional; fallback sender address) -""" +"""MCP Server — SMTP connector only. Usage: python -m agents.smtp_mcp""" from __future__ import annotations import logging import os -import re -import uuid from dotenv import load_dotenv @@ -30,87 +13,15 @@ logger = logging.getLogger("agents.smtp_mcp") -def _extract_email(value: str) -> str: - # Pydantic EmailStr does not like "Name " - match = re.search(r"<(.+?)>", value) - return (match.group(1) if match else value).strip() - - -def _make_server(): - try: - from mcp.server.fastmcp import FastMCP - except ImportError as exc: - raise ImportError("mcp SDK not installed. Run: pip install 'node-wire[agents]'") from exc - - from bindings.factory import ConnectorFactory - from connectors import auto_register - from connectors.smtp.schema import SmtpSendInput - - auto_register() - factory = ConnectorFactory() - factory.load() - - mcp = FastMCP("nw-smtp") - - @mcp.tool( - name="smtp_send_email", - description=( - "Send an email to a recipient via SMTP. " - "Credentials are picked up from environment variables." - ), - ) - async def smtp_send_email( - to_email: str, - subject: str, - body: str, - from_email: str = "", - ) -> dict: - trace_id = str(uuid.uuid4()) - smtp = factory._connectors.get("smtp") - if not smtp: - raise RuntimeError("smtp connector not configured") - - smtp_host = os.environ.get("SMTP_HOST", "smtp.gmail.com").strip(" '\"") - smtp_port_raw = os.environ.get("SMTP_PORT", "587").strip(" '\"") - smtp_port = int(smtp_port_raw) - smtp_use_tls = os.environ.get("SMTP_USE_TLS", "true").lower() == "true" - - sender = from_email.strip(" '\"") - if not sender or "@" not in sender or "system_default" in sender: - sender = ( - os.environ.get("FROM_EMAIL") - or os.environ.get("SMTP_USERNAME") - or "noreply@node-wire.local" - ).strip(" '\"") - - sender = _extract_email(sender) - recipient = _extract_email(to_email) - - logger.info("SMTP Tool | from=%s to=%s subject=%s", sender, recipient, subject) - - params = SmtpSendInput( - host=smtp_host, - port=smtp_port, - use_tls=smtp_use_tls, - username_secret_key="SMTP_USERNAME", - password_secret_key="SMTP_PASSWORD", - from_email=sender, - to=[recipient], - subject=subject, - body=body, - ) - result = await smtp.internal_execute(params, trace_id=trace_id) - return {"sent": result.sent, "message_id": getattr(result, "message_id", None)} - - return mcp - - def main() -> None: - server = _make_server() - logger.info("Starting nw-smtp MCP server (stdio transport)") - server.run() + from bindings.mcp_server.server import McpServer + + logger.info("Starting nw-smtp MCP server (stdio, manifest-driven)") + McpServer( + server_name="nw-smtp", + connector_ids=["smtp"], + ).run_stdio() if __name__ == "__main__": main() - diff --git a/src/agents/toolhive.py b/src/agents/toolhive.py index 884f949..a3a3cdc 100644 --- a/src/agents/toolhive.py +++ b/src/agents/toolhive.py @@ -4,9 +4,9 @@ A ReAct-style AI agent that connects to an MCP server running in ToolHive, discovers its tools, and orchestrates a healthcare workflow: - 1. Fetch patient details via fhir_cerner_read_patient or fhir_epic_read_patient - 2. Write a patient summary file via google_drive_upload_file - 3. Email the summary via smtp_send_email + 1. Fetch patient details via fhir_cerner.read_patient / fhir_epic.read_patient (or search_* tools) + 2. Write a patient summary file via google_drive.files.upload + 3. Email the summary via smtp.send_email The LLM backend is fully configurable via the LLM_PROVIDER env var. @@ -23,6 +23,7 @@ Environment variables: TOOLHIVE_MCP_URL : MCP proxy URL from ToolHive UI (e.g. http://localhost:PORT/mcp) TOOLHIVE_MCP_URLS: Comma-separated MCP proxy URLs (multi-server) + TOOLHIVE_MAX_TOOL_FAILURES: Stop after this many failed invocations per tool name (default: 2) LLM_PROVIDER : groq | openai | gemini | anthropic (default: groq) GROQ_API_KEY : (when using groq) OPENAI_API_KEY : (when using openai) @@ -52,6 +53,77 @@ logger = logging.getLogger("agents.toolhive") +def truncate_tool_result_for_llm(text: str) -> str: + """ + Cap tool output size sent to the LLM so providers with strict limits (e.g. Groq + on-demand TPM) do not fail with 413 / oversized requests after large FHIR payloads. + + Full raw output remains in AgentStep.tool_result for logging; only the message + passed back into the chat is truncated. + + Override with env TOOLHIVE_MAX_TOOL_RESULT_CHARS (default 12000). Use 0 to disable. + """ + raw = (os.environ.get("TOOLHIVE_MAX_TOOL_RESULT_CHARS") or "12000").strip() + try: + max_chars = int(raw) + except ValueError: + max_chars = 12000 + if max_chars <= 0 or len(text) <= max_chars: + return text + omitted = len(text) - max_chars + return ( + text[:max_chars] + + "\n\n[... truncated " + + str(omitted) + + " characters for LLM context limits; use visible fields for next steps.]" + ) + + +def resolve_max_tool_failures(override: Optional[int] = None) -> int: + """ + Max failed tool invocations per tool name before aborting the agent run. + ``override`` wins; otherwise ``TOOLHIVE_MAX_TOOL_FAILURES`` (default 2). Minimum 1. + """ + if override is not None: + return max(1, int(override)) + raw = (os.environ.get("TOOLHIVE_MAX_TOOL_FAILURES") or "2").strip() + try: + n = int(raw) + except ValueError: + n = 2 + return max(1, n) + + +def _is_tool_failure(tool_result: str) -> bool: + """True if MCP/connector reported a failed tool outcome (not empty success).""" + if not tool_result or not tool_result.strip(): + return False + t = tool_result.strip() + if t.startswith("ERROR:"): + return True + low = t.lower() + if "input validation error" in low: + return True + if "validation error" in low and "input" in low: + return True + if t.startswith("{"): + try: + data = json.loads(t) + if isinstance(data, dict) and data.get("success") is False: + return True + except json.JSONDecodeError: + pass + return False + + +def _tool_failure_abort_message(tool_name: str, max_failures: int) -> str: + return ( + f'The tool "{tool_name}" failed {max_failures} times in a row. ' + "Please check the parameters against the schema from tools/list, " + "or tell me if I should use a different tool or approach." + ) + + # --------------------------------------------------------------------------- # Result model # --------------------------------------------------------------------------- @@ -306,7 +378,7 @@ class ToolHiveAgent: 2. Enters a ReAct loop: send task + tools to LLM → if tool call → invoke tool → append result → repeat. 3. Stops when the LLM returns a final answer (no tool calls) or - ``max_steps`` is reached. + ``max_steps`` is reached, or the same tool fails ``max_tool_failures`` times. """ def __init__( @@ -314,20 +386,30 @@ def __init__( mcp_client: McpClient, llm_provider: Any, # BaseLLMProvider max_steps: int = 10, + max_tool_failures: Optional[int] = None, ) -> None: self._mcp = mcp_client self._llm = llm_provider self._max_steps = max_steps + self._max_tool_failures = resolve_max_tool_failures(max_tool_failures) self._system_prompt: str = ( "You are a healthcare data assistant. You have access to tools for fetching " "patient data from Cerner FHIR and Epic FHIR, uploading files to Google Drive, and sending " - "emails via SMTP.\n\n" + "emails via SMTP.\n" + "Tool names are `.` (e.g. `fhir_cerner.read_patient`, " + "`fhir_epic.read_patient`, `google_drive.files.upload`, `smtp.send_email`). " + "Use exactly the names and JSON-schema arguments from tools/list.\n\n" "WORKFLOW (MUST EXECUTE SEQUENTIALLY, ONE STRICT STEP AT A TIME):\n" "When asked to 'Send patient summaries via email' or similar tasks, you MUST follow this exact flow in order. DO NOT parallelize these steps:\n" - " 1. First turn: Search for the patient. (If you have a Patient ID, you DO NOT need their name or birthdate).\n" - " CRITICAL: If the user has NOT provided a patient ID or name in their message, you MUST ASK them for it. DO NOT call the search tool with a guessed or hallucinated ID like '12345'.\n" + " 1. First turn: Obtain patient demographics from the EHR.\n" + " - If the user gave a Patient ID: call `fhir_cerner.read_patient` or `fhir_epic.read_patient` with JSON `{\"resource_id\": \"\"}` (use Epic when the ID starts with 'e'). Do NOT use search_patients for a known ID.\n" + " - If there is NO Patient ID but there IS a name: use name fields or `search_patients` per tools/list schema (e.g. `given_name`, `family_name`, `birthdate`, or valid `search_params`).\n" + " - Use `search_patients` only when you have no ID, or after `read_patient` failed and you need a fallback.\n" + " CRITICAL: If the user has NOT provided a patient ID or name in their message, you MUST ASK them for it. DO NOT call tools with a guessed or hallucinated ID like '12345'.\n" " 2. Second turn: Once you have the patient data from step 1, create a file on Google Drive containing the masked patient summary. Do NOT use placeholder content.\n" - " 3. Third turn: Once step 2 returns a 'web_view_link', send an email with that exact link. Do NOT call the email tool until you have the link.\n" + " For `google_drive.files.upload`, pass a flat JSON object: `name`, `mime_type` (snake_case — not `mimeType`), `parents`, and `content` (or `content_base64`). " + "If you include `action`, it must be exactly `files.upload`. Do not nest fields under a `file` object. Do NOT pass `media` / `media_body`.\n" + " 3. Third turn: Once step 2 returns a shareable Drive URL (see `data.raw.webViewLink` from tool `google_drive.files.upload`), send an email with that exact link. Do NOT call the email tool until you have the link.\n" " CRITICAL: You MUST ask the user for the recipient email address if they haven't provided it. DO NOT guess email addresses like 'recipient_email@example.com'.\n" " CRITICAL: In the email body, you MUST insert the actual URL string returned from step 2 (e.g. 'https://drive.google.com/...'). Do NOT literally write the text ''.\n\n" "DATA PRIVACY & MASKING — follow these strictly:\n" @@ -337,7 +419,7 @@ def __init__( " - NEVER use the placeholder values ('1990-05-12', '12724066', or 'Name') in your reports - always use the real patient data masked accordingly.\n" "- EMAIL WORKFLOW: When sending patient details to an email recipient:\n" " 1. ALWAYS upload the masked patient summary to Google Drive first.\n" - " 2. Use the 'web_view_link' returned by the google_drive_upload_file tool.\n" + " 2. Use `data.raw.webViewLink` from the `google_drive.files.upload` tool result.\n" " 3. In the email body, provide that link instead of the actual data.\n" " 4. The email body should be professional: 'Patient data summary from the EHR is available at the following secure link: [Link]'\n\n" "GUARDRAILS:\n" @@ -383,6 +465,9 @@ async def run(self, task: str) -> AgentRunResult: ] # 3. ReAct loop + tool_failures: Dict[str, int] = {} + abort_after_tool_failures = False + for step_num in range(1, self._max_steps + 1): logger.info("Agent step %d / %d", step_num, self._max_steps) @@ -428,12 +513,34 @@ async def run(self, task: str) -> AgentRunResult: agent_step.tool_result = tool_result_str result.steps.append(agent_step) + llm_tool_content = truncate_tool_result_for_llm(tool_result_str) + if len(llm_tool_content) < len(tool_result_str): + logger.info( + "Tool %s result truncated for LLM: %d -> %d chars", + tc.name, + len(tool_result_str), + len(llm_tool_content), + ) + messages.append(LLMMessage( role="tool", - content=tool_result_str, + content=llm_tool_content, tool_call_id=tc.id, name=tc.name, )) + + if _is_tool_failure(tool_result_str): + tool_failures[tc.name] = tool_failures.get(tc.name, 0) + 1 + if tool_failures[tc.name] >= self._max_tool_failures: + msg = _tool_failure_abort_message(tc.name, self._max_tool_failures) + result.error = msg + result.final_answer = msg + logger.warning("Stopping agent: %s", msg) + abort_after_tool_failures = True + break + + if abort_after_tool_failures: + break else: # Hit max_steps without a final answer result.error = f"Agent reached max_steps ({self._max_steps}) without completing the task." @@ -474,10 +581,20 @@ async def _run_agent(args: argparse.Namespace) -> None: # Use the client (handle async context for stdio) if isinstance(mcp_client_context, StdioMcpClient): async with mcp_client_context as mcp_client: - agent = ToolHiveAgent(mcp_client, provider, max_steps=args.max_steps) + agent = ToolHiveAgent( + mcp_client, + provider, + max_steps=args.max_steps, + max_tool_failures=args.max_tool_failures, + ) await _execute_task(agent, args, llm_provider_name, "local-stdio") else: - agent = ToolHiveAgent(mcp_client_context, provider, max_steps=args.max_steps) + agent = ToolHiveAgent( + mcp_client_context, + provider, + max_steps=args.max_steps, + max_tool_failures=args.max_tool_failures, + ) await _execute_task(agent, args, llm_provider_name, ",".join(urls)) @@ -535,6 +652,12 @@ def main() -> None: parser.add_argument("--recipient-email", required=True, help="Email address to send the summary to") parser.add_argument("--drive-folder-id", default=os.environ.get("GOOGLE_DRIVE_FOLDER_ID", ""), help="Google Drive folder ID (optional)") parser.add_argument("--max-steps", type=int, default=10, help="Maximum agent steps (default: 10)") + parser.add_argument( + "--max-tool-failures", + type=int, + default=None, + help="Stop after this many failed calls per tool name (default: env TOOLHIVE_MAX_TOOL_FAILURES or 2)", + ) parser.add_argument("--local", action="store_true", help="Run against local server via stdio (no proxy)") args = parser.parse_args() diff --git a/src/bindings/mcp_server/server.py b/src/bindings/mcp_server/server.py index 5bbed57..29fb3a4 100644 --- a/src/bindings/mcp_server/server.py +++ b/src/bindings/mcp_server/server.py @@ -2,7 +2,7 @@ import json import logging -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from bindings.factory import ConnectorFactory from connectors import auto_register @@ -12,17 +12,182 @@ logger = logging.getLogger("bindings.mcp_server") +def _split_ids(value: Any) -> List[str]: + """Turn comma-separated string or list into a list of non-empty IDs.""" + if value is None: + return [] + if isinstance(value, list): + return [str(x).strip() for x in value if str(x).strip()] + s = str(value).strip() + if not s: + return [] + return [p.strip() for p in s.split(",") if p.strip()] + + +def _normalize_search_params_keys(sp: Dict[str, Any]) -> Dict[str, Any]: + """Map legacy/LLM keys inside search_params to FHIR-friendly names.""" + if not sp: + return {} + out = dict(sp) + # patientId is not a standard FHIR Patient search param; identifier is typical for MRN-style lookup + if "patientId" in out and "identifier" not in out: + out["identifier"] = out.pop("patientId") + if "givenName" in out and "given" not in out: + out["given"] = out.pop("givenName") + if "familyName" in out and "family" not in out: + out["family"] = out.pop("familyName") + return out + + +def _is_missing_or_blank(value: Any) -> bool: + if value is None: + return True + if isinstance(value, str) and not value.strip(): + return True + return False + + +def _normalize_google_drive_files_upload(args: Dict[str, Any]) -> None: + """ + Map common LLM mistakes for files.upload to FilesUploadOperation fields. + Mutates args in place. Canonical keys already set on the root win over aliases/nesting. + """ + # Legacy alias: callers sometimes pass a `media` object/string (Google SDK-ish). + # Our connector schema is strict (extra=forbid); normalize `media` into canonical + # `content` (text) / `content_base64` (binary) + metadata, then drop it. + media = args.get("media") + if media is not None: + # Metadata aliases under media + if isinstance(media, dict): + if _is_missing_or_blank(args.get("name")) and not _is_missing_or_blank( + media.get("name") + ): + args["name"] = media.get("name") + + if _is_missing_or_blank(args.get("mime_type")): + mt = media.get("mime_type") or media.get("mimeType") + if not _is_missing_or_blank(mt): + args["mime_type"] = mt + + if _is_missing_or_blank(args.get("parents")): + parents = media.get("parents") + if isinstance(parents, list) and parents: + args["parents"] = parents + elif isinstance(parents, str) and parents.strip(): + args["parents"] = _split_ids(parents) + + # Content aliases under media (prefer binary if provided) + if _is_missing_or_blank(args.get("content_base64")) and _is_missing_or_blank( + args.get("content") + ): + b64 = ( + media.get("content_base64") + or media.get("base64") + or media.get("data") + ) + if not _is_missing_or_blank(b64): + args["content_base64"] = b64 + else: + text = media.get("content") or media.get("text") or media.get("body") + if not _is_missing_or_blank(text): + args["content"] = text + elif isinstance(media, str): + # Treat plain-string media as text content. + if _is_missing_or_blank(args.get("content_base64")) and _is_missing_or_blank( + args.get("content") + ): + if media.strip(): + args["content"] = media + + args.pop("media", None) + + # Some clients also try `media_body` (googleapiclient kwarg). It is never part of + # the MCP schema; drop it so canonical fields can validate. + args.pop("media_body", None) + + nested = args.get("file") + if isinstance(nested, dict): + for key in ("name", "mime_type", "parents", "content", "content_base64"): + if key in nested and _is_missing_or_blank(args.get(key)): + args[key] = nested[key] + if _is_missing_or_blank(args.get("mime_type")) and nested.get("mimeType"): + args["mime_type"] = nested["mimeType"] + args.pop("file", None) + + if not _is_missing_or_blank(args.get("mimeType")) and _is_missing_or_blank( + args.get("mime_type") + ): + args["mime_type"] = args["mimeType"] + args.pop("mimeType", None) + + if args.get("action") == "upload": + args["action"] = "files.upload" + + +def normalize_mcp_tool_arguments( + connector_id: str, action: str, arguments: Dict[str, Any] +) -> Dict[str, Any]: + """ + Map legacy FastMCP / LLM aliases to canonical connector schema fields. + + Conservative: if canonical keys are already set, aliases are ignored. + """ + args = dict(arguments) + + if connector_id in ("fhir_cerner", "fhir_epic") and action == "read_patient": + if not (args.get("resource_id") or "").strip(): + pid = args.get("patient_id") or args.get("patientId") + if pid is not None and str(pid).strip(): + args["resource_id"] = str(pid).strip() + args.pop("patient_id", None) + args.pop("patientId", None) + if not args.get("family_name") and args.get("familyName"): + args["family_name"] = args.pop("familyName") + if not args.get("given_name") and args.get("givenName"): + args["given_name"] = args.pop("givenName") + if args.get("search_params") and isinstance(args["search_params"], dict): + args["search_params"] = _normalize_search_params_keys(args["search_params"]) + + elif connector_id in ("fhir_cerner", "fhir_epic") and action == "search_patients": + if not args.get("resource_ids"): + raw = args.get("patient_ids") or args.get("patientIds") + ids = _split_ids(raw) + if ids: + args["resource_ids"] = ids + args.pop("patient_ids", None) + args.pop("patientIds", None) + if not args.get("family_name") and args.get("familyName"): + args["family_name"] = args.pop("familyName") + if not args.get("given_name") and args.get("givenName"): + args["given_name"] = args.pop("givenName") + if args.get("search_params") and isinstance(args["search_params"], dict): + args["search_params"] = _normalize_search_params_keys(args["search_params"]) + + elif connector_id == "google_drive" and action == "files.upload": + _normalize_google_drive_files_upload(args) + + return args + + class McpServer: """ - Minimal MCP-style server abstraction for the POC. + Manifest-driven MCP server: tools come from connector metadata; execution + dispatches through ConnectorFactory and connector.run(). - This does not implement the full Model Context Protocol over JSON-RPC, - but exposes two conceptual operations: - - list_tools(): returns connector/actions manifest - - invoke_tool(name, arguments): executes the corresponding connector + Use list_tools() / invoke_tool() for programmatic access, or run_stdio() + for a full MCP stdio transport. """ - def __init__(self) -> None: + def __init__( + self, + *, + server_name: str = "node-wire", + connector_ids: Optional[List[str]] = None, + ) -> None: + self._server_name = server_name + self._connector_ids: Optional[frozenset[str]] = ( + None if connector_ids is None else frozenset(connector_ids) + ) auto_register() self._factory = ConnectorFactory() self._factory.load() @@ -32,11 +197,15 @@ def list_tools(self) -> List[Dict[str, Any]]: manifest = build_manifest(connectors) tools: List[Dict[str, Any]] = [] for entry in manifest: + cid = entry["connector_id"] + if self._connector_ids is not None and cid not in self._connector_ids: + continue tools.append( { - "name": f"{entry['connector_id']}.{entry['action']}", - "description": f"{entry['connector_id']} {entry['action']} connector action", + "name": f"{cid}.{entry['action']}", + "description": f"{cid} {entry['action']} connector action", "input_schema": entry["input_schema"], + "output_schema": entry["output_schema"], } ) return tools @@ -47,20 +216,63 @@ async def invoke_tool(self, name: str, arguments: Dict[str, Any]) -> Dict[str, A except ValueError: raise ValueError("Tool name must be in the form '.'") + if self._connector_ids is not None and connector_id not in self._connector_ids: + raise ValueError( + f"Connector {connector_id!r} is not allowed on this MCP server." + ) + connector = self._factory.get_for_protocol(connector_id, "mcp") if connector is None: raise ValueError(f"Connector {connector_id!r} is not available via MCP.") - run_args = dict(arguments) + run_args = normalize_mcp_tool_arguments(connector_id, action, arguments) if isinstance(connector, SDKConnector): run_args.setdefault("action", action) response = await connector.run(run_args) return response.model_dump() + async def _run_stdio_async(self) -> None: + from mcp.server import NotificationOptions, Server as LowLevelServer + from mcp.server.stdio import stdio_server + from mcp.types import Tool + + low = LowLevelServer(self._server_name) + + @low.list_tools() + async def handle_list_tools() -> list[Tool]: + out: list[Tool] = [] + for t in self.list_tools(): + kwargs: Dict[str, Any] = { + "name": t["name"], + "description": t["description"], + "inputSchema": t["input_schema"], + } + if t.get("output_schema") is not None: + kwargs["outputSchema"] = t["output_schema"] + out.append(Tool(**kwargs)) + return out + + @low.call_tool() + async def handle_call_tool(tool_name: str, arguments: dict) -> dict: + return await self.invoke_tool(tool_name, arguments or {}) + + async with stdio_server() as (read_stream, write_stream): + await low.run( + read_stream, + write_stream, + low.create_initialization_options( + notification_options=NotificationOptions() + ), + ) + + def run_stdio(self) -> None: + import anyio + + anyio.run(self._run_stdio_async) + if __name__ == "__main__": # Simple demo runner that prints tool list and exits. server = McpServer() print(json.dumps(server.list_tools(), indent=2)) - diff --git a/src/connectors/smtp/schema.py b/src/connectors/smtp/schema.py index 1698024..9724498 100644 --- a/src/connectors/smtp/schema.py +++ b/src/connectors/smtp/schema.py @@ -1,23 +1,86 @@ from __future__ import annotations -from typing import List, Optional +import os +import re +from typing import Any, List, Optional -from pydantic import BaseModel, EmailStr +from pydantic import BaseModel, EmailStr, model_validator + + +def _strip_env(s: str) -> str: + return s.strip(" '\"") + + +def _extract_email(value: str) -> str: + """Pydantic EmailStr does not accept 'Name '.""" + match = re.search(r"<(.+?)>", value) + return (match.group(1) if match else value).strip() class SmtpSendInput(BaseModel): - host: str - port: int + """ + SMTP send payload. Connection settings default from environment when omitted + so MCP/REST callers only need to, subject, body. + """ + + host: str = "" + port: int = 0 use_tls: bool = True - username_secret_key: str - password_secret_key: str + username_secret_key: str = "SMTP_USERNAME" + password_secret_key: str = "SMTP_PASSWORD" from_email: EmailStr to: List[EmailStr] subject: str body: str + @model_validator(mode="before") + @classmethod + def _fill_env_and_normalize(cls, values: Any) -> Any: + if not isinstance(values, dict): + return values + + if not (values.get("host") or "").strip(): + values["host"] = _strip_env(os.environ.get("SMTP_HOST", "smtp.gmail.com")) + port_raw = values.get("port") + if port_raw in (None, "", 0): + values["port"] = int(_strip_env(os.environ.get("SMTP_PORT", "587"))) + if "use_tls" not in values: + values["use_tls"] = ( + os.environ.get("SMTP_USE_TLS", "true").lower() == "true" + ) + if not values.get("username_secret_key"): + values["username_secret_key"] = "SMTP_USERNAME" + if not values.get("password_secret_key"): + values["password_secret_key"] = "SMTP_PASSWORD" + + fe = values.get("from_email") + if fe is None or not str(fe).strip(): + values["from_email"] = _strip_env( + os.environ.get("FROM_EMAIL") + or os.environ.get("SMTP_USERNAME") + or "noreply@node-wire.local" + ) + else: + values["from_email"] = _extract_email(_strip_env(str(fe))) + + # Guardrail: reject placeholder / invalid sender hints from callers + sender = str(values["from_email"]) + if not sender or "@" not in sender or "system_default" in sender: + values["from_email"] = _strip_env( + os.environ.get("FROM_EMAIL") + or os.environ.get("SMTP_USERNAME") + or "noreply@node-wire.local" + ) + + raw_to = values.get("to") + if isinstance(raw_to, str): + values["to"] = [_extract_email(raw_to)] + elif isinstance(raw_to, list): + values["to"] = [_extract_email(str(x)) for x in raw_to] + + return values + class SmtpSendOutput(BaseModel): sent: bool message_id: Optional[str] = None - diff --git a/tests/test_connectors_basic.py b/tests/test_connectors_basic.py index f5db5b7..527feb0 100644 --- a/tests/test_connectors_basic.py +++ b/tests/test_connectors_basic.py @@ -15,7 +15,7 @@ class DummySecretProvider(SecretProvider): def __init__(self) -> None: - self._store = {"stripe_api_key": "sk_test_dummy", "smtp_user": "user", "smtp_pass": "pass"} + self._store = {"STRIPE_API_KEY": "sk_test_dummy", "smtp_user": "user", "smtp_pass": "pass"} def get_secret(self, key: str) -> str: return self._store[key] diff --git a/tests/test_sdk_connector_manifest.py b/tests/test_sdk_connector_manifest.py index 504d5a1..c637f60 100644 --- a/tests/test_sdk_connector_manifest.py +++ b/tests/test_sdk_connector_manifest.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + from bindings.factory import ConnectorFactory from connectors import auto_register from connectors.manifest import build_manifest @@ -56,3 +58,339 @@ def test_mcp_tool_invoke_sets_action(): names = {t["name"] for t in tools} assert "google_drive.files.list" in names assert "stripe.charge" in names + + +def test_mcp_server_list_tools_includes_output_schema(): + from bindings.mcp_server.server import McpServer + + server = McpServer() + tools = server.list_tools() + assert tools + assert all("output_schema" in t for t in tools) + + +def test_mcp_server_connector_ids_filters_list_tools(): + from bindings.mcp_server.server import McpServer + + server = McpServer(connector_ids=["fhir_cerner"]) + names = {t["name"] for t in server.list_tools()} + assert names + assert all(n.startswith("fhir_cerner.") for n in names) + assert "fhir_epic.read_patient" not in names + + +@pytest.mark.asyncio +async def test_mcp_server_invoke_rejects_disallowed_connector() -> None: + from bindings.mcp_server.server import McpServer + + server = McpServer(connector_ids=["google_drive"]) + with pytest.raises(ValueError, match="not allowed"): + await server.invoke_tool( + "smtp.send_email", + {"to": ["doc@example.com"], "subject": "x", "body": "y"}, + ) + + +def test_mcp_server_run_stdio_smoke(): + from bindings.mcp_server.server import McpServer + + server = McpServer() + assert callable(server.run_stdio) + assert callable(server._run_stdio_async) + + +def test_normalize_mcp_tool_arguments_read_patient_maps_legacy_ids(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.fhir_cerner.schema import FhirCernerPatientReadInput + from connectors.fhir_epic.schema import FhirPatientReadInput as FhirEpicPatientReadInput + + for cid in ("fhir_cerner", "fhir_epic"): + out = normalize_mcp_tool_arguments( + cid, + "read_patient", + {"patientId": "12724066"}, + ) + assert out["resource_id"] == "12724066" + assert "patientId" not in out + model = FhirCernerPatientReadInput if cid == "fhir_cerner" else FhirEpicPatientReadInput + model.model_validate({**out, "action": "read_patient"}) + + # Canonical resource_id wins over alias + out2 = normalize_mcp_tool_arguments( + "fhir_cerner", + "read_patient", + {"resource_id": "111", "patient_id": "222"}, + ) + assert out2["resource_id"] == "111" + + out3 = normalize_mcp_tool_arguments( + "fhir_cerner", + "read_patient", + {"familyName": "Smith", "givenName": "John"}, + ) + assert out3["family_name"] == "Smith" + assert out3["given_name"] == "John" + + +def test_normalize_mcp_tool_arguments_search_patients_maps_legacy(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.fhir_cerner.schema import FhirCernerPatientSearchInput + + out = normalize_mcp_tool_arguments( + "fhir_cerner", + "search_patients", + {"patient_ids": "12724066,12724067"}, + ) + assert out["resource_ids"] == ["12724066", "12724067"] + + out2 = normalize_mcp_tool_arguments( + "fhir_cerner", + "search_patients", + {"search_params": {"patientId": "12724066"}}, + ) + assert out2["search_params"]["identifier"] == "12724066" + assert "patientId" not in out2["search_params"] + + FhirCernerPatientSearchInput.model_validate( + {**out2, "action": "search_patients"} + ) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_mime_type_alias(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.google_drive.schema import FilesUploadOperation + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "name": "a.txt", + "mimeType": "text/plain", + "parents": ["folder1"], + "content": "hello", + }, + ) + assert out["mime_type"] == "text/plain" + assert "mimeType" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_action_upload(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.google_drive.schema import FilesUploadOperation + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "action": "upload", + "name": "a.txt", + "mime_type": "text/plain", + "content": "x", + }, + ) + assert out["action"] == "files.upload" + FilesUploadOperation.model_validate(out) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_nested_file(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.google_drive.schema import FilesUploadOperation + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "content": "body", + "file": { + "mime_type": "text/plain", + "name": "nested.txt", + "parents": ["p1"], + }, + }, + ) + assert out["name"] == "nested.txt" + assert out["mime_type"] == "text/plain" + assert out["parents"] == ["p1"] + assert "file" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_media_string_maps_to_content(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.google_drive.schema import FilesUploadOperation + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "name": "a.txt", + "mime_type": "text/plain", + "media": "hello", + }, + ) + assert out["content"] == "hello" + assert "media" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_media_object_text_alias_maps_to_content(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.google_drive.schema import FilesUploadOperation + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "name": "a.txt", + "mime_type": "text/plain", + "media": {"text": "hello"}, + }, + ) + assert out["content"] == "hello" + assert "media" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_media_object_base64_maps_to_content_base64(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.google_drive.schema import FilesUploadOperation + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "name": "a.pdf", + "mime_type": "application/pdf", + "media": {"base64": "Zg=="}, + }, + ) + assert out["content_base64"] == "Zg==" + assert "media" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_media_metadata_aliases_are_used_when_missing(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.google_drive.schema import FilesUploadOperation + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "media": { + "name": "nested.txt", + "mimeType": "text/plain", + "parents": "p1,p2", + "content": "hi", + } + }, + ) + assert out["name"] == "nested.txt" + assert out["mime_type"] == "text/plain" + assert out["parents"] == ["p1", "p2"] + assert out["content"] == "hi" + assert "media" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_canonical_content_wins_over_media_alias(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.google_drive.schema import FilesUploadOperation + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "name": "root.txt", + "mime_type": "text/plain", + "content": "root", + "media": {"content": "ignored"}, + }, + ) + assert out["content"] == "root" + assert "media" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_canonical_mime_type_wins_over_nested(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "mime_type": "text/plain", + "name": "root.txt", + "content": "c", + "file": {"mime_type": "application/json", "name": "ignored.txt"}, + }, + ) + assert out["mime_type"] == "text/plain" + assert out["name"] == "root.txt" + + +@pytest.mark.asyncio +async def test_mcp_server_invoke_tool_passes_normalized_payload_to_connector_run() -> None: + """invoke_tool should apply normalization before BaseConnector.run (SDK action).""" + from bindings.mcp_server.server import McpServer + from runtime.models import ConnectorResponse + + server = McpServer(connector_ids=["fhir_cerner"]) + cerner = server._factory.get_for_protocol("fhir_cerner", "mcp") + assert cerner is not None + + captured: dict = {} + + async def fake_run(raw_input): + captured["payload"] = dict(raw_input) + return ConnectorResponse(success=True, data={"resource": {"id": "12724066"}}, trace_id="t") + + orig_run = cerner.run + try: + cerner.run = fake_run + await server.invoke_tool("fhir_cerner.read_patient", {"patientId": "12724066"}) + finally: + cerner.run = orig_run + + assert captured["payload"]["resource_id"] == "12724066" + assert captured["payload"].get("action") == "read_patient" + + +@pytest.mark.asyncio +async def test_mcp_server_invoke_google_drive_files_upload_normalizes_payload() -> None: + """invoke_tool should normalize Drive upload aliases before connector.run.""" + from bindings.mcp_server.server import McpServer + from runtime.models import ConnectorResponse + + server = McpServer(connector_ids=["google_drive"]) + gdrive = server._factory.get_for_protocol("google_drive", "mcp") + assert gdrive is not None + + captured: dict = {} + + async def fake_run(raw_input): + captured["payload"] = dict(raw_input) + return ConnectorResponse(success=True, data={"raw": {}}, trace_id="t") + + orig_run = gdrive.run + try: + gdrive.run = fake_run + await server.invoke_tool( + "google_drive.files.upload", + { + "mimeType": "text/plain", + "name": "patient_summary.txt", + "parents": ["folder-id"], + "content": "summary", + "media": {"content": "ignored"}, + "action": "upload", + }, + ) + finally: + gdrive.run = orig_run + + assert captured["payload"]["mime_type"] == "text/plain" + assert captured["payload"]["action"] == "files.upload" + assert "mimeType" not in captured["payload"] + assert "media" not in captured["payload"] diff --git a/tests/test_toolhive_agent.py b/tests/test_toolhive_agent.py index 50aa61a..8cd10a8 100644 --- a/tests/test_toolhive_agent.py +++ b/tests/test_toolhive_agent.py @@ -22,7 +22,45 @@ LLMResponse, ToolCall, ) -from agents.toolhive import AgentRunResult, ToolHiveAgent, ToolHiveMcpClient +from agents.toolhive import ( + AgentRunResult, + ToolHiveAgent, + ToolHiveMcpClient, + _is_tool_failure, + resolve_max_tool_failures, + truncate_tool_result_for_llm, +) + + +def test_truncate_tool_result_for_llm_respects_limit(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("TOOLHIVE_MAX_TOOL_RESULT_CHARS", "20") + long = "x" * 100 + out = truncate_tool_result_for_llm(long) + assert len(out) > 20 + assert out.startswith("x" * 20) + assert "truncated" in out + + +def test_truncate_tool_result_for_llm_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("TOOLHIVE_MAX_TOOL_RESULT_CHARS", "0") + long = "y" * 5000 + assert truncate_tool_result_for_llm(long) == long + + +def test_is_tool_failure_detects_validation_and_error_prefix() -> None: + assert _is_tool_failure("Input validation error: bad") + assert _is_tool_failure("ERROR: connection refused") + assert _is_tool_failure('{"success": false, "message": "x"}') + assert not _is_tool_failure("") + assert not _is_tool_failure('{"success": true, "data": {}}') + + +def test_resolve_max_tool_failures_env_and_override(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("TOOLHIVE_MAX_TOOL_FAILURES", raising=False) + assert resolve_max_tool_failures(None) == 2 + monkeypatch.setenv("TOOLHIVE_MAX_TOOL_FAILURES", "5") + assert resolve_max_tool_failures(None) == 5 + assert resolve_max_tool_failures(3) == 3 # --------------------------------------------------------------------------- @@ -31,37 +69,38 @@ SAMPLE_TOOLS = [ { - "name": "fhir_cerner_read_patient", + "name": "fhir_cerner.read_patient", "description": "Fetch a patient from Cerner FHIR", "input_schema": { "type": "object", - "properties": {"patient_id": {"type": "string"}}, - "required": ["patient_id"], + "properties": {"resource_id": {"type": "string"}}, + "required": ["resource_id"], }, }, { - "name": "google_drive_upload_file", + "name": "google_drive.files.upload", "description": "Upload a file to Google Drive", "input_schema": { "type": "object", "properties": { - "file_name": {"type": "string"}, + "name": {"type": "string"}, + "mime_type": {"type": "string"}, "content": {"type": "string"}, }, - "required": ["file_name", "content"], + "required": ["name", "mime_type", "content"], }, }, { - "name": "smtp_send_email", + "name": "smtp.send_email", "description": "Send an email via SMTP", "input_schema": { "type": "object", "properties": { - "to_email": {"type": "string"}, + "to": {"type": "array", "items": {"type": "string"}}, "subject": {"type": "string"}, "body": {"type": "string"}, }, - "required": ["to_email", "subject", "body"], + "required": ["to", "subject", "body"], }, }, ] @@ -142,19 +181,37 @@ async def test_agent_runs_three_tool_sequence() -> None: # Step 1: Call FHIR LLMResponse( content=None, - tool_calls=[_tool_call("fhir_cerner_read_patient", {"patient_id": "12724066"})], + tool_calls=[_tool_call("fhir_cerner.read_patient", {"resource_id": "12724066"})], stop_reason="tool_calls", ), # Step 2: Call Drive LLMResponse( content=None, - tool_calls=[_tool_call("google_drive_upload_file", {"file_name": "summary.txt", "content": "Patient: John"})], + tool_calls=[ + _tool_call( + "google_drive.files.upload", + { + "name": "summary.txt", + "mime_type": "text/plain", + "content": "Patient: John", + }, + ) + ], stop_reason="tool_calls", ), # Step 3: Send email LLMResponse( content=None, - tool_calls=[_tool_call("smtp_send_email", {"to_email": "doc@example.com", "subject": "Summary", "body": "Patient: John"})], + tool_calls=[ + _tool_call( + "smtp.send_email", + { + "to": ["doc@example.com"], + "subject": "Summary", + "body": "Patient: John", + }, + ) + ], stop_reason="tool_calls", ), # Final answer @@ -173,21 +230,47 @@ async def test_agent_runs_three_tool_sequence() -> None: assert result.success is True assert result.final_answer == "All 3 steps completed successfully." assert len(result.steps) == 3 - assert result.steps[0].tool_called == "fhir_cerner_read_patient" - assert result.steps[1].tool_called == "google_drive_upload_file" - assert result.steps[2].tool_called == "smtp_send_email" + assert result.steps[0].tool_called == "fhir_cerner.read_patient" + assert result.steps[1].tool_called == "google_drive.files.upload" + assert result.steps[2].tool_called == "smtp.send_email" # Verify MCP was called exactly 3 times assert mock_mcp.call_tool.await_count == 3 +@pytest.mark.asyncio +async def test_agent_id_first_turn_calls_read_patient_with_resource_id() -> None: + """Document ID-first flow: Cerner read uses canonical resource_id (not search_patients).""" + responses = [ + LLMResponse( + content=None, + tool_calls=[_tool_call("fhir_cerner.read_patient", {"resource_id": "12724066"})], + stop_reason="tool_calls", + ), + LLMResponse(content="Patient retrieved.", tool_calls=[], stop_reason="stop"), + ] + provider = _MockLLMProvider(responses) + mock_mcp = AsyncMock(spec=ToolHiveMcpClient) + mock_mcp.list_tools.return_value = SAMPLE_TOOLS + mock_mcp.call_tool.return_value = '{"success": true}' + + agent = ToolHiveAgent(mcp_client=mock_mcp, llm_provider=provider, max_steps=10) + result = await agent.run("Patient ID 12724066 — fetch from Cerner") + + assert result.success is True + mock_mcp.call_tool.assert_awaited_once() + call = mock_mcp.call_tool.await_args + assert call[0][0] == "fhir_cerner.read_patient" + assert call[0][1]["resource_id"] == "12724066" + + @pytest.mark.asyncio async def test_agent_respects_max_steps() -> None: """Agent should stop and return an error if max_steps is reached.""" # LLM always returns a tool call — never finishes infinite_response = LLMResponse( content=None, - tool_calls=[_tool_call("fhir_cerner_read_patient", {"patient_id": "x"})], + tool_calls=[_tool_call("fhir_cerner.read_patient", {"resource_id": "x"})], stop_reason="tool_calls", ) provider = _MockLLMProvider([infinite_response]) @@ -211,7 +294,7 @@ async def test_agent_handles_tool_error_gracefully() -> None: responses = [ LLMResponse( content=None, - tool_calls=[_tool_call("fhir_cerner_read_patient", {"patient_id": "bad"})], + tool_calls=[_tool_call("fhir_cerner.read_patient", {"resource_id": "bad"})], stop_reason="tool_calls", ), LLMResponse(content="Unable to fetch patient — error recorded.", tool_calls=[], stop_reason="stop"), @@ -244,100 +327,150 @@ async def test_agent_fails_when_mcp_unreachable() -> None: assert "Failed to list MCP tools" in (result.error or "") +@pytest.mark.asyncio +async def test_agent_stops_after_repeated_tool_failures() -> None: + """After max_tool_failures for the same tool, stop without further LLM steps.""" + fail_msg = "Input validation error: bad args" + responses = [ + LLMResponse( + content=None, + tool_calls=[_tool_call("google_drive.files.upload", {"name": "a.txt"})], + stop_reason="tool_calls", + ), + LLMResponse( + content=None, + tool_calls=[_tool_call("google_drive.files.upload", {"name": "a.txt"})], + stop_reason="tool_calls", + ), + LLMResponse(content="should not run", tool_calls=[], stop_reason="stop"), + ] + provider = _MockLLMProvider(responses) + mock_mcp = AsyncMock(spec=ToolHiveMcpClient) + mock_mcp.list_tools.return_value = SAMPLE_TOOLS + mock_mcp.call_tool.return_value = fail_msg + + agent = ToolHiveAgent( + mcp_client=mock_mcp, + llm_provider=provider, + max_steps=10, + max_tool_failures=2, + ) + result = await agent.run("Upload to Drive") + + assert result.success is False + assert len(result.steps) == 2 + assert "google_drive.files.upload" in (result.error or "") + assert "failed 2 times" in (result.final_answer or result.error or "").lower() + assert mock_mcp.call_tool.await_count == 2 + assert provider._call_count == 2 + + +@pytest.mark.asyncio +async def test_agent_success_then_two_failures_same_tool_aborts() -> None: + """Failures only increment on failed tool results; abort after second failure.""" + ok = '{"success": true, "data": {}}' + fail_msg = "Input validation error: x" + responses = [ + LLMResponse( + content=None, + tool_calls=[_tool_call("google_drive.files.upload", {})], + stop_reason="tool_calls", + ), + LLMResponse( + content=None, + tool_calls=[_tool_call("google_drive.files.upload", {})], + stop_reason="tool_calls", + ), + LLMResponse( + content=None, + tool_calls=[_tool_call("google_drive.files.upload", {})], + stop_reason="tool_calls", + ), + ] + provider = _MockLLMProvider(responses) + mock_mcp = AsyncMock(spec=ToolHiveMcpClient) + mock_mcp.list_tools.return_value = SAMPLE_TOOLS + mock_mcp.call_tool.side_effect = [ok, fail_msg, fail_msg] + + agent = ToolHiveAgent( + mcp_client=mock_mcp, + llm_provider=provider, + max_steps=10, + max_tool_failures=2, + ) + result = await agent.run("x") + + assert result.success is False + assert len(result.steps) == 3 + assert mock_mcp.call_tool.await_count == 3 + + # --------------------------------------------------------------------------- # MCP entrypoint smoke test # --------------------------------------------------------------------------- -def test_mcp_entrypoint_registers_eight_tools() -> None: - """The FastMCP server should expose the full FHIR + integration tool surface.""" - # We patch all external deps before importing the module to avoid side effects - with ( - patch("bindings.factory.ConnectorFactory") as mock_factory_cls, - patch("connectors.auto_register"), - patch("mcp.server.fastmcp.FastMCP", autospec=False) as mock_fastmcp_cls, - ): - mock_factory = MagicMock() - mock_factory._connectors = { - "fhir_cerner": MagicMock(), - "fhir_epic": MagicMock(), - "google_drive": MagicMock(), - "smtp": MagicMock(), - } - mock_factory_cls.return_value = mock_factory - - mock_mcp_instance = MagicMock() - registered_tools: List[str] = [] - - def fake_tool(*args: Any, **kwargs: Any): - name = kwargs.get("name") or (args[0] if args else "unknown") - registered_tools.append(name) - return lambda fn: fn # decorator passthrough - - mock_mcp_instance.tool = fake_tool - mock_fastmcp_cls.return_value = mock_mcp_instance - - # Import inside the test to ensure it picks up the mocks - from agents.mcp_entrypoint import _make_server - _make_server() - - assert len(registered_tools) == 8 - assert "fhir_cerner_read_patient" in registered_tools - assert "fhir_cerner_search_patients" in registered_tools - assert "fhir_cerner_search_encounters" in registered_tools - assert "fhir_epic_read_patient" in registered_tools - assert "fhir_epic_search_patients" in registered_tools - assert "fhir_epic_search_encounters" in registered_tools - assert "google_drive_upload_file" in registered_tools - assert "smtp_send_email" in registered_tools +def test_mcp_entrypoint_exposes_manifest_tools() -> None: + """Unified MCP server lists all connectors enabled for MCP in config.""" + from bindings.mcp_server.server import McpServer + + server = McpServer(server_name="node-wire") + names = {t["name"] for t in server.list_tools()} + assert "fhir_cerner.read_patient" in names + assert "fhir_epic.read_patient" in names + assert "google_drive.files.upload" in names + assert "smtp.send_email" in names + assert "stripe.charge" in names + assert "http_generic.request" in names + # Broader surface than the old 8 FastMCP tools + assert len(names) >= 18 # --------------------------------------------------------------------------- -# Individual MCP server smoke tests +# Individual MCP entrypoint modules (thin wrappers) # --------------------------------------------------------------------------- -def _make_server_smoke(module_path: str, expected_tool: str) -> None: - """Helper: verify a per-connector _make_server() registers exactly one tool.""" - with ( - patch("bindings.factory.ConnectorFactory") as mock_factory_cls, - patch("connectors.auto_register"), - patch("mcp.server.fastmcp.FastMCP", autospec=False) as mock_fastmcp_cls, - ): - mock_factory = MagicMock() - mock_factory._connectors = {} - mock_factory_cls.return_value = mock_factory - - mock_mcp_instance = MagicMock() - registered_tools: List[str] = [] - - def fake_tool(*args: Any, **kwargs: Any): - name = kwargs.get("name") or (args[0] if args else "unknown") - registered_tools.append(name) - return lambda fn: fn - - mock_mcp_instance.tool = fake_tool - mock_fastmcp_cls.return_value = mock_mcp_instance - - import importlib - mod = importlib.import_module(module_path) - mod._make_server() - - assert registered_tools == [expected_tool], ( - f"{module_path}: expected [{expected_tool}], got {registered_tools}" - ) +def test_fhir_cerner_mcp_main_callable() -> None: + from agents.fhir_cerner_mcp import main + + assert callable(main) + + +def test_fhir_epic_mcp_main_callable() -> None: + from agents.fhir_epic_mcp import main + + assert callable(main) + + +def test_google_drive_mcp_main_callable() -> None: + from agents.google_drive_mcp import main + + assert callable(main) + + +def test_smtp_mcp_main_callable() -> None: + from agents.smtp_mcp import main + + assert callable(main) + + +def test_mcp_server_matches_per_connector_entrypoints() -> None: + """Per-connector scripts use connector_ids filter; tool prefixes must match.""" + from bindings.mcp_server.server import McpServer -def test_fhir_cerner_mcp_registers_one_tool() -> None: - """fhir_cerner_mcp._make_server() should expose exactly fhir_cerner_read_patient.""" - _make_server_smoke("agents.fhir_cerner_mcp", "fhir_cerner_read_patient") + full = {t["name"] for t in McpServer().list_tools()} + cerner = {t["name"] for t in McpServer(connector_ids=["fhir_cerner"]).list_tools()} + assert cerner == {n for n in full if n.startswith("fhir_cerner.")} -def test_fhir_epic_mcp_registers_one_tool() -> None: - """fhir_epic_mcp._make_server() should expose exactly fhir_epic_read_patient.""" - _make_server_smoke("agents.fhir_epic_mcp", "fhir_epic_read_patient") + epic = {t["name"] for t in McpServer(connector_ids=["fhir_epic"]).list_tools()} + assert epic == {n for n in full if n.startswith("fhir_epic.")} + drive = {t["name"] for t in McpServer(connector_ids=["google_drive"]).list_tools()} + assert drive == {n for n in full if n.startswith("google_drive.")} + assert "google_drive.files.upload" in drive -def test_google_drive_mcp_registers_one_tool() -> None: - """google_drive_mcp._make_server() should expose exactly google_drive_upload_file.""" - _make_server_smoke("agents.google_drive_mcp", "google_drive_upload_file") + smtp = {t["name"] for t in McpServer(connector_ids=["smtp"]).list_tools()} + assert smtp == {"smtp.send_email"} From 88de5424d988afd6fd5e209535872730d9e2b19f Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Wed, 1 Apr 2026 04:05:15 -0700 Subject: [PATCH 08/60] Updated Docs and UI --- README.md | 12 ++- Setup.md | 38 +++++++++- docs/creating-a-connector.md | 128 ++++++++++++++++++++++++++++++++ docs/toolhive_agent_scenario.md | 9 +-- playground/README.md | 18 +++-- playground/index.html | 6 +- 6 files changed, 191 insertions(+), 20 deletions(-) create mode 100644 docs/creating-a-connector.md diff --git a/README.md b/README.md index 66d68ab..8eadd65 100644 --- a/README.md +++ b/README.md @@ -123,14 +123,14 @@ Examples: Google Drive has a full doc at `src/connectors/google_drive/README.md` ### gRPC / MCP - **gRPC:** Started when `MODE=GRPC`; server listens on port 50051. -- **MCP:** Started when `MODE=MCP`; server exposes tools for discovery and invocation. +- **MCP:** `MODE=MCP` starts a minimal MCP-style placeholder server (sufficient for local, manual inspection), but it is not the full stdio MCP server used for ToolHive and the agent layer. ### Entrypoint - Run with `python -m bindings_entrypoint` (or the `node-wire` script after install). The **MODE** environment variable selects: - **API** (default) – REST API on port 8000. - **GRPC** – gRPC server on port 50051. - - **MCP** – MCP server. + - **MCP** – minimal MCP-style placeholder server (see note above). --- @@ -193,3 +193,11 @@ $env:GOOGLE_DRIVE_SA_JSON = Get-Content -Path $saPath -Raw ## Dependencies All dependencies are declared in `pyproject.toml` (Python >=3.11). They include: pydantic, FastAPI, uvicorn, tenacity, pybreaker, OpenTelemetry, grpcio, and connector-specific libraries (httpx, aiosmtplib, stripe, google-auth, google-api-python-client, etc.). See `pyproject.toml` for the full list and versions. + +--- + +## Setup and development docs + +- Platform setup (REST/gRPC/agents MCP): [Setup.md](Setup.md) +- Individual connector MCP servers (ToolHive): [docs/mcp-servers.md](docs/mcp-servers.md) +- Creating a new connector: [docs/creating-a-connector.md](docs/creating-a-connector.md) diff --git a/Setup.md b/Setup.md index 7be559f..069d512 100644 --- a/Setup.md +++ b/Setup.md @@ -23,10 +23,11 @@ Node Wire is a Python framework that runs connector adapters (Google Drive, SMTP | Requirement | Version | Notes | | ----------- | ------- | --------------------------------------- | -| Python | 3.12+ | `python --version` to check | +| Python | 3.11+ | `python --version` to check | | pip or uv | Latest | `pip install --upgrade pip` | | Git | Any | To clone the repo | | Docker | Latest | Only needed for ToolHive MCP deployment | +| Node.js | Any LTS | Only needed for `npx @modelcontextprotocol/inspector` | --- @@ -36,7 +37,7 @@ Node Wire is a Python framework that runs connector adapters (Google Drive, SMTP ```bash # 1. Clone the repository git clone -cd connector-platform +cd node-wire # 2. Install dependencies (recommended: uv) uv sync --extra agents @@ -45,6 +46,8 @@ uv sync --extra agents uv run node-wire --help ``` +> **Install uv:** See the official installer docs at `https://docs.astral.sh/uv/`. +> > **REST/gRPC only** (no AI agent features): `uv sync` without the extra is sufficient. > > **Alternative (pip):** If you’re not using `uv`, install editable deps with pip: @@ -67,6 +70,8 @@ cp sample.env .env You only need to fill in the sections for the connectors you plan to use. The platform starts successfully even if some credentials are missing — those connectors will simply return an error when called. +> **Doc convention:** Environment variable names in the docs follow `sample.env`. Some legacy keys (like `stripe_api_key`) are intentionally lower-case because that is what the connector reads. + ### Environment Variable Sections @@ -74,7 +79,7 @@ You only need to fill in the sections for the connectors you plan to use. The pl | ---------------- | ------------------------------------------------------------------------------------------------------------------- | ---------------------- | | **FHIR Epic** | `EPIC_FHIR_BASE_URL`, `EPIC_TOKEN_URL`, `EPIC_CLIENT_ID`, `EPIC_KID`, `EPIC_PRIVATE_KEY` | Epic EHR integration | | **FHIR Cerner** | `CERNER_FHIR_BASE_URL`, `CERNER_TOKEN_URL`, `CERNER_CLIENT_ID`, `CERNER_KID`, `CERNER_PRIVATE_KEY`, `CERNER_SCOPES` | Cerner EHR integration | -| **Google Drive** | `google_drive_sa_json`, `GOOGLE_DRIVE_FOLDER_ID` | Google Drive connector | +| **Google Drive** | `GOOGLE_DRIVE_SA_JSON`, `GOOGLE_DRIVE_FOLDER_ID` | Google Drive connector | | **SMTP** | `SMTP_HOST`, `SMTP_PORT`, `SMTP_USERNAME`, `SMTP_PASSWORD` | Sending emails | | **LLM / Agent** | `LLM_PROVIDER`, `GROQ_API_KEY` (or other provider key) | AI agent / ToolHive | | **ToolHive** | `TOOLHIVE_MCP_URL` (single) or `TOOLHIVE_MCP_URLS` (comma-separated, multi-server) | ToolHive MCP proxy | @@ -95,6 +100,19 @@ The platform supports three modes. Set the `MODE` environment variable to switch | **gRPC** | `MODE=GRPC uv run node-wire` | `50051` | gRPC clients | | **MCP (stdio)** | `python -m agents.mcp_entrypoint` | stdio | AI agents, ToolHive, Claude Desktop | +> **Important:** `MODE=MCP` for `node-wire` / `python -m bindings_entrypoint` starts a minimal MCP-style placeholder server, not the full stdio MCP server used with ToolHive and the agent layer. For ToolHive/Inspector/agents, use `python -m agents.mcp_entrypoint` (or the per-connector MCP servers in `docs/mcp-servers.md`). + +### Configuration file (`config/connectors.yaml`) + +Connectors are loaded from `config/connectors.yaml`. Each connector has: + +- `enabled`: whether the connector is instantiated at startup +- `exposed_via`: which protocols can access it (`rest`, `grpc`, `mcp`) + +If a connector is disabled (or not exposed for a protocol), requests to it will fail with “not configured / not available” even if your `.env` is correct. + +For details on adding a new connector to the runtime, see `docs/creating-a-connector.md`. + ### REST API Quick Start @@ -106,6 +124,12 @@ uv run node-wire PORT=8001 uv run node-wire ``` +Equivalent entrypoint (without `uv`): + +```bash +MODE=API python -m bindings_entrypoint +``` + Once running: - **Health check:** `GET http://localhost:8000/health` @@ -216,7 +240,7 @@ Quick summary of what you'll need: Add to your `.env`: ```env -google_drive_sa_json=/absolute/path/to/service-account.json +GOOGLE_DRIVE_SA_JSON=/absolute/path/to/service-account.json GOOGLE_DRIVE_FOLDER_ID=your-folder-id-from-drive-url ``` @@ -315,6 +339,12 @@ npx @modelcontextprotocol/inspector python -m agents.google_drive_mcp npx @modelcontextprotocol/inspector python -m agents.mcp_entrypoint ``` +### Troubleshooting quick hits + +- **Port 8000 in use**: set `PORT=8001` (or any free port) when starting the REST API. +- **Connector “not configured”**: confirm it is `enabled: true` (and exposed for your protocol) in `config/connectors.yaml`. +- **ToolHive + Google Drive auth failure**: inside ToolHive, `GOOGLE_DRIVE_SA_JSON` must be the JSON **contents** (not a file path). Locally, it can be an absolute file path (see `docs/mcp-servers.md`). + --- ## Running Tests diff --git a/docs/creating-a-connector.md b/docs/creating-a-connector.md new file mode 100644 index 0000000..98f1bce --- /dev/null +++ b/docs/creating-a-connector.md @@ -0,0 +1,128 @@ +# Creating a connector in Node Wire + +This guide explains how to implement a new connector (Layer B) and make it available via REST/gRPC/MCP (Layer C). + +## How connectors plug into the platform + +- **Layer B (`src/connectors/`)**: connector implementations (schemas + logic). +- **Layer C (`src/bindings/`)**: protocol bindings and configuration-driven loading. + +At startup, the REST binding: + +- Imports connector `registration.py` modules via `connectors.auto_register()` so exceptions can be mapped. +- Loads and instantiates enabled connectors via `ConnectorFactory` using `config/connectors.yaml`. + +## Connector shape (single-action) + +Most connectors are a single `BaseConnector` subclass with: + +- `connector_id`: stable identifier (used in URLs and config) +- `action`: the action name (used in URLs and manifests) +- `schema.py`: Pydantic input/output models +- `logic.py`: connector implementation (`internal_execute`) +- `registration.py` (optional): register exception mappings for runtime error taxonomy + +Use the `http_generic` connector as a reference: + +- `src/connectors/http_generic/schema.py` +- `src/connectors/http_generic/logic.py` +- `src/connectors/http_generic/registration.py` (if present) + +### Minimal checklist (single-action) + +1. Create a new package: `src/connectors//`. +2. Define request/response models in `schema.py`. +3. Implement `logic.py` with a `BaseConnector[...]` subclass: + - set `connector_id = ""` + - set `action = ""` + - implement `internal_execute(self, params, *, trace_id)` +4. If you raise connector-specific exceptions, add `registration.py` and register them with the runtime `ErrorMapper` so clients get stable `error_code`/`error_category`. + +## Connector shape (multi-action) + +Some connectors expose multiple actions from a single logical integration (e.g. FHIR). In that pattern, the factory stores **one** object under a `connector_id`, and that object exposes: + +- `list_actions() -> list[BaseConnector]` +- `get_action(name: str) -> BaseConnector | None` + +The factory uses these helpers for discovery and dispatch. + +See the Epic FHIR connector implementation for the pattern: + +- `src/connectors/fhir_epic/logic.py` (`FhirEpicConnector`, `_FhirAction`, `list_actions()`, `get_action()`) + +## Wire the connector into the runtime (required) + +There are two places to update so the platform can load and expose your connector. + +### 1) Add an entry to `config/connectors.yaml` + +Add a new block under `connectors:`: + +```yaml +connectors: + my_connector: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] +``` + +- `enabled: true` controls whether the connector is instantiated. +- `exposed_via` controls which protocols can see it. + +If `enabled` is false, or if a protocol is missing from `exposed_via`, you will see “not configured / not available” errors even if your `.env` is correct. + +### 2) Add factory wiring in `src/bindings/factory.py` + +`ConnectorFactory` instantiates connectors via `_instantiate(connector_id)`. Add a branch for your `connector_id` that returns your connector instance and passes the `secret_provider`. + +For single-action connectors, the factory typically passes the input/output model classes too (example: `http_generic`, `google_drive`). + +For multi-action connectors, the factory stores one instance (example: `fhir_epic`, `fhir_cerner`), and `get_for_protocol()` uses `get_action()` when an action is requested. + +## Registration (`registration.py`) + +`connectors.auto_register()` imports `registration` modules from connector subpackages automatically: + +- A connector package may omit `registration.py` if it doesn’t need custom exception mapping. +- If present, `registration.py` should register exception types with the runtime error taxonomy so clients get predictable categories (`BUSINESS`, `AUTH`, `RETRYABLE`, `FATAL`). + +See `src/connectors/__init__.py` for the auto-discovery behavior. + +## Secrets and configuration conventions + +- Connector secrets are read via the `SecretProvider` (`self.secret_provider.get_secret("KEY")`). +- For local development, secrets are typically defined in `.env` using the names in `sample.env`. +- The platform’s `EnvSecretProvider` is case-insensitive (it checks both `KEY` and `key`), but prefer **one canonical spelling** in documentation and config. + +## How exposure works per protocol + +The REST binding exposes: + +- `POST /connectors/{connector_id}/{action}` + +Routes and schemas come from the connector manifest built over the factory’s `list_for_protocol("rest")` output. + +The same `enabled` / `exposed_via` gating applies to gRPC and the built-in MCP-style manifest. + +## Optional: MCP tools for ToolHive / agents + +This repository also includes MCP servers under `src/agents/` (for ToolHive and other MCP clients). These are separate from the REST/gRPC bindings: + +- **Combined MCP server**: `python -m agents.mcp_entrypoint` +- **Per-connector MCP servers**: `python -m agents._mcp` (see `docs/mcp-servers.md`) + +Adding a connector to the runtime (factory + YAML) does not automatically create a ToolHive-ready MCP server. If you need MCP tools, you typically add a small wrapper in `src/agents/` that calls into the connector via `ConnectorFactory`. + +## Loading flow (simplified) + +```mermaid +flowchart LR + yamlFile[config/connectors.yaml] + factory[ConnectorFactory.load] + instantiate[_instantiate connector_id] + connector[Connector_instance] + yamlFile --> factory + factory --> instantiate + instantiate --> connector +``` + diff --git a/docs/toolhive_agent_scenario.md b/docs/toolhive_agent_scenario.md index 666c39b..40fea59 100644 --- a/docs/toolhive_agent_scenario.md +++ b/docs/toolhive_agent_scenario.md @@ -62,6 +62,7 @@ For modular deployments, each connector can be run as an independent MCP server - `nw-google-drive` (Google Drive) - `nw-smartonfhir-epic` (Epic SMART on FHIR) - `nw-smartonfhir-cerner` (Cerner SMART on FHIR) +- `nw-smtp` (SMTP email) When running multiple MCP servers, configure the agent with **`TOOLHIVE_MCP_URLS`** (comma-separated list of ToolHive proxy URLs). The agent will merge tools across servers. @@ -135,7 +136,7 @@ Below is the full set of environment variables used by the connector platform an | `GROQ_API_KEY` | LLM (Groq) | Your Groq API key | | `GROQ_MODEL` | LLM | Example: `openai/gpt-oss-120b` | | `MCP_TRANSPORT` | ToolHive / local | `stdio` when running in ToolHive container | -| `PYTHONPATH` | Runtime | e.g. `/app/src` for container; `d:\connector-platform\src` locally | +| `PYTHONPATH` | Runtime | e.g. `/app/src` for container; `**/node-wire/src` locally | | `SMTP_HOST` | SMTP connector | Example: `sandbox.smtp.mailtrap.io` | | `SMTP_PORT` | SMTP connector | Example: `2525` | | `SMTP_USERNAME` | SMTP connector | Mailtrap / SMTP user | @@ -160,7 +161,7 @@ Option A — Recommended: ToolHive UI (no code) Option B — Local quick run (Windows PowerShell) -Prerequisite: Install Python 3.10+ and Git. If you cannot install, ask an administrator to run Option A. +Prerequisite: Install Python 3.11+ and Git. If you cannot install, ask an administrator to run Option A. 1. Open PowerShell and clone or navigate to the project folder. 2. Create a simple `.env` file in the project root (replace placeholder values): @@ -204,8 +205,6 @@ Notes for non-developers: From the root of the repository: ```bash -cd connector-platform - docker build -t node-wire:latest . ``` @@ -538,7 +537,7 @@ tests/test_toolhive_agent.py::test_mcp_entrypoint_registers_three_to PASSED ## File layout (`agents`) ``` -connector-platform/ +node-wire/ ├── Dockerfile ← Docker image for ToolHive ├── pyproject.toml ← [agents] extras added ├── sample.env ← env var reference diff --git a/playground/README.md b/playground/README.md index 1ba1290..b7851ab 100644 --- a/playground/README.md +++ b/playground/README.md @@ -97,8 +97,8 @@ The demo is pre-configured with mock/sandbox endpoints for immediate use. To tes To test the Google Drive integration manually, follow these specialized setup steps: 1. **Service Account**: Create a Service Account in the Google Cloud Console with the **Google Drive API** enabled. Download the JSON key. 2. **Secret Configuration**: - * Place the JSON key file in your project directory (e.g., `D:\connector-platform\service_account.json`). - * Update your `.env` file: `GOOGLE_DRIVE_SA_JSON=D:\connector-platform\service_account.json`. + * Place the JSON key file somewhere safe on your machine (e.g., `/service_account.json`). + * Update your `.env` file: `GOOGLE_DRIVE_SA_JSON=/service_account.json`. * *Note: The platform now supports direct file paths for easier local configuration.* 3. **Permissions**: If using a specific **Vault Folder ID**, ensure that folder is shared with the Service Account's email address (found in the JSON) with "Editor" or "Manager" permissions. 4. **Workflow Verification**: @@ -110,7 +110,7 @@ To test the Google Drive integration manually, follow these specialized setup st To enable the AI Agent chat, you need to configure an LLM provider: 1. **Select Provider**: Set `LLM_PROVIDER` to `groq` (default) or `openai` in your `.env`. 2. **Add API Key**: Provide the corresponding key, e.g., `GROQ_API_KEY=your_key_here`. -3. **SMTP Setup**: (Optional) Add SMTP credentials (`SMTP_HOST`, `SMTP_PORT`, `SMTP_USER`, `SMTP_PASS`) to enable the agent to send emails. +3. **SMTP Setup**: (Optional) Add SMTP credentials (`SMTP_HOST`, `SMTP_PORT`, `SMTP_USERNAME`, `SMTP_PASSWORD`) to enable the agent to send emails. 4. **MCP URL**: (Optional) If running the MCP server in a separate container, set `TOOLHIVE_MCP_URL` to point to the MCP proxy. --- @@ -119,8 +119,14 @@ To enable the AI Agent chat, you need to configure an LLM provider: 1. Navigate to the project root. 2. Start the FastAPI server: - ```bash - set MODE=API&& python -m bindings_entrypoint - ``` + +```bash +# Recommended +uv run node-wire + +# Equivalent (no uv) +MODE=API python -m bindings_entrypoint +``` + 3. Open your browser to `http://localhost:8000/playground/` (or the configured port). 4. Switch between **EHR**, **IT Ops**, **Cerner**, **Google Drive Vault**, and **AI Agent** tabs to explore the different workflows. diff --git a/playground/index.html b/playground/index.html index 46978e8..a18a660 100644 --- a/playground/index.html +++ b/playground/index.html @@ -4,7 +4,7 @@ - Node-wire Playground + node-wire Playground @@ -28,7 +28,7 @@
-

Node-Wire

+

node-wire

Autonomous Connector Orchestration Platform

@@ -93,7 +93,7 @@

Connectors

-

Node-Wire MCP via ToolHive

+

node-wire MCP via ToolHive

MCP Agent — Guardrailed
+
+
+ + Transport: stdio +
+
+
@@ -599,4 +606,4 @@

Technical Audit

- \ No newline at end of file + diff --git a/playground/scenarios.py b/playground/scenarios.py index 5beeca0..5de6f5d 100644 --- a/playground/scenarios.py +++ b/playground/scenarios.py @@ -1,12 +1,14 @@ from __future__ import annotations import base64 +import json import logging import uuid from datetime import datetime, timezone from typing import Any, Dict, List, Optional from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse from pydantic import BaseModel, ValidationError, model_validator from dotenv import load_dotenv import os @@ -1105,6 +1107,15 @@ class AgentChatResponse(BaseModel): success: bool +@router.get("/agent-transport") +async def agent_transport() -> Dict[str, str]: + transport = _current_agent_transport() + return { + "transport": transport, + "label": "Streamable HTTP" if transport == "streamable-http" else "stdio", + } + + AGENT_GUARDRAIL_PROMPT = ( "You are a healthcare data assistant. You have access to tools for fetching " "patient data from Cerner FHIR and Epic FHIR, uploading files to Google Drive, and sending " @@ -1146,6 +1157,27 @@ class AgentChatResponse(BaseModel): ) +def _build_agent_chat_task(payload: AgentChatInput) -> str: + history_text_parts = [] + for msg in payload.history: + role = msg.get("role", "user") + content = msg.get("content", "") + history_text_parts.append(f"{role.upper()}: {content}") + + if history_text_parts: + return ( + "Previous conversation:\n" + + "\n".join(history_text_parts) + + f"\n\nUSER (latest): {payload.message}" + ) + return payload.message + + +def _current_agent_transport() -> str: + transport = os.environ.get("NW_MCP_TRANSPORT", "stdio").strip().lower() or "stdio" + return transport if transport in {"stdio", "streamable-http"} else "stdio" + + @router.post("/agent-chat", response_model=AgentChatResponse) @@ -1185,25 +1217,11 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: logger.info("Agent Chat | creating LLM provider: %s", provider_name) llm_provider = LLMProviderFactory.create_from_env() - # Build the task from the conversation history + current message - # The agent will see the full context - history_text_parts = [] - for msg in payload.history: - role = msg.get("role", "user") - content = msg.get("content", "") - history_text_parts.append(f"{role.upper()}: {content}") - - if history_text_parts: - task = ( - "Previous conversation:\n" - + "\n".join(history_text_parts) - + f"\n\nUSER (latest): {payload.message}" - ) - else: - task = payload.message + task = _build_agent_chat_task(payload) # Determine MCP transport — try proxy first, fallback to local stdio - urls = resolve_mcp_urls() + transport = _current_agent_transport() + urls = resolve_mcp_urls() if transport == "streamable-http" else [] run_result = None if urls: @@ -1278,3 +1296,88 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: trace_id=trace_id, success=False, ) + + +@router.post("/agent-chat-stream") +async def agent_chat_stream(payload: AgentChatInput) -> Any: + """ + Stream agent progress and final answer chunks to the playground UI. + + Tool steps are emitted as each tool finishes. The final assistant answer is + emitted as chunks instead of waiting for the browser to receive one large + buffered JSON payload. + """ + + async def stream_events(): + try: + import sys + + from agents.llm_factory import LLMProviderFactory + from agents.toolhive import ( + MultiMcpClient, + StdioMcpClient, + ToolHiveAgent, + ToolHiveMcpClient, + resolve_mcp_urls, + resolve_max_tool_failures, + ) + + if not payload.message.strip(): + yield json.dumps({ + "type": "final_chunk", + "content": "Please type a message to get started.", + }) + "\n" + yield json.dumps({ + "type": "done", + "trace_id": str(uuid.uuid4()), + "success": False, + }) + "\n" + return + + llm_provider = LLMProviderFactory.create_from_env() + task = _build_agent_chat_task(payload) + transport = _current_agent_transport() + urls = resolve_mcp_urls() if transport == "streamable-http" else [] + + if urls: + if len(urls) == 1: + mcp_client = ToolHiveMcpClient(urls[0]) + else: + mcp_client = MultiMcpClient([ToolHiveMcpClient(u) for u in urls]) + agent = ToolHiveAgent( + mcp_client, + llm_provider, + max_steps=10, + max_tool_failures=resolve_max_tool_failures(None), + ) + agent._system_prompt = AGENT_GUARDRAIL_PROMPT + async for event in agent.run_events(task): + yield json.dumps(event) + "\n" + return + + cmd = [sys.executable, "-m", "agents.mcp_entrypoint"] + async with StdioMcpClient(cmd) as mcp_client: + agent = ToolHiveAgent( + mcp_client, + llm_provider, + max_steps=10, + max_tool_failures=resolve_max_tool_failures(None), + ) + agent._system_prompt = AGENT_GUARDRAIL_PROMPT + async for event in agent.run_events(task): + yield json.dumps(event) + "\n" + + except Exception as exc: + logger.error("Agent Chat stream failed: %s", exc, exc_info=True) + trace_id = str(uuid.uuid4()) + yield json.dumps({ + "type": "final_chunk", + "content": f"Sorry, I encountered an error: {exc}. Please check the server configuration and try again.", + }) + "\n" + yield json.dumps({ + "type": "done", + "trace_id": trace_id, + "success": False, + }) + "\n" + + return StreamingResponse(stream_events(), media_type="application/x-ndjson") diff --git a/playground/style.css b/playground/style.css index 28f669b..3b62fe3 100644 --- a/playground/style.css +++ b/playground/style.css @@ -1600,4 +1600,53 @@ textarea:focus { .chat-reset-btn svg { width: 16px; height: 16px; -} \ No newline at end of file +} + +.transport-status-bar { + display: flex; + justify-content: flex-start; + align-items: center; + gap: 1rem; + margin: -0.75rem 0 1rem; + padding: 0.7rem 0.75rem; + background: #f8fafc; + border: 1px solid #e2e8f0; + border-radius: 0.875rem; +} + +.transport-status-pill { + display: flex; + align-items: center; + gap: 0.55rem; + color: var(--brand-accent); + background: white; + border: 1px solid #e2e8f0; + border-radius: 999px; + padding: 0.5rem 0.85rem; + font-weight: 700; + font-size: 0.78rem; + box-shadow: 0 8px 18px rgba(15, 23, 42, 0.04); +} + +.transport-status-dot { + width: 0.55rem; + height: 0.55rem; + border-radius: 999px; + background: var(--success); + box-shadow: 0 0 0 4px rgba(16, 185, 129, 0.14); +} + +.streaming-bubble p { + white-space: pre-wrap; +} + +@media (max-width: 768px) { + .transport-status-bar { + align-items: stretch; + flex-direction: column; + } + + .transport-status-pill { + justify-content: center; + } +} diff --git a/src/agents/toolhive.py b/src/agents/toolhive.py index 2d0a9ae..a0fdbb3 100644 --- a/src/agents/toolhive.py +++ b/src/agents/toolhive.py @@ -41,7 +41,7 @@ import uuid from contextlib import AsyncExitStack from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Protocol +from typing import Any, AsyncIterator, Dict, List, Optional, Protocol from dotenv import load_dotenv @@ -155,6 +155,25 @@ def _tool_failure_abort_message(tool_name: str, max_failures: int) -> str: ) +def _chunk_agent_text(text: str, chunk_size: int = 180) -> List[str]: + """Split final assistant text into UI-friendly chunks.""" + if not text: + return [""] + + chunks: List[str] = [] + current = "" + for part in text.split(" "): + candidate = f"{current} {part}".strip() + if current and len(candidate) > chunk_size: + chunks.append(current + " ") + current = part + else: + current = candidate + if current: + chunks.append(current) + return chunks + + # --------------------------------------------------------------------------- # Result model # --------------------------------------------------------------------------- @@ -589,6 +608,112 @@ async def run(self, task: str) -> AgentRunResult: return result + async def run_events(self, task: str) -> AsyncIterator[Dict[str, Any]]: + """ + Stream agent progress events as the ReAct loop runs. + + The LLM providers currently return complete assistant messages, so final + answer chunks begin after the final LLM call completes. Tool-step events + are emitted immediately after each MCP tool call completes. + """ + trace_id = str(uuid.uuid4()) + logger.info("Streaming agent run started | trace_id=%s", trace_id) + logger.info("Task: %s", task) + + from agents.llm_factory import LLMMessage + + yield {"type": "meta", "trace_id": trace_id} + + try: + tools = await self._mcp.list_tools() + logger.info("Discovered %d MCP tools", len(tools)) + yield {"type": "status", "message": f"Discovered {len(tools)} MCP tools"} + except Exception as exc: + error = f"Failed to list MCP tools: {exc}" + logger.error(error) + yield {"type": "error", "trace_id": trace_id, "message": error} + yield {"type": "done", "trace_id": trace_id, "success": False} + return + + messages: List[LLMMessage] = [ + LLMMessage(role="system", content=self._system_prompt), + LLMMessage(role="user", content=task), + ] + tool_failures: Dict[str, int] = {} + + for step_num in range(1, self._max_steps + 1): + logger.info("Streaming agent step %d / %d", step_num, self._max_steps) + yield {"type": "status", "message": f"Agent reasoning step {step_num}"} + + try: + llm_resp = self._llm.chat_with_tools(messages, tools) + except Exception as exc: + error = f"LLM error at step {step_num}: {exc}" + logger.error(error) + yield {"type": "error", "trace_id": trace_id, "message": error} + yield {"type": "done", "trace_id": trace_id, "success": False} + return + + messages.append(LLMMessage( + role="assistant", + content=llm_resp.content, + tool_calls=llm_resp.tool_calls, + )) + + if not llm_resp.wants_tool_call: + final_answer = llm_resp.content or "" + for chunk in _chunk_agent_text(final_answer): + yield {"type": "final_chunk", "content": chunk} + yield {"type": "done", "trace_id": trace_id, "success": True} + return + + abort_message: Optional[str] = None + for tc in llm_resp.tool_calls: + scrubbed_args = _redact_tool_args_for_log(tc.name, tc.arguments) + logger.info("Calling tool: %s | args=%s", tc.name, scrubbed_args) + + try: + tool_result_str = await self._mcp.call_tool(tc.name, tc.arguments) + logger.info("Tool %s returned: %.200s", tc.name, tool_result_str) + except Exception as exc: + tool_result_str = f"ERROR: {exc}" + logger.error("Tool %s failed: %s", tc.name, exc) + + yield { + "type": "step", + "step": step_num, + "tool": tc.name, + "args": tc.arguments, + "result": tool_result_str, + } + + llm_tool_content = truncate_tool_result_for_llm(tool_result_str) + messages.append(LLMMessage( + role="tool", + content=llm_tool_content, + tool_call_id=tc.id, + name=tc.name, + )) + + if _is_tool_failure(tool_result_str): + tool_failures[tc.name] = tool_failures.get(tc.name, 0) + 1 + if tool_failures[tc.name] >= self._max_tool_failures: + abort_message = _tool_failure_abort_message(tc.name, self._max_tool_failures) + logger.warning("Stopping streaming agent: %s", abort_message) + break + + if abort_message: + for chunk in _chunk_agent_text(abort_message): + yield {"type": "final_chunk", "content": chunk} + yield {"type": "done", "trace_id": trace_id, "success": False} + return + + error = f"Agent reached max_steps ({self._max_steps}) without completing the task." + logger.warning(error) + for chunk in _chunk_agent_text(error): + yield {"type": "final_chunk", "content": chunk} + yield {"type": "done", "trace_id": trace_id, "success": False} + # --------------------------------------------------------------------------- # CLI entrypoint diff --git a/tests/test_factory_and_rest.py b/tests/test_factory_and_rest.py index bd00912..1a537b7 100644 --- a/tests/test_factory_and_rest.py +++ b/tests/test_factory_and_rest.py @@ -28,6 +28,25 @@ def test_health_endpoint(): assert resp.json() == {"status": "ok"} +def test_agent_transport_defaults_to_stdio(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv("NW_MCP_TRANSPORT", raising=False) + client = TestClient(app) + resp = client.get("/scenarios/agent-transport") + assert resp.status_code == 200 + assert resp.json() == {"transport": "stdio", "label": "stdio"} + + +def test_agent_transport_reports_streamable_http(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("NW_MCP_TRANSPORT", "streamable-http") + client = TestClient(app) + resp = client.get("/scenarios/agent-transport") + assert resp.status_code == 200 + assert resp.json() == { + "transport": "streamable-http", + "label": "Streamable HTTP", + } + + def test_rest_post_without_auth_returns_401_when_key_required(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delenv("NW_REST_AUTH_DISABLED", raising=False) monkeypatch.delenv("NW_REST_JWT_SECRET", raising=False) From 41e9a7d6807eef476c0bee2ab28286a175991549 Mon Sep 17 00:00:00 2001 From: gokulvg-aot <92848578+GokulVGAot@users.noreply.github.com> Date: Mon, 4 May 2026 09:26:41 +0530 Subject: [PATCH 21/60] mcp-auth-scenario-tested (#27) * mcp-auth-scenario-tested * auth verification tested successfully and verified * identity propagated to runtime and MCP scope policy is global issue resolved --------- Co-authored-by: My Name --- playground/scenarios.py | 38 +- sample.env | 33 +- src/agents/mcp_entrypoint.py | 6 +- src/agents/toolhive.py | 67 +++- src/bindings/factory.py | 5 + src/bindings/mcp_server/auth.py | 378 ++++++++++-------- src/bindings/mcp_server/server.py | 15 +- src/bindings/rest_api/app.py | 16 +- src/bindings/rest_api/auth.py | 45 ++- src/bindings_entrypoint.py | 4 + src/node_wire_runtime/__init__.py | 3 + src/node_wire_runtime/caller_identity.py | 40 ++ .../policies/mcp_scope_policy.py | 118 ++++-- tests/conftest.py | 4 +- tests/test_factory_and_rest.py | 64 +++ tests/test_mcp_auth.py | 268 ++++++------- tests/test_scope_policy_transport.py | 60 +++ 17 files changed, 758 insertions(+), 406 deletions(-) create mode 100644 src/node_wire_runtime/caller_identity.py create mode 100644 tests/test_scope_policy_transport.py diff --git a/playground/scenarios.py b/playground/scenarios.py index 5de6f5d..f32c5be 100644 --- a/playground/scenarios.py +++ b/playground/scenarios.py @@ -1219,9 +1219,14 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: task = _build_agent_chat_task(payload) - # Determine MCP transport — try proxy first, fallback to local stdio - transport = _current_agent_transport() - urls = resolve_mcp_urls() if transport == "streamable-http" else [] + # Determine MCP transport — try proxy first, optionally fallback to local stdio. + # Default behavior surfaces proxy/auth errors directly in the UI so demos can + # show MCP failures (instead of silently falling back to stdio). + fallback_to_stdio = ( + (os.environ.get("PLAYGROUND_AGENT_PROXY_FALLBACK_TO_STDIO", "false").strip().lower()) + in {"1", "true", "yes", "on"} + ) + urls = resolve_mcp_urls() run_result = None if urls: @@ -1250,11 +1255,30 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: ) ) if proxy_incomplete: - logger.warning("Agent Chat | proxy incomplete, falling back to local stdio") - run_result = None + if fallback_to_stdio: + logger.warning("Agent Chat | proxy incomplete, falling back to local stdio") + run_result = None + else: + logger.warning( + "Agent Chat | proxy incomplete, returning proxy error to UI " + "(set PLAYGROUND_AGENT_PROXY_FALLBACK_TO_STDIO=true to fallback)" + ) except Exception as proxy_err: - logger.warning("Agent Chat | proxy error: %s — falling back to local stdio", proxy_err) - run_result = None + if fallback_to_stdio: + logger.warning("Agent Chat | proxy error: %s — falling back to local stdio", proxy_err) + run_result = None + else: + logger.warning( + "Agent Chat | proxy error: %s — returning error to UI " + "(set PLAYGROUND_AGENT_PROXY_FALLBACK_TO_STDIO=true to fallback)", + proxy_err, + ) + return AgentChatResponse( + reply=f"MCP proxy error: {proxy_err}", + steps=[], + trace_id=trace_id, + success=False, + ) if run_result is None: # Use local stdio transport diff --git a/sample.env b/sample.env index 722864f..beeb748 100644 --- a/sample.env +++ b/sample.env @@ -27,19 +27,30 @@ SMTP_PASSWORD=your-gmail-app-password STRIPE_API_KEY=sk_test_your_key_here # ToolHive / Agent Configuration +# Single MCP proxy URL (backward compatible) +TOOLHIVE_MCP_URL=http://localhost:8081/mcp +# Multi-server MCP URLs (comma-separated; preferred for per-connector servers) TOOLHIVE_MCP_URLS= +# Optional MCP auth credentials sent by the ToolHive client to MCP server +# TOOLHIVE_MCP_API_KEY=replace-with-your-mcp-api-key +# TOOLHIVE_MCP_BEARER_TOKEN=replace-with-jwt-or-api-key +# When false (recommended for demos), proxy errors are returned to UI directly. +# Set true to allow proxy failure fallback to local stdio MCP. +PLAYGROUND_AGENT_PROXY_FALLBACK_TO_STDIO=false # Cap MCP tool JSON size sent back to the LLM (Groq on-demand TPM); default 12000 # TOOLHIVE_MAX_TOOL_RESULT_CHARS=12000 # Native MCP Transport (for agents.mcp_entrypoint and per-connector MCP servers) # ----------------------------------------------------------------------------- -# NW_MCP_TRANSPORT: Selects the communication layer. +# NW_MCP_TRANSPORT: Selects the communication layer. # - stdio: (Default) Required for ToolHive proxying and Claude Desktop. # - streamable-http: Native HTTP/SSE transport for direct web integration. -NW_MCP_TRANSPORT=stdio +NW_MCP_TRANSPORT=streamable-http +NW_MCP_HOST=127.0.0.1 +NW_MCP_PATH=/mcp # NW_MCP_PORT: The port used only when NW_MCP_TRANSPORT=streamable-http. -# - Default: 8080. Ensure it does not conflict with REST API (port 8000). -NW_MCP_PORT=8080 +# - Default: 8081 in local demos. Ensure it does not conflict with REST API (port 8000). +NW_MCP_PORT=8081 # LLM Provider LLM_PROVIDER=groq @@ -61,11 +72,23 @@ ANTHROPIC_API_KEY=your-anthropic-api-key ANTHROPIC_MODEL=claude-3-5-haiku-20241022 # MCP auth (set AUTH_DISABLED=false outside local dev) -NW_MCP_AUTH_DISABLED=true +NW_MCP_AUTH_ENABLED=true NW_MCP_API_KEY=replace-with-strong-random-value NW_MCP_JWT_SECRET=replace-with-hs256-secret # Optional per-tool scope map JSON # NW_MCP_ACTION_SCOPE_MAP_JSON={"smtp.send_email":"mcp:smtp.send_email"} +# Example for FHIR + Google Drive policy gating: +# NW_MCP_ACTION_SCOPE_MAP_JSON={"fhir_epic.read_patient":"mcp:fhir.read_patient","fhir_cerner.read_patient":"mcp:fhir.read_patient","google_drive.files.upload":"mcp:gdrive.files.upload"} +# Scope hook is runtime-level. With current bindings, strict scope enforcement applies +# to identity-aware MCP/REST calls. gRPC enforcement is deferred until gRPC identity +# propagation is implemented. +# ToolHive bearer token is sent to MCP as Authorization + X-API-Key + _meta aliases. +# TOOLHIVE_MCP_BEARER_TOKEN= + +# REST auth for Playground demo (disable for local UI testing) +NW_REST_AUTH_DISABLED=true +NW_REST_LOAD_DOTENV=true +# REST JWTs (NW_REST_JWT_SECRET): claims sub, tenant_id, scopes propagate to connector.run(..., principal, tenant_id, scopes) for ScopePolicyHook # MCP contract (optional; Google Drive legacy payload `action: "upload"`) # NODE_WIRE_LEGACY_GDRIVE_ACTION_UPLOAD=warn diff --git a/src/agents/mcp_entrypoint.py b/src/agents/mcp_entrypoint.py index 77b1887..3e7e65f 100644 --- a/src/agents/mcp_entrypoint.py +++ b/src/agents/mcp_entrypoint.py @@ -6,8 +6,10 @@ from dotenv import load_dotenv -load_dotenv() -load_dotenv(os.path.join(os.path.dirname(__file__), ".env")) +# Override inherited shell env so MCP auth/policy settings in project .env +# are applied predictably across local runs. +load_dotenv(override=True) +load_dotenv(os.path.join(os.path.dirname(__file__), ".env"), override=True) logging.basicConfig(level=logging.INFO) logger = logging.getLogger("agents.mcp_entrypoint") diff --git a/src/agents/toolhive.py b/src/agents/toolhive.py index a0fdbb3..5ce86ad 100644 --- a/src/agents/toolhive.py +++ b/src/agents/toolhive.py @@ -23,6 +23,8 @@ Environment variables: TOOLHIVE_MCP_URL : MCP proxy URL from ToolHive UI (e.g. http://localhost:PORT/mcp) TOOLHIVE_MCP_URLS: Comma-separated MCP proxy URLs (multi-server) + TOOLHIVE_MCP_API_KEY: Optional inbound MCP auth key (sent as Bearer + X-API-Key) + TOOLHIVE_MCP_BEARER_TOKEN: Optional inbound MCP bearer token (JWT/API key) TOOLHIVE_MAX_TOOL_FAILURES: Stop after this many failed invocations per tool name (default: 2) LLM_PROVIDER : groq | openai | gemini | anthropic (default: groq) GROQ_API_KEY : (when using groq) @@ -225,6 +227,43 @@ def __init__(self, base_url: str) -> None: self._base_url = base_url.rstrip("/") self._session_id: Optional[str] = None self._initialized: bool = False + self._auth_token: Optional[str] = ( + os.environ.get("TOOLHIVE_MCP_BEARER_TOKEN") + or os.environ.get("TOOLHIVE_MCP_API_KEY") + ) + + def _build_request_headers(self) -> Dict[str, str]: + headers: Dict[str, str] = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + } + if self._session_id: + headers["Mcp-Session-Id"] = self._session_id + # For MCP auth-gated servers, send both forms for compatibility. + if self._auth_token: + headers["Authorization"] = f"Bearer {self._auth_token}" + headers["X-API-Key"] = self._auth_token + return headers + + def _inject_auth_meta(self, params: Dict[str, Any]) -> Dict[str, Any]: + if not self._auth_token: + return dict(params) + out = dict(params) + meta = out.get("_meta") + if isinstance(meta, dict): + merged_meta = dict(meta) + else: + merged_meta = {} + # Include common aliases to maximize compatibility with MCP servers. + merged_meta.setdefault("authorization", f"Bearer {self._auth_token}") + merged_meta.setdefault("Authorization", f"Bearer {self._auth_token}") + merged_meta.setdefault("x-api-key", self._auth_token) + merged_meta.setdefault("X-API-Key", self._auth_token) + merged_meta.setdefault("token", self._auth_token) + merged_meta.setdefault("api_key", self._auth_token) + merged_meta.setdefault("apiKey", self._auth_token) + out["_meta"] = merged_meta + return out async def _initialize(self) -> None: """Send MCP initialize + initialized handshake; store session ID.""" @@ -244,10 +283,7 @@ async def _initialize(self) -> None: resp = await client.post( self._base_url, json=init_payload, - headers={ - "Content-Type": "application/json", - "Accept": "application/json, text/event-stream", - }, + headers=self._build_request_headers(), ) resp.raise_for_status() session_id = resp.headers.get("Mcp-Session-Id") @@ -259,14 +295,8 @@ async def _initialize(self) -> None: # Send the initialized notification (fire-and-forget; no id = notification) notif = {"jsonrpc": "2.0", "method": "notifications/initialized"} - headers: Dict[str, str] = { - "Content-Type": "application/json", - "Accept": "application/json, text/event-stream", - } - if self._session_id: - headers["Mcp-Session-Id"] = self._session_id try: - await client.post(self._base_url, json=notif, headers=headers) + await client.post(self._base_url, json=notif, headers=self._build_request_headers()) except Exception: pass # Notifications have no response; ignore transport errors @@ -283,19 +313,14 @@ async def _rpc(self, method: str, params: Dict[str, Any]) -> Any: "id": str(uuid.uuid4()), "method": method, } - if params: - payload["params"] = params - - headers: Dict[str, str] = { - "Content-Type": "application/json", - "Accept": "application/json, text/event-stream", - } - if self._session_id: - headers["Mcp-Session-Id"] = self._session_id + # Always include params when auth token is present so _meta is sent even + # for methods like tools/list that otherwise pass {}. + if params or self._auth_token: + payload["params"] = self._inject_auth_meta(params) url = self._base_url async with httpx.AsyncClient(timeout=60.0) as client: - resp = await client.post(url, json=payload, headers=headers) + resp = await client.post(url, json=payload, headers=self._build_request_headers()) resp.raise_for_status() data = resp.json() if "error" in data: diff --git a/src/bindings/factory.py b/src/bindings/factory.py index 217edad..63267bd 100644 --- a/src/bindings/factory.py +++ b/src/bindings/factory.py @@ -78,7 +78,12 @@ def _build_secret_provider() -> SecretProvider: def _build_policy_hook() -> PolicyHook | None: action_scope_map = load_scope_map_from_env() if not action_scope_map: + logger.info("Policy hook disabled (no action scope map)") return None + logger.info( + "Policy hook enabled", + extra={"scope_map_entries": len(action_scope_map)}, + ) return ScopePolicyHook(action_scope_map) diff --git a/src/bindings/mcp_server/auth.py b/src/bindings/mcp_server/auth.py index 67d49f0..b368589 100644 --- a/src/bindings/mcp_server/auth.py +++ b/src/bindings/mcp_server/auth.py @@ -1,177 +1,201 @@ -from __future__ import annotations - -import os -from dataclasses import dataclass -from typing import Any, Mapping - -import jwt - - -def _truthy(val: str | None) -> bool: - if val is None: - return False - return val.strip().lower() in ("1", "true", "yes", "on") - - -@dataclass(frozen=True) -class McpIdentity: - principal: str - tenant_id: str | None - scopes: tuple[str, ...] - claims: Mapping[str, Any] - auth_type: str - - -class McpAuthError(PermissionError): - def __init__( - self, - detail: str, - *, - status_code: int, - error_code: str, - www_authenticate: str | None = None, - ) -> None: - super().__init__(detail) - self.detail = detail - self.status_code = status_code - self.error_code = error_code - self.www_authenticate = www_authenticate - - def to_payload(self) -> dict[str, Any]: - payload: dict[str, Any] = { - "detail": self.detail, - "error_code": self.error_code, - "status_code": self.status_code, - } - if self.www_authenticate: - payload["www_authenticate"] = self.www_authenticate - return payload - - -class McpAuthRequiredError(McpAuthError): - def __init__(self) -> None: - super().__init__( - "Authentication required", - status_code=401, - error_code="MCP_AUTH_REQUIRED", - www_authenticate='Bearer realm="node-wire"', - ) - - -class McpAuthInvalidError(McpAuthError): - def __init__(self) -> None: - super().__init__( - "Invalid API key or token", - status_code=403, - error_code="MCP_AUTH_INVALID", - www_authenticate='Bearer realm="node-wire"', - ) - - -class McpAuthNotConfiguredError(McpAuthError): - def __init__(self) -> None: - super().__init__( - ( - "MCP authentication is not configured. Set NW_MCP_API_KEY " - "(and optionally NW_MCP_JWT_SECRET), or set NW_MCP_AUTH_DISABLED=true " - "for local development only." - ), - status_code=503, - error_code="MCP_AUTH_NOT_CONFIGURED", - ) - - -def mcp_auth_disabled() -> bool: - return _truthy(os.environ.get("NW_MCP_AUTH_DISABLED")) - - -def mcp_auth_configured() -> bool: - return bool(os.environ.get("NW_MCP_API_KEY") or os.environ.get("NW_MCP_JWT_SECRET")) - - -def _get_meta_value(meta: Mapping[str, Any] | None, keys: tuple[str, ...]) -> str | None: - if not meta: - return None - for key in keys: - val = meta.get(key) - if isinstance(val, str) and val.strip(): - return val.strip() - return None - - -def extract_token( - *, - headers: Mapping[str, Any] | None = None, - meta: Mapping[str, Any] | None = None, -) -> str | None: - if headers: - auth = headers.get("authorization") or headers.get("Authorization") - if isinstance(auth, str) and auth.lower().startswith("bearer "): - return auth[7:].strip() - x_api_key = headers.get("x-api-key") or headers.get("X-API-Key") - if isinstance(x_api_key, str) and x_api_key.strip(): - return x_api_key.strip() - - auth_meta = _get_meta_value(meta, ("authorization", "Authorization")) - if auth_meta and auth_meta.lower().startswith("bearer "): - return auth_meta[7:].strip() - - return _get_meta_value(meta, ("x-api-key", "X-API-Key", "api_key", "apiKey", "token")) - - -def verify_mcp_token(token: str) -> tuple[dict[str, Any], str]: - api_key = os.getenv("NW_MCP_API_KEY") - jwt_secret = os.getenv("NW_MCP_JWT_SECRET") - - if api_key and token == api_key: - return ({"sub": "api-key-user", "tenant_id": None, "scopes": ["*"]}, "api_key") - - if jwt_secret and token.count(".") == 2: - try: - claims = jwt.decode(token, jwt_secret, algorithms=["HS256"]) - return (claims, "jwt") - except jwt.PyJWTError as exc: - raise McpAuthInvalidError() from exc - - raise McpAuthInvalidError() - - -def build_identity(claims: Mapping[str, Any], auth_type: str) -> McpIdentity: - principal = str(claims.get("sub") or claims.get("client_id") or "unknown") - tenant_val = claims.get("tenant_id") - tenant_id = str(tenant_val) if tenant_val is not None else None - raw_scopes = claims.get("scopes") - if raw_scopes is None: - raw_scopes = claims.get("scope") - if isinstance(raw_scopes, str): - scopes = tuple(s for s in raw_scopes.split(" ") if s) - elif isinstance(raw_scopes, (list, tuple, set)): - scopes = tuple(str(s) for s in raw_scopes if str(s).strip()) - else: - scopes = tuple() - return McpIdentity( - principal=principal, - tenant_id=tenant_id, - scopes=scopes, - claims=dict(claims), - auth_type=auth_type, - ) - - -def authenticate_mcp_request( - *, - headers: Mapping[str, Any] | None = None, - meta: Mapping[str, Any] | None = None, -) -> McpIdentity | None: - if mcp_auth_disabled(): - return None - - if not mcp_auth_configured(): - raise McpAuthNotConfiguredError() - - token = extract_token(headers=headers, meta=meta) - if not token: - raise McpAuthRequiredError() - - claims, auth_type = verify_mcp_token(token) - return build_identity(claims, auth_type) +from __future__ import annotations + +import os +import logging +from pathlib import Path +from typing import Any, Mapping + +import jwt +from dotenv import load_dotenv + +from node_wire_runtime.caller_identity import CallerIdentity, build_caller_identity + +logger = logging.getLogger("bindings.mcp_server.auth") + +# Back-compat: callers may still import ``McpIdentity`` / ``build_identity`` from MCP auth. +McpIdentity = CallerIdentity + + +def _truthy(val: str | None) -> bool: + if val is None: + return False + return val.strip().lower() in ("1", "true", "yes", "on") + + +class McpAuthError(PermissionError): + def __init__( + self, + detail: str, + *, + status_code: int, + error_code: str, + www_authenticate: str | None = None, + ) -> None: + super().__init__(detail) + self.detail = detail + self.status_code = status_code + self.error_code = error_code + self.www_authenticate = www_authenticate + + def to_payload(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "detail": self.detail, + "error_code": self.error_code, + "status_code": self.status_code, + } + if self.www_authenticate: + payload["www_authenticate"] = self.www_authenticate + return payload + + +class McpAuthRequiredError(McpAuthError): + def __init__(self) -> None: + super().__init__( + "Authentication required", + status_code=401, + error_code="MCP_AUTH_REQUIRED", + www_authenticate='Bearer realm="node-wire"', + ) + + +class McpAuthInvalidError(McpAuthError): + def __init__(self) -> None: + super().__init__( + "Invalid API key or token", + status_code=403, + error_code="MCP_AUTH_INVALID", + www_authenticate='Bearer realm="node-wire"', + ) + + +class McpAuthNotConfiguredError(McpAuthError): + def __init__(self) -> None: + super().__init__( + ( + "MCP authentication is not configured. Set NW_MCP_API_KEY " + "(and optionally NW_MCP_JWT_SECRET), or set NW_MCP_AUTH_ENABLED=true " + "for local development only." + ), + status_code=503, + error_code="MCP_AUTH_NOT_CONFIGURED", + ) + + +_mcp_auth_env_bootstrapped = False + + +def _bootstrap_mcp_auth_env() -> None: + global _mcp_auth_env_bootstrapped + if _mcp_auth_env_bootstrapped: + return + + # Some launch paths on Windows can miss .env loading for the MCP worker. + # If MCP auth vars are missing/empty, try loading project .env once. + if os.environ.get("NW_MCP_API_KEY") or os.environ.get("NW_MCP_JWT_SECRET"): + _mcp_auth_env_bootstrapped = True + return + + repo_root_env = Path(__file__).resolve().parents[3] / ".env" + load_dotenv(override=True) + load_dotenv(repo_root_env, override=True) + _mcp_auth_env_bootstrapped = True + + +def mcp_auth_disabled() -> bool: + return _truthy(os.environ.get("NW_MCP_AUTH_ENABLED")) + + +def mcp_auth_configured() -> bool: + _bootstrap_mcp_auth_env() + return bool(os.environ.get("NW_MCP_API_KEY") or os.environ.get("NW_MCP_JWT_SECRET")) + + +def _get_meta_value(meta: Mapping[str, Any] | None, keys: tuple[str, ...]) -> str | None: + if not meta: + return None + for key in keys: + val = meta.get(key) + if isinstance(val, str) and val.strip(): + return val.strip() + return None + + +def extract_token( + *, + headers: Mapping[str, Any] | None = None, + meta: Mapping[str, Any] | None = None, +) -> str | None: + if headers: + auth = headers.get("authorization") or headers.get("Authorization") + if isinstance(auth, str) and auth.lower().startswith("bearer "): + return auth[7:].strip() + x_api_key = headers.get("x-api-key") or headers.get("X-API-Key") + if isinstance(x_api_key, str) and x_api_key.strip(): + return x_api_key.strip() + + auth_meta = _get_meta_value(meta, ("authorization", "Authorization")) + if auth_meta and auth_meta.lower().startswith("bearer "): + return auth_meta[7:].strip() + + return _get_meta_value(meta, ("x-api-key", "X-API-Key", "api_key", "apiKey", "token")) + + +def verify_mcp_token(token: str) -> tuple[dict[str, Any], str]: + api_key = os.getenv("NW_MCP_API_KEY") + jwt_secret = os.getenv("NW_MCP_JWT_SECRET") + + if api_key and token == api_key: + return ({"sub": "api-key-user", "tenant_id": None, "scopes": ["*"]}, "api_key") + + if jwt_secret and token.count(".") == 2: + try: + claims = jwt.decode(token, jwt_secret, algorithms=["HS256"]) + logger.info("MCP token verified as JWT") + return (claims, "jwt") + except jwt.PyJWTError as exc: + raise McpAuthInvalidError() from exc + + raise McpAuthInvalidError() + + +def build_identity(claims: Mapping[str, Any], auth_type: str) -> CallerIdentity: + """Deprecated alias for :func:`build_caller_identity`; prefer that name in new code.""" + return build_caller_identity(claims, auth_type) + + +def authenticate_mcp_request( + *, + headers: Mapping[str, Any] | None = None, + meta: Mapping[str, Any] | None = None, +) -> CallerIdentity | None: + logger.info( + "MCP auth gate status", + extra={ + "auth_disabled": mcp_auth_disabled(), + "auth_configured": mcp_auth_configured(), + "has_api_key": bool(os.environ.get("NW_MCP_API_KEY")), + "has_jwt_secret": bool(os.environ.get("NW_MCP_JWT_SECRET")), + }, + ) + if mcp_auth_disabled(): + return None + + if not mcp_auth_configured(): + raise McpAuthNotConfiguredError() + + token = extract_token(headers=headers, meta=meta) + if not token: + raise McpAuthRequiredError() + + claims, auth_type = verify_mcp_token(token) + identity = build_caller_identity(claims, auth_type) + logger.info( + "MCP auth accepted", + extra={ + "auth_type": identity.auth_type, + "principal": identity.principal, + "tenant_id": identity.tenant_id or "", + "scopes": list(identity.scopes), + }, + ) + return identity diff --git a/src/bindings/mcp_server/server.py b/src/bindings/mcp_server/server.py index 6478537..7dd165c 100644 --- a/src/bindings/mcp_server/server.py +++ b/src/bindings/mcp_server/server.py @@ -6,11 +6,8 @@ from typing import Any, Dict, List, Mapping, Optional from bindings.factory import ConnectorFactory -from bindings.mcp_server.auth import ( - McpAuthError, - McpIdentity, - authenticate_mcp_request, -) +from bindings.mcp_server.auth import McpAuthError, authenticate_mcp_request +from node_wire_runtime.caller_identity import CallerIdentity from node_wire_runtime.connector_registry import auto_register from node_wire_runtime.manifest import MCP_MANIFEST_CONTRACT_VERSION, build_manifest from node_wire_runtime import BaseConnector, ConnectorResponse, ErrorCategory @@ -54,7 +51,7 @@ def __init__( _pkg_ver, ) - def list_tools(self, *, identity: McpIdentity | None = None) -> List[Dict[str, Any]]: + def list_tools(self, *, identity: CallerIdentity | None = None) -> List[Dict[str, Any]]: self._ensure_identity(identity=identity) return self._list_tools_impl() @@ -87,9 +84,9 @@ def _list_tools_impl(self) -> List[Dict[str, Any]]: def _ensure_identity( self, *, - identity: McpIdentity | None, + identity: CallerIdentity | None, meta: Mapping[str, Any] | None = None, - ) -> McpIdentity | None: + ) -> CallerIdentity | None: if identity is not None: return identity return authenticate_mcp_request(meta=meta) @@ -117,7 +114,7 @@ async def invoke_tool( name: str, arguments: Dict[str, Any], *, - identity: McpIdentity | None = None, + identity: CallerIdentity | None = None, ) -> Dict[str, Any]: identity = self._ensure_identity(identity=identity) try: diff --git a/src/bindings/rest_api/app.py b/src/bindings/rest_api/app.py index 0b27064..0c65dff 100644 --- a/src/bindings/rest_api/app.py +++ b/src/bindings/rest_api/app.py @@ -7,14 +7,15 @@ import sys from pathlib import Path -from fastapi import Depends, FastAPI, HTTPException +from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from dotenv import load_dotenv # Production: set NW_REST_LOAD_DOTENV=false to rely on injected env only (no .env file). if os.environ.get("NW_REST_LOAD_DOTENV", "true").lower() not in ("0", "false", "no"): - load_dotenv() + # Override inherited shell env so local .env edits are honored consistently. + load_dotenv(override=True) from bindings.factory import ConnectorFactory from node_wire_runtime.connector_registry import auto_register @@ -25,7 +26,7 @@ from opentelemetry.trace import Status, StatusCode from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor -from bindings.rest_api.auth import RestAuthMiddleware +from bindings.rest_api.auth import RestAuthMiddleware, get_rest_caller_identity # Add project root to sys.path to allow importing from 'playground' package PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent @@ -78,6 +79,7 @@ def _http_status_for_category(category: ErrorCategory | None) -> int: def _make_endpoint(cid: str, act: str) -> Any: async def endpoint( + request: Request, payload: Dict[str, Any], factory_dep: ConnectorFactory = Depends(get_factory), ) -> JSONResponse: @@ -101,7 +103,13 @@ async def endpoint( run_payload["action"] = act # Let the runtime (Layer A) perform full schema validation. # Any validation errors will be mapped into ConnectorResponse. - response: ConnectorResponse = await connector.run(run_payload) + rest_id = get_rest_caller_identity(request) + response: ConnectorResponse = await connector.run( + run_payload, + principal=rest_id.principal if rest_id else None, + tenant_id=rest_id.tenant_id if rest_id else None, + scopes=rest_id.scopes if rest_id else None, + ) status = _http_status_for_category(response.error_category) if not response.success: diff --git a/src/bindings/rest_api/auth.py b/src/bindings/rest_api/auth.py index a8962b9..539eefe 100644 --- a/src/bindings/rest_api/auth.py +++ b/src/bindings/rest_api/auth.py @@ -7,6 +7,9 @@ NW_REST_AUTH_DISABLED — if ``true``/``1``/``yes``, skip auth (local dev only; do not use in production). Public (unauthenticated): ``GET /health`` only. OpenAPI UI requires auth. + +After successful auth, normalized caller identity (principal / tenant_id / scopes) is stored on +``request.state.nw_rest_caller_identity`` and forwarded to ``connector.run`` for policy hooks. """ from __future__ import annotations @@ -18,6 +21,16 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response +from node_wire_runtime.caller_identity import CallerIdentity, build_caller_identity + + +REST_CALLER_STATE_KEY = "nw_rest_caller_identity" + + +def get_rest_caller_identity(request: Request) -> CallerIdentity | None: + """Return JWT/API-key caller identity attached by middleware, if any.""" + return getattr(request.state, REST_CALLER_STATE_KEY, None) + def _truthy(val: str | None) -> bool: if val is None: @@ -42,16 +55,32 @@ def _extract_bearer_or_api_key(request: Request) -> str | None: return None -def _verify_token(token: str, *, api_key: str | None, jwt_secret: str | None) -> bool: +def verify_rest_token_and_identity( + token: str, + *, + api_key: str | None, + jwt_secret: str | None, +) -> tuple[bool, CallerIdentity | None]: + """ + Validate REST bearer/API-key token and build caller identity (same shape as MCP). + + Shared API key behaves like MCP: wildcard scopes for ScopePolicyHook compatibility. + """ if api_key and token == api_key: - return True + ident = build_caller_identity( + {"sub": "api-key-user", "tenant_id": None, "scopes": ["*"]}, + auth_type="rest_api_key", + ) + return True, ident + if jwt_secret and token.count(".") == 2: try: - jwt.decode(token, jwt_secret, algorithms=["HS256"]) - return True + claims = jwt.decode(token, jwt_secret, algorithms=["HS256"]) except jwt.PyJWTError: - return False - return False + return False, None + return True, build_caller_identity(claims, auth_type="jwt") + + return False, None class RestAuthMiddleware(BaseHTTPMiddleware): @@ -91,11 +120,13 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: headers={"WWW-Authenticate": 'Bearer realm="node-wire"'}, ) - if not _verify_token(token, api_key=api_key, jwt_secret=jwt_secret): + ok, identity = verify_rest_token_and_identity(token, api_key=api_key, jwt_secret=jwt_secret) + if not ok or identity is None: return JSONResponse( status_code=403, content={"detail": "Invalid API key or token"}, headers={"WWW-Authenticate": 'Bearer realm="node-wire"'}, ) + setattr(request.state, REST_CALLER_STATE_KEY, identity) return await call_next(request) diff --git a/src/bindings_entrypoint.py b/src/bindings_entrypoint.py index 7aa5ebb..db1bdd5 100644 --- a/src/bindings_entrypoint.py +++ b/src/bindings_entrypoint.py @@ -4,11 +4,15 @@ import os import uvicorn +from dotenv import load_dotenv from bindings.rest_api.app import app as rest_app from bindings.mcp_server.server import McpServer from node_wire_runtime.observability import init_observability +# Load project .env early so all modes (API/GRPC/MCP) see consistent config. +load_dotenv(override=True) + logging.basicConfig(level=logging.INFO) logger = logging.getLogger("bindings.entrypoint") logging.getLogger("opentelemetry.exporter.otlp.proto.http").setLevel(logging.DEBUG) diff --git a/src/node_wire_runtime/__init__.py b/src/node_wire_runtime/__init__.py index 401383d..99af6ba 100644 --- a/src/node_wire_runtime/__init__.py +++ b/src/node_wire_runtime/__init__.py @@ -2,6 +2,7 @@ from .errors import ErrorMapper from .secrets import SecretProvider, EnvSecretProvider, SecretNotFoundError, SecretProviderError from .policy import PolicyHook, PolicyDenied +from .caller_identity import CallerIdentity, build_caller_identity from .auth import AuthProvider, NoAuthProvider, StaticTokenAuthProvider, OAuth2AuthProvider, ServiceAccountAuthProvider from .base_connector import BaseConnector, nw_action, sdk_action, _CONNECTOR_REGISTRY from .sdk_action_spec import ( @@ -21,6 +22,8 @@ "SecretProviderError", "PolicyHook", "PolicyDenied", + "CallerIdentity", + "build_caller_identity", "AuthProvider", "NoAuthProvider", "StaticTokenAuthProvider", diff --git a/src/node_wire_runtime/caller_identity.py b/src/node_wire_runtime/caller_identity.py new file mode 100644 index 0000000..d1c7150 --- /dev/null +++ b/src/node_wire_runtime/caller_identity.py @@ -0,0 +1,40 @@ +"""Transport-neutral caller identity for connector execution and policy hooks.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Mapping + + +@dataclass(frozen=True) +class CallerIdentity: + """Who is calling ``connector.run`` (REST, MCP, or other bindings).""" + + principal: str + tenant_id: str | None + scopes: tuple[str, ...] + claims: Mapping[str, Any] + auth_type: str + + +def build_caller_identity(claims: Mapping[str, Any], auth_type: str) -> CallerIdentity: + """Build identity from JWT-style claims (``sub``, ``tenant_id``, ``scopes`` / ``scope``).""" + principal = str(claims.get("sub") or claims.get("client_id") or "unknown") + tenant_val = claims.get("tenant_id") + tenant_id = str(tenant_val) if tenant_val is not None else None + raw_scopes = claims.get("scopes") + if raw_scopes is None: + raw_scopes = claims.get("scope") + if isinstance(raw_scopes, str): + scopes = tuple(s for s in raw_scopes.split(" ") if s) + elif isinstance(raw_scopes, (list, tuple, set)): + scopes = tuple(str(s) for s in raw_scopes if str(s).strip()) + else: + scopes = tuple() + return CallerIdentity( + principal=principal, + tenant_id=tenant_id, + scopes=scopes, + claims=dict(claims), + auth_type=auth_type, + ) diff --git a/src/node_wire_runtime/policies/mcp_scope_policy.py b/src/node_wire_runtime/policies/mcp_scope_policy.py index a5694c5..a31d14f 100644 --- a/src/node_wire_runtime/policies/mcp_scope_policy.py +++ b/src/node_wire_runtime/policies/mcp_scope_policy.py @@ -1,38 +1,80 @@ -from __future__ import annotations - -import json -import os -from typing import Mapping - -from node_wire_runtime.policy import PolicyContext, PolicyDenied, PolicyHook - - -class ScopePolicyHook(PolicyHook): - def __init__(self, action_scope_map: Mapping[str, str]) -> None: - self._map = dict(action_scope_map) - - def check(self, context: PolicyContext) -> None: - required = self._map.get(f"{context.connector_id}.{context.action}") - if not required: - return - scopes = set(context.scopes or ()) - if required in scopes or "*" in scopes: - return - raise PolicyDenied(f"Missing required scope: {required}") - - -def load_scope_map_from_env() -> dict[str, str]: - raw = os.environ.get("NW_MCP_ACTION_SCOPE_MAP_JSON") - if not raw: - return {} - parsed = json.loads(raw) - if not isinstance(parsed, dict): - raise ValueError("NW_MCP_ACTION_SCOPE_MAP_JSON must be a JSON object.") - out: dict[str, str] = {} - for key, value in parsed.items(): - if not isinstance(key, str) or not isinstance(value, str): - raise ValueError( - "NW_MCP_ACTION_SCOPE_MAP_JSON must map string action keys to string scopes." - ) - out[key] = value - return out +from __future__ import annotations + +import json +import logging +import os +from pathlib import Path +from typing import Mapping + +from dotenv import load_dotenv + +from node_wire_runtime.policy import PolicyContext, PolicyDenied, PolicyHook + +logger = logging.getLogger("runtime.policy.scope") + + +class ScopePolicyHook(PolicyHook): + def __init__(self, action_scope_map: Mapping[str, str]) -> None: + self._map = dict(action_scope_map) + + def check(self, context: PolicyContext) -> None: + action_key = f"{context.connector_id}.{context.action}" + required = self._map.get(action_key) + scopes = tuple(context.scopes or ()) + # Defer transport-specific authz until caller identity is propagated. + # This prevents non-identity paths (e.g. current gRPC) from being + # denied solely because MCP scope map is configured. + if required and not context.principal and not scopes: + logger.info( + "Scope policy bypassed due to missing caller identity", + extra={ + "action_key": action_key, + "required_scope": required, + }, + ) + return + logger.info( + "Scope policy evaluating action", + extra={ + "action_key": action_key, + "required_scope": required or "", + "principal": context.principal or "", + "tenant_id": context.tenant_id or "", + "scopes": list(scopes), + }, + ) + if not required: + return + scope_set = set(scopes) + if required in scope_set or "*" in scope_set: + return + raise PolicyDenied(f"Missing required scope: {required}") + + +def load_scope_map_from_env() -> dict[str, str]: + raw = os.environ.get("NW_MCP_ACTION_SCOPE_MAP_JSON") + if not raw: + # Mirror MCP auth bootstrap behavior: recover config from project .env + # when launch paths inherit incomplete shell env. + repo_root_env = Path(__file__).resolve().parents[3] / ".env" + load_dotenv(override=True) + load_dotenv(repo_root_env, override=True) + raw = os.environ.get("NW_MCP_ACTION_SCOPE_MAP_JSON") + if not raw: + logger.info("Scope policy map not configured (env empty)") + return {} + parsed = json.loads(raw) + if not isinstance(parsed, dict): + raise ValueError("NW_MCP_ACTION_SCOPE_MAP_JSON must be a JSON object.") + out: dict[str, str] = {} + for key, value in parsed.items(): + if not isinstance(key, str) or not isinstance(value, str): + raise ValueError( + "NW_MCP_ACTION_SCOPE_MAP_JSON must map string action keys to string scopes." + ) + out[key] = value + logger.info( + "Scope policy map loaded", + extra={"entries": len(out), "action_keys": sorted(out.keys())}, + ) + return out diff --git a/tests/conftest.py b/tests/conftest.py index b2f760b..30e96c7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ """Shared pytest configuration. REST API tests default to ``NW_REST_AUTH_DISABLED=true`` so existing tests do not need -headers. MCP tests default to ``NW_MCP_AUTH_DISABLED=true`` for the same reason. +headers. MCP tests default to ``NW_MCP_AUTH_ENABLED=true`` for the same reason. Tests that assert authentication behavior override these env vars. """ from __future__ import annotations @@ -12,4 +12,4 @@ @pytest.fixture(autouse=True) def _rest_auth_disabled_for_tests(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("NW_REST_AUTH_DISABLED", "true") - monkeypatch.setenv("NW_MCP_AUTH_DISABLED", "true") + monkeypatch.setenv("NW_MCP_AUTH_ENABLED", "true") diff --git a/tests/test_factory_and_rest.py b/tests/test_factory_and_rest.py index 1a537b7..b567d60 100644 --- a/tests/test_factory_and_rest.py +++ b/tests/test_factory_and_rest.py @@ -2,6 +2,7 @@ from unittest.mock import AsyncMock, MagicMock +import jwt import pytest from fastapi.testclient import TestClient @@ -89,6 +90,69 @@ def test_rest_post_with_bearer_succeeds_when_key_required(monkeypatch: pytest.Mo assert r.status_code == 200 +def test_rest_post_propagates_api_key_identity_to_connector_run( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("NW_REST_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_REST_JWT_SECRET", raising=False) + monkeypatch.setenv("NW_REST_API_KEY", "unit-test-secret") + + mock_factory = MagicMock() + mock_factory.get_for_protocol.return_value = _stub_connector( + ConnectorResponse(success=True, data={}, trace_id="t-p") + ) + app.dependency_overrides[get_factory] = lambda: mock_factory + try: + client = TestClient(app) + client.post( + "/connectors/http_generic/request", + json={"method": "GET", "url": "https://example.com"}, + headers={"Authorization": "Bearer unit-test-secret"}, + ) + finally: + app.dependency_overrides.clear() + + stub = mock_factory.get_for_protocol.return_value + kwargs = stub.run.await_args.kwargs + assert kwargs["principal"] == "api-key-user" + assert kwargs["tenant_id"] is None + assert kwargs["scopes"] == ("*",) + + +def test_rest_post_propagates_jwt_claims_to_connector_run(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_REST_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_REST_API_KEY", raising=False) + secret = "rest-jwt-test-secret-at-least-32bytes!!" + monkeypatch.setenv("NW_REST_JWT_SECRET", secret) + + tok = jwt.encode( + {"sub": "alice", "tenant_id": "t-1", "scopes": ["mcp:test.scope"]}, + secret, + algorithm="HS256", + ) + + mock_factory = MagicMock() + mock_factory.get_for_protocol.return_value = _stub_connector( + ConnectorResponse(success=True, data={}, trace_id="t-j") + ) + app.dependency_overrides[get_factory] = lambda: mock_factory + try: + client = TestClient(app) + client.post( + "/connectors/http_generic/request", + json={"method": "GET", "url": "https://example.com"}, + headers={"Authorization": f"Bearer {tok}"}, + ) + finally: + app.dependency_overrides.clear() + + stub = mock_factory.get_for_protocol.return_value + kw = stub.run.await_args.kwargs + assert kw["principal"] == "alice" + assert kw["tenant_id"] == "t-1" + assert kw["scopes"] == ("mcp:test.scope",) + + def test_rest_not_configured_returns_503_when_no_key_and_not_disabled(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delenv("NW_REST_AUTH_DISABLED", raising=False) monkeypatch.delenv("NW_REST_API_KEY", raising=False) diff --git a/tests/test_mcp_auth.py b/tests/test_mcp_auth.py index e9ac912..acff2da 100644 --- a/tests/test_mcp_auth.py +++ b/tests/test_mcp_auth.py @@ -1,134 +1,134 @@ -from __future__ import annotations - -import jwt -import pytest - -from bindings.mcp_server.auth import ( - McpAuthInvalidError, - McpAuthRequiredError, - authenticate_mcp_request, -) -from bindings.mcp_server.server import McpServer - - -def test_mcp_auth_missing_token_returns_401(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) - monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") - monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) - - with pytest.raises(McpAuthRequiredError) as exc_info: - authenticate_mcp_request() - assert exc_info.value.status_code == 401 - assert exc_info.value.detail == "Authentication required" - - -def test_mcp_auth_invalid_token_returns_403(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) - monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") - monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) - - with pytest.raises(McpAuthInvalidError) as exc_info: - authenticate_mcp_request(meta={"token": "wrong-secret"}) - assert exc_info.value.status_code == 403 - assert exc_info.value.detail == "Invalid API key or token" - - -def test_mcp_auth_valid_token_allows_tools_list(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) - monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") - monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) - - identity = authenticate_mcp_request(meta={"token": "unit-test-secret"}) - assert identity is not None - - server = McpServer(connector_ids=["smtp"]) - tools = server.list_tools(identity=identity) - assert any(t["name"] == "smtp.send_email" for t in tools) - - -@pytest.mark.asyncio -async def test_mcp_authz_denies_tool_without_scope(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) - monkeypatch.delenv("NW_MCP_API_KEY", raising=False) - monkeypatch.setenv("NW_MCP_JWT_SECRET", "jwt-secret") - monkeypatch.setenv( - "NW_MCP_ACTION_SCOPE_MAP_JSON", - '{"smtp.send_email":"mcp:smtp.send_email"}', - ) - - token = jwt.encode( - {"sub": "alice", "tenant_id": "tenant-a", "scopes": ["mcp:other.scope"]}, - "jwt-secret", - algorithm="HS256", - ) - identity = authenticate_mcp_request(meta={"authorization": f"Bearer {token}"}) - assert identity is not None - - server = McpServer(connector_ids=["smtp"]) - resp = await server.invoke_tool( - "smtp.send_email", - { - "from_email": "sender@example.com", - "to": ["recipient@example.com"], - "subject": "x", - "body": "y", - }, - identity=identity, - ) - - assert resp["success"] is False - assert resp["error_code"] == "POLICY_DENIED" - assert resp["message"] == "Missing required scope: mcp:smtp.send_email" - - -@pytest.mark.asyncio -async def test_mcp_execution_passes_principal_and_tenant( - monkeypatch: pytest.MonkeyPatch, -) -> None: - monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) - monkeypatch.delenv("NW_MCP_API_KEY", raising=False) - monkeypatch.setenv("NW_MCP_JWT_SECRET", "jwt-secret") - monkeypatch.delenv("NW_MCP_ACTION_SCOPE_MAP_JSON", raising=False) - - token = jwt.encode( - {"sub": "service-account", "tenant_id": "tenant-42", "scopes": ["*"]}, - "jwt-secret", - algorithm="HS256", - ) - identity = authenticate_mcp_request(meta={"authorization": f"Bearer {token}"}) - assert identity is not None - - server = McpServer(connector_ids=["smtp"]) - smtp = server._factory.get_for_protocol("smtp", "mcp") - assert smtp is not None - - captured: dict[str, object] = {} - - async def fake_run(raw_input, *, principal=None, tenant_id=None, scopes=None): - captured["payload"] = dict(raw_input) - captured["principal"] = principal - captured["tenant_id"] = tenant_id - captured["scopes"] = tuple(scopes or ()) - from node_wire_runtime.models import ConnectorResponse - - return ConnectorResponse(success=True, data={"ok": True}, trace_id="trace-test") - - orig_run = smtp.run - try: - smtp.run = fake_run - await server.invoke_tool( - "smtp.send_email", - { - "from_email": "sender@example.com", - "to": ["recipient@example.com"], - "subject": "x", - "body": "y", - }, - identity=identity, - ) - finally: - smtp.run = orig_run - - assert captured["principal"] == "service-account" - assert captured["tenant_id"] == "tenant-42" - assert captured["scopes"] == ("*",) +from __future__ import annotations + +import jwt +import pytest + +from bindings.mcp_server.auth import ( + McpAuthInvalidError, + McpAuthRequiredError, + authenticate_mcp_request, +) +from bindings.mcp_server.server import McpServer + + +def test_mcp_auth_missing_token_returns_401(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + with pytest.raises(McpAuthRequiredError) as exc_info: + authenticate_mcp_request() + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Authentication required" + + +def test_mcp_auth_invalid_token_returns_403(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + with pytest.raises(McpAuthInvalidError) as exc_info: + authenticate_mcp_request(meta={"token": "wrong-secret"}) + assert exc_info.value.status_code == 403 + assert exc_info.value.detail == "Invalid API key or token" + + +def test_mcp_auth_valid_token_allows_tools_list(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + identity = authenticate_mcp_request(meta={"token": "unit-test-secret"}) + assert identity is not None + + server = McpServer(connector_ids=["smtp"]) + tools = server.list_tools(identity=identity) + assert any(t["name"] == "smtp.send_email" for t in tools) + + +@pytest.mark.asyncio +async def test_mcp_authz_denies_tool_without_scope(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.delenv("NW_MCP_API_KEY", raising=False) + monkeypatch.setenv("NW_MCP_JWT_SECRET", "jwt-secret") + monkeypatch.setenv( + "NW_MCP_ACTION_SCOPE_MAP_JSON", + '{"smtp.send_email":"mcp:smtp.send_email"}', + ) + + token = jwt.encode( + {"sub": "alice", "tenant_id": "tenant-a", "scopes": ["mcp:other.scope"]}, + "jwt-secret", + algorithm="HS256", + ) + identity = authenticate_mcp_request(meta={"authorization": f"Bearer {token}"}) + assert identity is not None + + server = McpServer(connector_ids=["smtp"]) + resp = await server.invoke_tool( + "smtp.send_email", + { + "from_email": "sender@example.com", + "to": ["recipient@example.com"], + "subject": "x", + "body": "y", + }, + identity=identity, + ) + + assert resp["success"] is False + assert resp["error_code"] == "POLICY_DENIED" + assert resp["message"] == "Missing required scope: mcp:smtp.send_email" + + +@pytest.mark.asyncio +async def test_mcp_execution_passes_principal_and_tenant( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.delenv("NW_MCP_API_KEY", raising=False) + monkeypatch.setenv("NW_MCP_JWT_SECRET", "jwt-secret") + monkeypatch.delenv("NW_MCP_ACTION_SCOPE_MAP_JSON", raising=False) + + token = jwt.encode( + {"sub": "service-account", "tenant_id": "tenant-42", "scopes": ["*"]}, + "jwt-secret", + algorithm="HS256", + ) + identity = authenticate_mcp_request(meta={"authorization": f"Bearer {token}"}) + assert identity is not None + + server = McpServer(connector_ids=["smtp"]) + smtp = server._factory.get_for_protocol("smtp", "mcp") + assert smtp is not None + + captured: dict[str, object] = {} + + async def fake_run(raw_input, *, principal=None, tenant_id=None, scopes=None): + captured["payload"] = dict(raw_input) + captured["principal"] = principal + captured["tenant_id"] = tenant_id + captured["scopes"] = tuple(scopes or ()) + from node_wire_runtime.models import ConnectorResponse + + return ConnectorResponse(success=True, data={"ok": True}, trace_id="trace-test") + + orig_run = smtp.run + try: + smtp.run = fake_run + await server.invoke_tool( + "smtp.send_email", + { + "from_email": "sender@example.com", + "to": ["recipient@example.com"], + "subject": "x", + "body": "y", + }, + identity=identity, + ) + finally: + smtp.run = orig_run + + assert captured["principal"] == "service-account" + assert captured["tenant_id"] == "tenant-42" + assert captured["scopes"] == ("*",) diff --git a/tests/test_scope_policy_transport.py b/tests/test_scope_policy_transport.py new file mode 100644 index 0000000..972c654 --- /dev/null +++ b/tests/test_scope_policy_transport.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import asyncio +from typing import Literal + +from pydantic import BaseModel + +from node_wire_runtime import BaseConnector, nw_action +from node_wire_runtime.policies.mcp_scope_policy import ScopePolicyHook + + +class _Input(BaseModel): + action: Literal["read_patient"] = "read_patient" + resource_id: str + + +class _Output(BaseModel): + ok: bool + + +class _PolicyTestConnector(BaseConnector): + connector_id = "fhir_epic" + output_model = _Output + + @nw_action("read_patient") + async def read_patient(self, params: _Input, *, trace_id: str) -> _Output: + return _Output(ok=True) + + +def _connector_with_scope_map() -> _PolicyTestConnector: + return _PolicyTestConnector( + policy_hook=ScopePolicyHook({"fhir_epic.read_patient": "mcp:fhir.read_patient"}) + ) + + +def test_scope_policy_bypasses_when_identity_missing_like_grpc() -> None: + connector = _connector_with_scope_map() + response = asyncio.run( + connector.run({"action": "read_patient", "resource_id": "x"}) + ) + + assert response.success is True + assert response.error_code is None + + +def test_scope_policy_denies_when_identity_present_without_required_scope() -> None: + connector = _connector_with_scope_map() + response = asyncio.run( + connector.run( + {"action": "read_patient", "resource_id": "x"}, + principal="alice", + tenant_id="tenant-1", + scopes=("mcp:other.scope",), + ) + ) + + assert response.success is False + assert response.error_code == "POLICY_DENIED" + assert response.message == "Missing required scope: mcp:fhir.read_patient" + From ab75ce91057be702e33fa4a25d5a8d189393e87b Mon Sep 17 00:00:00 2001 From: Rahul Ap Date: Tue, 5 May 2026 13:10:26 +0530 Subject: [PATCH 22/60] Cnp 49 migrate stripe connector to node wire (#23) * Add Stripe connector with payment and subscription functionalities * Refactor error handling and response structure in StripeConnector methods * Implement Stripe payment and subscription management features in the playground * Add idempotency_key to ChargeInput, CancelSubscriptionInput, CreatePaymentIntentInput, CreateSubscriptionInput, and IssueRefundInput for duplicate operation prevention * Fix StripeConnector to handle optional price_id and update tests for charge action and REST exposure * Refactor StripeConnector to use asyncio for retrieving SetupIntent, Invoice, and PaymentIntent, improving performance and responsiveness. * Replace Stripe connector icon with SVG for improved visual consistency --- README.md | 2 +- config/connectors.yaml | 4 +- docker-compose.mcp.yml | 7 + docker/stripe/Dockerfile | 34 +++ docs/local-packages-to-images.md | 6 +- docs/mcp-servers.md | 13 +- docs/packaging.md | 1 + docs/toolhive_agent_scenario.md | 2 + playground/app.js | 132 ++++++++++++ playground/index.html | 121 ++++++++++- playground/scenarios.py | 290 ++++++++++++++++++++++++++ scripts/build-mcp-images.sh | 6 + src/agents/stripe_mcp.py | 27 +++ src/node_wire_stripe/README.md | 68 ++++++ src/node_wire_stripe/logic.py | 264 +++++++++++++++++++++-- src/node_wire_stripe/schema.py | 117 +++++++++-- tests/test_base_connector_manifest.py | 4 +- tests/test_connectors_basic.py | 2 +- tests/test_connectors_io.py | 2 +- tests/test_factory_and_rest.py | 2 +- tests/test_stripe.py | 203 ++++++++++++++++++ 21 files changed, 1255 insertions(+), 52 deletions(-) create mode 100644 docker/stripe/Dockerfile create mode 100644 src/agents/stripe_mcp.py create mode 100644 src/node_wire_stripe/README.md create mode 100644 tests/test_stripe.py diff --git a/README.md b/README.md index 132f768..93cea35 100644 --- a/README.md +++ b/README.md @@ -91,7 +91,7 @@ The platform is split into three layers: |----------------|--------------------------------------------------|---------------|-----------------------------| | **http_generic** | Generic HTTP request (any URL, method, headers) | `request` | rest, grpc, mcp | | **smtp** | Send email via SMTP | `send_email` | rest, grpc, mcp | -| **stripe** | Stripe charge | `charge` | grpc, mcp (no rest in config)| +| **stripe** | Multi-action: `charge`, `create_payment_intent`, `create_subscription`, `cancel_subscription`, `issue_refund` | `rest`, `grpc`, `mcp` | | **google_drive**| Google Drive (list, create, get, update, upload, delete, permissions) | `execute` (payload discriminator) | rest, grpc, mcp | | **fhir_epic** | FHIR R4 integration for Epic (multi-action) | `read_patient`, `search_patients`, `search_encounter`, `create_document_reference`, `search_document_reference` | rest, grpc, mcp | | **fhir_cerner** | FHIR R4 integration for Cerner (multi-action) | `read_patient`, `search_patients`, `search_encounter`, `create_document_reference`, `search_document_reference` | rest, grpc, mcp | diff --git a/config/connectors.yaml b/config/connectors.yaml index 6f350b5..2d0f069 100644 --- a/config/connectors.yaml +++ b/config/connectors.yaml @@ -26,7 +26,7 @@ connectors: stripe: enabled: true - exposed_via: ["grpc", "mcp"] + exposed_via: ["rest", "grpc", "mcp"] auth: provider: static_token secret_key: stripe_api_key @@ -72,4 +72,4 @@ connectors: - system/Patient.read - system/Encounter.read - system/DocumentReference.read - - system/DocumentReference.write + - system/DocumentReference.write \ No newline at end of file diff --git a/docker-compose.mcp.yml b/docker-compose.mcp.yml index 1004b3f..e4024df 100644 --- a/docker-compose.mcp.yml +++ b/docker-compose.mcp.yml @@ -26,3 +26,10 @@ services: stdin_open: true tty: true restart: unless-stopped + + nw-stripe: + image: nw-stripe:latest + env_file: .env + stdin_open: true + tty: true + restart: unless-stopped diff --git a/docker/stripe/Dockerfile b/docker/stripe/Dockerfile new file mode 100644 index 0000000..552c46a --- /dev/null +++ b/docker/stripe/Dockerfile @@ -0,0 +1,34 @@ +FROM python:3.12-slim@sha256:3d5ed973e45820f5ba5e46bd065bd88b3a504ff0724d85980dcd05eab361fcf4 + +LABEL org.opencontainers.image.title="nw-stripe" \ + org.opencontainers.image.description="Node Wire — Stripe MCP server" \ + org.opencontainers.image.source="https://github.com/AOT-Technologies/node-wire" + +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY src/ ./src/ +COPY config/ ./config/ +COPY packages/runtime/dist/*.whl /wheels/ +COPY packages/connectors/stripe/dist/*.whl /wheels/ + +ENV PYTHONPATH=/app/src \ + NW_ALLOWED_CONNECTORS=stripe + +RUN pip install --no-cache-dir --find-links=/wheels \ + node-wire-runtime node-wire-stripe "mcp>=1.6.0" \ + && rm -rf /wheels + +RUN groupadd --system --gid 1000 app \ + && useradd --system --uid 1000 --gid app --home /app app \ + && chown -R app:app /app + +USER app + +HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ + python -c "from agents.stripe_mcp import main; assert callable(main); print('ok')" || exit 1 + +CMD ["python", "-m", "agents.stripe_mcp"] diff --git a/docs/local-packages-to-images.md b/docs/local-packages-to-images.md index c57defa..ae143ac 100644 --- a/docs/local-packages-to-images.md +++ b/docs/local-packages-to-images.md @@ -37,7 +37,8 @@ Build only specific packages (faster when iterating): ```bash bash scripts/build-packages.sh \ packages/runtime \ - packages/connectors/smtp + packages/connectors/smtp \ + packages/connectors/stripe ``` The script (`scripts/build-packages.sh` in default mode, not `--all`): @@ -56,6 +57,7 @@ Quick check (example for SMTP): ```bash ls packages/runtime/dist/*.whl ls packages/connectors/smtp/dist/*.whl +ls packages/connectors/stripe/dist/*.whl ``` If `ls` fails, rebuild that package before continuing. @@ -81,6 +83,7 @@ This builds: - `nw-smartonfhir-epic` - `nw-smartonfhir-cerner` - `nw-smtp` +- `nw-stripe` ### Build one image manually @@ -100,6 +103,7 @@ Each Dockerfile expects specific wheel files to exist in `dist/`: | `docker/google-drive/Dockerfile` | `packages/runtime/dist/*.whl`, `packages/connectors/google_drive/dist/*.whl` | | `docker/fhir-epic/Dockerfile` | `packages/runtime/dist/*.whl`, `packages/connectors/fhir_epic/dist/*.whl` | | `docker/fhir-cerner/Dockerfile` | `packages/runtime/dist/*.whl`, `packages/connectors/fhir_cerner/dist/*.whl` | +| `docker/stripe/Dockerfile` | `packages/runtime/dist/*.whl`, `packages/connectors/stripe/dist/*.whl` | | `Dockerfile` (unified MCP server) | runtime + all connector wheels (`http_generic`, `stripe`, `smtp`, `google_drive`, `fhir_epic`, `fhir_cerner`) | --- diff --git a/docs/mcp-servers.md b/docs/mcp-servers.md index 964caee..d54ffc7 100644 --- a/docs/mcp-servers.md +++ b/docs/mcp-servers.md @@ -49,8 +49,9 @@ flowchart TD | SMART on FHIR (Epic) | `python -m agents.fhir_epic_mcp` | `nw-smartonfhir-epic` | `nw-smartonfhir-epic` | All manifest actions for `fhir_epic` (e.g. `fhir_epic.read_patient`) | | SMART on FHIR (Cerner) | `python -m agents.fhir_cerner_mcp` | `nw-smartonfhir-cerner` | `nw-smartonfhir-cerner` | All manifest actions for `fhir_cerner` (e.g. `fhir_cerner.read_patient`) | | SMTP | `python -m agents.smtp_mcp` | `nw-smtp` | `nw-smtp` | `smtp.send_email` | +| Stripe | `python -m agents.stripe_mcp` | `nw-stripe` | `nw-stripe` | All manifest actions for `stripe` (e.g., `stripe.charge`) | -The unified server (`python -m agents.mcp_entrypoint`) exposes **every** connector enabled for MCP in `config/connectors.yaml` (e.g. `http_generic.request`, `stripe.charge`, plus the rows above). +The unified server (`python -m agents.mcp_entrypoint`) exposes **every** connector enabled for MCP in `config/connectors.yaml` (e.g. `http_generic.request`, `stripe.charge`, `stripe.create_payment_intent`, `stripe.create_subscription`, `stripe.cancel_subscription`, `stripe.issue_refund`, plus the rows above). ### Tool arguments and security @@ -296,6 +297,16 @@ SMTP_PASSWORD=your-gmail-app-password FROM_EMAIL=your-email@gmail.com ``` +#### `nw-stripe` + +| Variable | Description | +|---|---| +| `STRIPE_API_KEY` | Your Stripe secret API key (starts with `sk_test_` or `sk_live_`) | + +```env +STRIPE_API_KEY=sk_test_4eC39HqLyjWDarjtT1zdp7dc +``` + ### ToolHive / Agent settings | Variable | Description | diff --git a/docs/packaging.md b/docs/packaging.md index 68c81b2..9a0be46 100644 --- a/docs/packaging.md +++ b/docs/packaging.md @@ -194,6 +194,7 @@ docker build -f docker/smtp/Dockerfile -t nw-smtp . docker build -f docker/google-drive/Dockerfile -t nw-google-drive . docker build -f docker/fhir-epic/Dockerfile -t nw-smartonfhir-epic . docker build -f docker/fhir-cerner/Dockerfile -t nw-smartonfhir-cerner . +docker build -f docker/stripe/Dockerfile -t nw-stripe . ``` For compose and ToolHive registration see `docs/mcp-servers.md`. diff --git a/docs/toolhive_agent_scenario.md b/docs/toolhive_agent_scenario.md index e3a6e5f..8e7c562 100644 --- a/docs/toolhive_agent_scenario.md +++ b/docs/toolhive_agent_scenario.md @@ -39,6 +39,7 @@ ToolHive UI ────────────────────── │ ├── Tool: fhir_cerner.read_patient ← fetch patient from Cerner │ │ ├── Tool: fhir_epic.read_patient ← fetch patient from Epic │ │ ├── Tool: google_drive.files.upload ← write file to Drive │ +│ ├── Tool: stripe.charge ← process payment │ │ └── Tool: smtp.send_email ← email the summary │ │ ↕ stdio → HTTP proxy │ ────────────────────────────────────────────────────────────────── @@ -90,6 +91,7 @@ When running **this scenario’s** minimal multi-connector stack (one MCP server | `fhir_cerner.read_patient` | Fetch a patient's record from a Cerner FHIR R4 endpoint | | `fhir_epic.read_patient` | Fetch a patient's record from an Epic FHIR R4 endpoint | | `google_drive.files.upload` | Create and upload a text file to Google Drive | +| `stripe.charge` | Process a payment | | `smtp.send_email` | Send an email via SMTP | The agent uses an LLM's tool-calling capability to decide which tools to call, in what order, and with what parameters. diff --git a/playground/app.js b/playground/app.js index 7741f9f..4c17acf 100644 --- a/playground/app.js +++ b/playground/app.js @@ -50,7 +50,23 @@ document.addEventListener('DOMContentLoaded', () => { const previewName = fileChosenPreview?.querySelector('.preview-name'); const removeFileBtn = fileChosenPreview?.querySelector('.remove-file-btn'); + const stripeForm = document.getElementById('stripe-form'); + const stripeRunBtn = document.getElementById('stripe-run-btn'); + const stripeSpinner = stripeRunBtn.querySelector('.loading-spinner'); + const stripeBtnText = stripeRunBtn.querySelector('.btn-lbl'); + const stripePanel = document.getElementById('stripe-panel'); + + const stripeActionSelect = document.getElementById('stripe-action-select'); + const stripeSections = { + charge: document.getElementById('stripe-section-charge'), + payment_intent: document.getElementById('stripe-section-pi'), + subscription: document.getElementById('stripe-section-sub'), + cancel_subscription: document.getElementById('stripe-section-cancel'), + refund: document.getElementById('stripe-section-refund') + }; + let currentSubMode = 'file'; + let currentStripeSubMode = 'charge'; const connectorStatus = document.getElementById('connector-status'); const brandLabel = document.querySelector('.brand-text h1 span.accent'); const tagline = document.querySelector('.tagline'); @@ -104,6 +120,31 @@ document.addEventListener('DOMContentLoaded', () => { "Apply file update", "Verify file metadata", "Complete update" + ], + stripe_charge: [ + "Initialize Payment", + "Process Charge", + "Verify Transaction" + ], + stripe_payment_intent: [ + "Initialize Session", + "Create Payment Intent", + "Verify Allocation" + ], + stripe_subscription: [ + "Validate Customer", + "Create Subscription", + "Verify Provisioning" + ], + stripe_cancel_subscription: [ + "Locate Resource", + "Cancel Subscription", + "Verify Termination" + ], + stripe_refund: [ + "Validate Charge", + "Process Refund", + "Verify Refund" ] }; @@ -331,6 +372,27 @@ document.addEventListener('DOMContentLoaded', () => { } } + function stripePipelineLabelOverride() { + if (currentStripeSubMode === 'charge') return pipelineLabels.stripe_charge; + if (currentStripeSubMode === 'payment_intent') return pipelineLabels.stripe_payment_intent; + if (currentStripeSubMode === 'subscription') return pipelineLabels.stripe_subscription; + if (currentStripeSubMode === 'cancel_subscription') return pipelineLabels.stripe_cancel_subscription; + if (currentStripeSubMode === 'refund') return pipelineLabels.stripe_refund; + return pipelineLabels.stripe_charge; + } + + function syncStripeActionForm() { + Object.values(stripeSections).forEach(sec => { + if (sec) sec.classList.add('hidden'); + }); + const activeSec = stripeSections[currentStripeSubMode] || stripeSections['charge']; + if (activeSec) activeSec.classList.remove('hidden'); + + if (stripeActionSelect) { + stripeActionSelect.value = currentStripeSubMode; + } + } + function setMode(mode) { currentMode = mode; @@ -339,6 +401,7 @@ document.addEventListener('DOMContentLoaded', () => { itopsPanel.classList.add('hidden'); cernerPanel.classList.add('hidden'); gdrivePanel.classList.add('hidden'); + stripePanel.classList.add('hidden'); if (mode === 'ehr') { ehrPanel.classList.remove('hidden'); @@ -364,10 +427,19 @@ document.addEventListener('DOMContentLoaded', () => { tagline.textContent = 'Secure Vault Orchestration'; document.documentElement.style.setProperty('--brand-accent', '#10b981'); log('Switched to Secure Document Archival mode (Google Drive)', 'system'); + } else if (mode === 'stripe') { + stripePanel.classList.remove('hidden'); + connectorStatus.textContent = 'Stripe Online'; + tagline.textContent = 'Financial Infrastructure'; + document.documentElement.style.setProperty('--brand-accent', '#635bff'); + log('Switched to Stripe Payment Orchestration mode', 'system'); } if (mode === 'gdrive') { syncGdriveActionForm(); resetUI(gdrivePipelineLabelOverride()); + } else if (mode === 'stripe') { + syncStripeActionForm(); + resetUI(stripePipelineLabelOverride()); } else { resetUI(); } @@ -645,6 +717,66 @@ document.addEventListener('DOMContentLoaded', () => { await handleSubmission(payload, '/scenarios/cerner-post-consultation', cernerRunBtn, cernerBtnText, cernerSpinner, 'Sync to Cerner Chart'); }); + stripeForm.addEventListener('submit', async (e) => { + e.preventDefault(); + const formData = new FormData(stripeForm); + const payload = Object.fromEntries(formData.entries()); + + let endpoint = '/scenarios/stripe-charge'; + let submitPayload = {}; + + if (currentStripeSubMode === 'charge' || !currentStripeSubMode) { + submitPayload = { + amount: parseInt(payload.charge_amount, 10), + currency: payload.charge_currency, + description: payload.charge_description + }; + endpoint = '/scenarios/stripe-charge'; + } else if (currentStripeSubMode === 'payment_intent') { + submitPayload = { + amount: parseInt(payload.pi_amount, 10), + currency: payload.pi_currency, + customer_id: payload.pi_customer || undefined, + payment_method: payload.pi_payment_method || undefined, + confirm: payload.pi_confirm === 'on' + }; + endpoint = '/scenarios/stripe-payment-intent'; + } else if (currentStripeSubMode === 'subscription') { + submitPayload = { + customer_id: payload.sub_customer, + price_id: payload.sub_price, + card_token: payload.sub_token || undefined + }; + endpoint = '/scenarios/stripe-subscription'; + } else if (currentStripeSubMode === 'cancel_subscription') { + submitPayload = { + subscription_id: payload.cancel_sub_id + }; + endpoint = '/scenarios/stripe-cancel-subscription'; + } else if (currentStripeSubMode === 'refund') { + const isPI = payload.refund_target_id.startsWith('pi_'); + submitPayload = { + charge_id: !isPI && payload.refund_target_id ? payload.refund_target_id : undefined, + payment_intent_id: isPI ? payload.refund_target_id : undefined, + amount: payload.refund_amount ? parseInt(payload.refund_amount, 10) : undefined + }; + endpoint = '/scenarios/stripe-refund'; + } + + await handleSubmission(submitPayload, endpoint, stripeRunBtn, stripeBtnText, stripeSpinner, 'Process Action'); + }); + + if (stripeActionSelect) { + stripeActionSelect.addEventListener('change', (e) => { + const mode = e.target.value; + if (mode === currentStripeSubMode) return; + currentStripeSubMode = mode; + syncStripeActionForm(); + resetUI(stripePipelineLabelOverride()); + log(`Switched to Stripe mode [${currentStripeSubMode}]`); + }); + } + // File Preview Logic if (gdriveFileInput && fileChosenPreview && previewName && fileDropZone) { gdriveFileInput.addEventListener('change', () => { diff --git a/playground/index.html b/playground/index.html index 52f24c0..9d2f331 100644 --- a/playground/index.html +++ b/playground/index.html @@ -205,8 +205,21 @@

Google Drive

Secure clinical document archival with IAM-governed access and encryption.

+ +
+
+ + + + +
+
+

Stripe

+

Financial transaction and subscription management infrastructure.

+
+
-
+
diff --git a/playground/scenarios.py b/playground/scenarios.py index f32c5be..558c72c 100644 --- a/playground/scenarios.py +++ b/playground/scenarios.py @@ -41,6 +41,7 @@ FilesListOperation, FilesUpdateOperation, ) +from node_wire_stripe.schema import ChargeInput logger = logging.getLogger("playground.scenarios") router = APIRouter(prefix="/scenarios", tags=["scenarios"]) @@ -61,6 +62,32 @@ class IncidentReportInput(BaseModel): description: str reported_by: str = "Demo User" +class StripeChargeInput(BaseModel): + amount: int + currency: str + description: Optional[str] = None + source: str = "tok_visa" + +class StripePaymentIntentInputPlayground(BaseModel): + amount: int + currency: str + customer_id: Optional[str] = None + payment_method: Optional[str] = None + confirm: bool = False + +class StripeSubscriptionInputPlayground(BaseModel): + customer_id: str + price_id: str + card_token: Optional[str] = None + +class StripeCancelSubscriptionInputPlayground(BaseModel): + subscription_id: str + +class StripeRefundInputPlayground(BaseModel): + charge_id: Optional[str] = None + payment_intent_id: Optional[str] = None + amount: Optional[int] = None + class CernerPostConsultationInput(BaseModel): patient_id: Optional[str] = None patient_family: Optional[str] = None @@ -239,6 +266,13 @@ def get_google_drive_connector(): return connector +def get_stripe_connector(): + connector = resolve_connector("stripe") + if not connector: + raise HTTPException(status_code=500, detail="Stripe connector not configured") + return connector + + @router.post("/post-consultation", response_model=ScenarioResponse) async def post_consultation_scenario( payload: PostConsultationInput, @@ -793,6 +827,262 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", except Exception as e: return _safe_error_return(e, steps, trace_id, "Step 3 failed") +@router.post("/stripe-charge", response_model=ScenarioResponse) +async def stripe_charge_scenario( + payload: StripeChargeInput, + connector: Any = Depends(get_stripe_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + + def add_step(name: str, status: str, details: str = "", display_name: str = "", data: Any = None): + steps.append(ScenarioStep(name=name, status=status, details=details, display_name=display_name, data=data)) + + # STEP 1: Process Payment Intent + add_step("Process Payment Intent", "pending", display_name="Initialize Payment") + try: + steps[-1].status = "success" + steps[-1].details = "Payment initialization verified." + steps[-1].display_name = "Payment Initialized" + steps[-1].data = {"amount": payload.amount, "currency": payload.currency} + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 1 failed") + + # STEP 2: Confirm Charge + add_step("Confirm Charge", "pending", display_name="Process Charge") + try: + from node_wire_stripe.schema import ChargeInput + charge_input = ChargeInput( + amount=payload.amount, + currency=payload.currency, + source=payload.source, + description=payload.description + ) + + charge_res = await execute_with_retry(connector, charge_input, trace_id, steps[-1]) + + steps[-1].status = "success" + steps[-1].details = f"Charge Processed: {charge_res.charge_id}" + steps[-1].display_name = "Charge Successful" + steps[-1].data = {"raw": charge_res.model_dump()} + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 2 failed") + + # STEP 3: Verify Transaction + add_step("Verify Transaction", "pending", display_name="Verify Receipt") + try: + beautiful_data = { + "id": charge_res.charge_id, + "type": "Payment Receipt", + "date": datetime.now().isoformat(), + "status": charge_res.status, + "patient_name": "Demo User", + "author": "Stripe Gateway", + "category": "Financial", + "description": payload.description or "No description", + "content_text": f"Charge of {payload.amount/100:.2f} {payload.currency.upper()} processed successfully. Receipt: {charge_res.receipt_url or 'N/A'}" + } + steps[-1].status = "success" + steps[-1].details = "Transaction Verified" + steps[-1].display_name = "Transaction Verified" + steps[-1].data = {"beautiful_data": beautiful_data, "raw": {"status": "Verified"}} + + return ScenarioResponse( + success=True, + steps=steps, + final_resource_id=charge_res.charge_id, + human_summary=f"Successfully processed {payload.amount/100:.2f} {payload.currency.upper()} charge.", + trace_id=trace_id + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 3 failed") + +@router.post("/stripe-payment-intent", response_model=ScenarioResponse) +async def stripe_payment_intent_scenario( + payload: StripePaymentIntentInputPlayground, + connector: Any = Depends(get_stripe_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + def add_step(name: str, status: str, details: str = "", display_name: str = "", data: Any = None): + steps.append(ScenarioStep(name=name, status=status, details=details, display_name=display_name, data=data)) + + add_step("Initialize Session", "pending", display_name="Initialize PI") + try: + steps[-1].status = "success" + steps[-1].details = f"Initialized PI session for {payload.amount} {payload.currency}" + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 1 failed") + + add_step("Create Payment Intent", "pending", display_name="Create Intent") + try: + from node_wire_stripe.schema import CreatePaymentIntentInput + pi_input = CreatePaymentIntentInput( + amount=payload.amount, + currency=payload.currency, + customer_id=payload.customer_id, + payment_method=payload.payment_method, + confirm=payload.confirm + ) + res = await execute_with_retry(connector, pi_input, trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = f"Created Intent: {res.payment_intent_id}" + steps[-1].data = {"raw": res.model_dump()} + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 2 failed") + + add_step("Verify Allocation", "pending", display_name="Verify Allocation") + try: + steps[-1].status = "success" + steps[-1].details = "Allocation verified" + steps[-1].display_name = "Allocation Verified" + + return ScenarioResponse( + success=True, + steps=steps, + final_resource_id=res.payment_intent_id, + human_summary=f"Successfully created payment intent {res.payment_intent_id}.", + trace_id=trace_id + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 3 failed") + +@router.post("/stripe-subscription", response_model=ScenarioResponse) +async def stripe_subscription_scenario( + payload: StripeSubscriptionInputPlayground, + connector: Any = Depends(get_stripe_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + def add_step(name: str, status: str, details: str = "", display_name: str = "", data: Any = None): + steps.append(ScenarioStep(name=name, status=status, details=details, display_name=display_name, data=data)) + + add_step("Validate Customer", "pending", display_name="Validate Params") + try: + steps[-1].status = "success" + steps[-1].details = f"Validated inputs for Customer: {payload.customer_id}" + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 1 failed") + + add_step("Create Subscription", "pending", display_name="Create Sub") + try: + from node_wire_stripe.schema import CreateSubscriptionInput + sub_input = CreateSubscriptionInput( + customer_id=payload.customer_id, + price_id=payload.price_id, + card_token=payload.card_token + ) + res = await execute_with_retry(connector, sub_input, trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = f"Subscription Created: {res.subscription_id}" + steps[-1].data = {"raw": res.model_dump()} + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 2 failed") + + add_step("Verify Provisioning", "pending", display_name="Verify Sub") + try: + steps[-1].status = "success" + steps[-1].details = f"Subscription {res.subscription_id} is {res.status}" + return ScenarioResponse( + success=True, + steps=steps, + final_resource_id=res.subscription_id, + human_summary=f"Successfully provisioned subscription for customer.", + trace_id=trace_id + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 3 failed") + +@router.post("/stripe-cancel-subscription", response_model=ScenarioResponse) +async def stripe_cancel_subscription_scenario( + payload: StripeCancelSubscriptionInputPlayground, + connector: Any = Depends(get_stripe_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + def add_step(name: str, status: str, details: str = "", display_name: str = "", data: Any = None): + steps.append(ScenarioStep(name=name, status=status, details=details, display_name=display_name, data=data)) + + add_step("Locate Resource", "pending", display_name="Locate Sub") + try: + steps[-1].status = "success" + steps[-1].details = f"Targeting subscription: {payload.subscription_id}" + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 1 failed") + + add_step("Cancel Subscription", "pending", display_name="Cancel Sub") + try: + from node_wire_stripe.schema import CancelSubscriptionInput + can_input = CancelSubscriptionInput( + subscription_id=payload.subscription_id + ) + res = await execute_with_retry(connector, can_input, trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = f"Cancelled Sub: {res.subscription_id}" + steps[-1].data = {"raw": res.model_dump()} + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 2 failed") + + add_step("Verify Termination", "pending", display_name="Verify Cancel") + try: + steps[-1].status = "success" + steps[-1].details = f"Cancellation verified. Status: {res.status}" + return ScenarioResponse( + success=True, + steps=steps, + final_resource_id=res.subscription_id, + human_summary=f"Successfully canceled subscription.", + trace_id=trace_id + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 3 failed") + +@router.post("/stripe-refund", response_model=ScenarioResponse) +async def stripe_refund_scenario( + payload: StripeRefundInputPlayground, + connector: Any = Depends(get_stripe_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + def add_step(name: str, status: str, details: str = "", display_name: str = "", data: Any = None): + steps.append(ScenarioStep(name=name, status=status, details=details, display_name=display_name, data=data)) + + add_step("Validate Charge", "pending", display_name="Validate Params") + try: + steps[-1].status = "success" + steps[-1].details = f"Refund targeted for ID: {payload.charge_id or payload.payment_intent_id}" + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 1 failed") + + add_step("Process Refund", "pending", display_name="Issue Refund") + try: + from node_wire_stripe.schema import IssueRefundInput + ref_input = IssueRefundInput( + charge_id=payload.charge_id, + payment_intent_id=payload.payment_intent_id, + amount=payload.amount + ) + res = await execute_with_retry(connector, ref_input, trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = f"Refund Processed: {res.refund_id}" + steps[-1].data = {"raw": res.model_dump()} + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 2 failed") + + add_step("Verify Refund", "pending", display_name="Verify Receipt") + try: + steps[-1].status = "success" + steps[-1].details = f"Refund recorded properly. Status: {res.status}" + return ScenarioResponse( + success=True, + steps=steps, + final_resource_id=res.refund_id, + human_summary=f"Successfully issued refund.", + trace_id=trace_id + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 3 failed") + @router.post("/gdrive-archival", response_model=ScenarioResponse) async def gdrive_archival_scenario( payload: GoogleDriveArchivalInput, diff --git a/scripts/build-mcp-images.sh b/scripts/build-mcp-images.sh index a3844c7..5e5c0eb 100755 --- a/scripts/build-mcp-images.sh +++ b/scripts/build-mcp-images.sh @@ -16,6 +16,7 @@ Images: - nw-smartonfhir-epic - nw-smartonfhir-cerner - nw-smtp + - nw-stripe EOF } @@ -65,5 +66,10 @@ docker build -f docker/smtp/Dockerfile \ -t "nw-smtp:${VERSION}" \ . +docker build -f docker/stripe/Dockerfile \ + -t nw-stripe:latest \ + -t "nw-stripe:${VERSION}" \ + . + echo "Done." diff --git a/src/agents/stripe_mcp.py b/src/agents/stripe_mcp.py new file mode 100644 index 0000000..7e1de98 --- /dev/null +++ b/src/agents/stripe_mcp.py @@ -0,0 +1,27 @@ +"""MCP Server — Stripe connector only. Usage: python -m agents.stripe_mcp""" +from __future__ import annotations + +import logging +import os + +from dotenv import load_dotenv + +load_dotenv() +load_dotenv(os.path.join(os.path.dirname(__file__), ".env")) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("agents.stripe_mcp") + + +def main() -> None: + from bindings.mcp_server.server import McpServer + + logger.info("Starting nw-stripe MCP server (stdio, manifest-driven)") + McpServer( + server_name="nw-stripe", + connector_ids=["stripe"], + ).run_stdio() + + +if __name__ == "__main__": + main() diff --git a/src/node_wire_stripe/README.md b/src/node_wire_stripe/README.md new file mode 100644 index 0000000..45a2da4 --- /dev/null +++ b/src/node_wire_stripe/README.md @@ -0,0 +1,68 @@ +# Node Wire Connector — Stripe + +The Stripe connector provides a reliable, async adapter for processing payments and managing subscriptions using the Stripe Python SDK. It follows the Node Wire platform contract: consistent error handling, resilience (retries/circuit breaking), and standardized telemetry. + +## Supported Actions + +The connector exposes several actions through the `@nw_action` decorator. Each action is available via REST, gRPC, and MCP. + +| Action | Description | Key Parameters | +| :--- | :--- | :--- | +| `charge` | Legacy charge creation. | `amount`, `currency`, `source` | +| `create_payment_intent` | Modern payment flow for one-time payments. | `amount`, `currency`, `customer_id`, `confirm` | +| `create_subscription` | Create a recurring subscription. | `customer_id`, `price_id`, `card_token` | +| `cancel_subscription` | Terminate or schedule the end of a subscription. | `subscription_id`, `cancel_at_period_end` | +| `issue_refund` | Full or partial refund for a charge or payment intent. | `charge_id` or `payment_intent_id`, `amount` | + +## Setup & Configuration + +### Environment Variables + +The connector requires a Stripe secret API key. By default, the `EnvSecretProvider` looks for: + +- `STRIPE_API_KEY`: Your Stripe secret key (e.g., `sk_test_...` or `sk_live_...`). + +Add this to your `.env` or system environment: + +```bash +STRIPE_API_KEY=sk_test_your_secret_key +``` + +### Enabling the Connector + +In `config/connectors.yaml`, ensuring the connector is enabled and exposed: + +```yaml +connectors: + stripe: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] +``` + +## Detailed Action Reference + +### `create_subscription` + +This action supports multiple payment integration flows: + +1. **Saved Payment Method**: Pass `default_payment_method` with an existing `pm_xxx` ID. +2. **Card Token**: Pass `card_token` (e.g., `tok_visa`). The connector will automatically create a PaymentMethod and attach it to the customer before creating the subscription. +3. **SCA / Action Required**: If the subscription requires further action (like 3D Secure), the connector returns the `client_secret` from the associated Setup Intent or Payment Intent. + +### `cancel_subscription` + +- Set `cancel_at_period_end: true` to let the subscription finish its current cycle. +- Set `cancel_at_period_end: false` (default) to terminate the subscription immediately. + +## Error Handling + +Mapped Stripe exceptions to Node Wire error categories: + +- `RateLimitError` -> `RETRYABLE` (`STRIPE_RATE_LIMIT`) +- `CardError` -> `BUSINESS` (`STRIPE_CARD_ERROR`) +- `AuthenticationError` -> `AUTH` (`STRIPE_AUTH_ERROR`) +- `APIConnectionError` -> `RETRYABLE` (`STRIPE_API_CONNECTION`) +- `InvalidRequestError` -> `BUSINESS` (`STRIPE_INVALID_REQUEST`) +- `StripeError` -> `FATAL` (`STRIPE_ERROR`) + +Trace IDs are included in all error responses for easier troubleshooting in the Stripe Dashboard. diff --git a/src/node_wire_stripe/logic.py b/src/node_wire_stripe/logic.py index 76fd658..51d043e 100644 --- a/src/node_wire_stripe/logic.py +++ b/src/node_wire_stripe/logic.py @@ -2,23 +2,30 @@ import asyncio import logging +from typing import Any import stripe from node_wire_runtime import BaseConnector, nw_action from node_wire_runtime.models import ErrorCategory -from .schema import ChargeInput, ChargeOutput +from .schema import ( + CancelSubscriptionInput, + ChargeInput, + CreatePaymentIntentInput, + CreateSubscriptionInput, + IssueRefundInput, + StripeOperationOutput, +) logger = logging.getLogger("connectors.stripe") class StripeConnector(BaseConnector): - """Stripe connector: charges and future SDK operations as @nw_action methods.""" + """Stripe connector: payments and subscriptions as @nw_action methods.""" connector_id = "stripe" - action = "charge" - output_model = ChargeOutput + output_model = StripeOperationOutput error_map = { stripe.error.RateLimitError: (ErrorCategory.RETRYABLE, "STRIPE_RATE_LIMIT"), @@ -29,17 +36,12 @@ class StripeConnector(BaseConnector): stripe.error.StripeError: (ErrorCategory.FATAL, "STRIPE_ERROR"), } + def _get_api_key(self) -> str: + return self.secret_provider.get_secret("stripe_api_key") + @nw_action("charge") - async def charge(self, params: ChargeInput, *, trace_id: str) -> ChargeOutput: - # The factory injects a StaticTokenAuthProvider with the Stripe API key. - # We extract the raw key from the Authorization header value. - auth_headers = await self.get_auth_headers() - raw_auth = auth_headers.get("Authorization", "") - # Strip any prefix (e.g. "Bearer ") — Stripe expects the raw key. - api_key = raw_auth.split(" ", 1)[-1].strip() if raw_auth else "" - if not api_key: - # Backward-compatibility fallback: read directly from secret_provider. - api_key = self.secret_provider.get_secret("stripe_api_key") + async def charge(self, params: ChargeInput, *, trace_id: str) -> StripeOperationOutput: + api_key = self._get_api_key() logger.info( "Creating Stripe charge", @@ -53,12 +55,15 @@ async def charge(self, params: ChargeInput, *, trace_id: str) -> ChargeOutput: ) def _create() -> stripe.Charge: - stripe.api_key = api_key return stripe.Charge.create( + api_key=api_key, amount=params.amount, currency=params.currency, source=params.source, + customer=params.customer_id, description=params.description, + metadata=params.metadata, + idempotency_key=params.idempotency_key or trace_id, ) try: @@ -70,25 +75,238 @@ def _create() -> stripe.Charge: "trace_id": trace_id, "connector_id": self.connector_id, "action": "charge", - "amount": params.amount, - "currency": params.currency, "error_type": type(exc).__name__, "error_message": str(exc), }, ) raise + return StripeOperationOutput( + charge_id=getattr(charge, "id", None), + receipt_url=getattr(charge, "receipt_url", None), + status="succeeded" if getattr(charge, "paid", False) else "failed", + ) + + @nw_action("create_payment_intent") + async def create_payment_intent( + self, params: CreatePaymentIntentInput, *, trace_id: str + ) -> StripeOperationOutput: + api_key = self._get_api_key() + logger.info( - "Stripe charge created successfully", + "Creating Stripe Payment Intent", extra={ "trace_id": trace_id, "connector_id": self.connector_id, - "action": "charge", - "charge_id": charge.get("id"), + "action": "create_payment_intent", + "amount": params.amount, + "currency": params.currency, + }, + ) + + def _create() -> stripe.PaymentIntent: + return stripe.PaymentIntent.create( + api_key=api_key, + amount=params.amount, + currency=params.currency, + customer=params.customer_id, + payment_method=params.payment_method, + confirm=params.confirm, + description=params.description, + metadata=params.metadata, + idempotency_key=params.idempotency_key or trace_id, + ) + + try: + pi = await asyncio.to_thread(_create) + except Exception as exc: + logger.error( + "Stripe Payment Intent creation failed", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "create_payment_intent", + "error_type": type(exc).__name__, + "error_message": str(exc), + }, + ) + raise + + return StripeOperationOutput( + payment_intent_id=getattr(pi, "id", None), + client_secret=getattr(pi, "client_secret", None), + status=getattr(pi, "status", None), + ) + + @nw_action("create_subscription") + async def create_subscription(self, params: CreateSubscriptionInput, *, trace_id: str) -> StripeOperationOutput: + api_key = self._get_api_key() + + logger.info( + "Creating Stripe Subscription", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "create_subscription", + "customer_id": params.customer_id, + "price_id": params.price_id, }, ) - return ChargeOutput( - charge_id=charge.get("id"), - receipt_url=charge.get("receipt_url"), + def _create() -> stripe.Subscription: + payment_method_id = params.default_payment_method + + # If card_token is provided, create and attach PaymentMethod + if params.card_token: + pm = stripe.PaymentMethod.create( + api_key=api_key, + type="card", + card={"token": params.card_token}, + ) + stripe.PaymentMethod.attach( + pm.id, + api_key=api_key, + customer=params.customer_id, + ) + payment_method_id = pm.id + + return stripe.Subscription.create( + api_key=api_key, + customer=params.customer_id, + items=[{"price": params.price_id}] if params.price_id else None, + payment_behavior=params.payment_behavior, + default_payment_method=payment_method_id, + metadata=params.metadata, + idempotency_key=params.idempotency_key or trace_id, + ) + + try: + sub = await asyncio.to_thread(_create) + except Exception as exc: + logger.error( + "Stripe Subscription creation failed", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "create_subscription", + "error_type": type(exc).__name__, + "error_message": str(exc), + }, + ) + raise + + # Subscription might have a setup_intent or latest_invoice.payment_intent + client_secret = None + pending_setup_intent = getattr(sub, "pending_setup_intent", None) + latest_invoice_id = getattr(sub, "latest_invoice", None) + + if pending_setup_intent: + si = await asyncio.to_thread(stripe.SetupIntent.retrieve, pending_setup_intent, api_key=api_key) + client_secret = getattr(si, "client_secret", None) + elif latest_invoice_id: + inv = await asyncio.to_thread(stripe.Invoice.retrieve, latest_invoice_id, api_key=api_key) + pi_id = getattr(inv, "payment_intent", None) + if pi_id: + pi = await asyncio.to_thread(stripe.PaymentIntent.retrieve, pi_id, api_key=api_key) + client_secret = getattr(pi, "client_secret", None) + + return StripeOperationOutput( + subscription_id=getattr(sub, "id", None), + status=getattr(sub, "status", None), + client_secret=client_secret, + ) + + @nw_action("cancel_subscription") + async def cancel_subscription(self, params: CancelSubscriptionInput, *, trace_id: str) -> StripeOperationOutput: + api_key = self._get_api_key() + + logger.info( + "Cancelling Stripe Subscription", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "cancel_subscription", + "subscription_id": params.subscription_id, + }, + ) + + def _cancel() -> stripe.Subscription: + if params.cancel_at_period_end: + return stripe.Subscription.modify( + params.subscription_id, + api_key=api_key, + cancel_at_period_end=True, + idempotency_key=params.idempotency_key or trace_id, + ) + else: + return stripe.Subscription.cancel( + params.subscription_id, + api_key=api_key, + idempotency_key=params.idempotency_key or trace_id, + ) + + try: + sub = await asyncio.to_thread(_cancel) + except Exception as exc: + logger.error( + "Stripe Subscription cancellation failed", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "cancel_subscription", + "error_type": type(exc).__name__, + "error_message": str(exc), + }, + ) + raise + + return StripeOperationOutput( + subscription_id=getattr(sub, "id", None), + status=getattr(sub, "status", None), + ) + + @nw_action("issue_refund") + async def issue_refund(self, params: IssueRefundInput, *, trace_id: str) -> StripeOperationOutput: + api_key = self._get_api_key() + + logger.info( + "Issuing Stripe Refund", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "issue_refund", + "charge_id": params.charge_id, + "payment_intent_id": params.payment_intent_id, + }, + ) + + def _refund() -> stripe.Refund: + return stripe.Refund.create( + api_key=api_key, + charge=params.charge_id, + payment_intent=params.payment_intent_id, + amount=params.amount, + reason=params.reason, + metadata=params.metadata, + idempotency_key=params.idempotency_key or trace_id, + ) + + try: + refund = await asyncio.to_thread(_refund) + except Exception as exc: + logger.error( + "Stripe Refund issuance failed", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "issue_refund", + "error_type": type(exc).__name__, + "error_message": str(exc), + }, + ) + raise + + return StripeOperationOutput( + refund_id=getattr(refund, "id", None), + status=getattr(refund, "status", None), ) diff --git a/src/node_wire_stripe/schema.py b/src/node_wire_stripe/schema.py index e912829..8564cbb 100644 --- a/src/node_wire_stripe/schema.py +++ b/src/node_wire_stripe/schema.py @@ -1,18 +1,99 @@ -from __future__ import annotations - -from typing import Literal - -from pydantic import BaseModel - - -class ChargeInput(BaseModel): - action: Literal["charge"] = "charge" - amount: int - currency: str - source: str - description: str | None = None - - -class ChargeOutput(BaseModel): - charge_id: str - receipt_url: str | None = None +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +class ChargeInput(BaseModel): + action: Literal["charge"] = "charge" + amount: int = Field(..., ge=1, le=99999999) + currency: str = Field(..., pattern=r"^[a-z]{3}$") + source: str + customer_id: str | None = None + description: str | None = None + metadata: dict | None = None + idempotency_key: str | None = Field(None, description="Optional unique key to prevent duplicate operations.") + + +class ChargeOutput(BaseModel): + charge_id: str + receipt_url: str | None = None + + +class CancelSubscriptionInput(BaseModel): + action: Literal["cancel_subscription"] = "cancel_subscription" + subscription_id: str + cancel_at_period_end: bool = False + idempotency_key: str | None = Field(None, description="Optional unique key to prevent duplicate operations.") + + +class CancelSubscriptionOutput(BaseModel): + subscription_id: str + status: str + + +class CreatePaymentIntentInput(BaseModel): + action: Literal["create_payment_intent"] = "create_payment_intent" + amount: int = Field(..., ge=1, le=99999999) + currency: str = Field(..., pattern=r"^[a-z]{3}$") + customer_id: str | None = None + payment_method: str | None = None + confirm: bool = False + description: str | None = None + metadata: dict | None = None + idempotency_key: str | None = Field(None, description="Optional unique key to prevent duplicate operations.") + + +class CreatePaymentIntentOutput(BaseModel): + payment_intent_id: str + client_secret: str | None = None + status: str + + +class CreateSubscriptionInput(BaseModel): + action: Literal["create_subscription"] = "create_subscription" + customer_id: str + price_id: str + payment_behavior: str = "default_incomplete" + default_payment_method: str | None = None + card_token: str | None = None + metadata: dict | None = None + idempotency_key: str | None = Field(None, description="Optional unique key to prevent duplicate operations.") + + +class CreateSubscriptionOutput(BaseModel): + subscription_id: str + client_secret: str | None = None + status: str + + +class IssueRefundInput(BaseModel): + action: Literal["issue_refund"] = "issue_refund" + charge_id: str | None = None + payment_intent_id: str | None = None + amount: int | None = Field(None, ge=1, le=99999999) + reason: str | None = None + metadata: dict | None = None + idempotency_key: str | None = Field(None, description="Optional unique key to prevent duplicate operations.") + + +class IssueRefundOutput(BaseModel): + refund_id: str + status: str + + +class StripeOperationOutput(BaseModel): + """ + Unified output model for all Stripe actions. + The actual result will be contained in one or more of these fields. + """ + + charge_id: str | None = None + receipt_url: str | None = None + subscription_id: str | None = None + status: str | None = None + payment_intent_id: str | None = None + client_secret: str | None = None + refund_id: str | None = None + raw: dict[str, Any] | None = None diff --git a/tests/test_base_connector_manifest.py b/tests/test_base_connector_manifest.py index 08314e7..bb77caa 100644 --- a/tests/test_base_connector_manifest.py +++ b/tests/test_base_connector_manifest.py @@ -39,14 +39,14 @@ def test_manifest_emits_per_action(): rest_actions = {(e["connector_id"], e["action"]) for e in rest_manifest} assert ("google_drive", "files.list") in rest_actions assert ("fhir_epic", "read_patient") in rest_actions - assert ("stripe", "charge") not in rest_actions # stripe is grpc/mcp only in config + assert ("stripe", "charge") in rest_actions mcp_manifest = build_manifest(factory.list_for_protocol("mcp")) mcp_actions = {(e["connector_id"], e["action"]) for e in mcp_manifest} assert ("stripe", "charge") in mcp_actions # Per-action input schema should expose that action's fields (not only a buried union) for entry in mcp_manifest: - if entry["connector_id"] == "stripe": + if entry["connector_id"] == "stripe" and entry["action"] == "charge": props = entry["input_schema"].get("properties", {}) assert "amount" in props diff --git a/tests/test_connectors_basic.py b/tests/test_connectors_basic.py index 7ab4dfa..e6a744a 100644 --- a/tests/test_connectors_basic.py +++ b/tests/test_connectors_basic.py @@ -41,4 +41,4 @@ def test_smtp_connector_instantiation_only(): def test_stripe_connector_instantiation_only(): connector = StripeConnector(secret_provider=DummySecretProvider()) assert connector.connector_id == "stripe" - assert connector.action == "charge" + assert connector.action == "execute" diff --git a/tests/test_connectors_io.py b/tests/test_connectors_io.py index a945d16..a788ba4 100644 --- a/tests/test_connectors_io.py +++ b/tests/test_connectors_io.py @@ -135,7 +135,7 @@ def test_stripe_charge_via_run() -> None: secrets = _MapSecrets({"stripe_api_key": "sk_test_dummy"}) with patch("node_wire_stripe.logic.stripe.Charge") as mock_charge: - mock_charge.create.return_value = {"id": "ch_123", "receipt_url": "https://pay.example/r"} + mock_charge.create.return_value = MagicMock(id="ch_123", receipt_url="https://pay.example/r", paid=True) c = StripeConnector(secret_provider=secrets) async def _run() -> None: diff --git a/tests/test_factory_and_rest.py b/tests/test_factory_and_rest.py index b567d60..5de59d3 100644 --- a/tests/test_factory_and_rest.py +++ b/tests/test_factory_and_rest.py @@ -19,7 +19,7 @@ def test_factory_loads_config(): assert http_connector is not None stripe_rest = factory.get_for_protocol("stripe", "rest") - assert stripe_rest is None # stripe not exposed via REST per config + assert stripe_rest is not None # stripe exposed via REST def test_health_endpoint(): diff --git a/tests/test_stripe.py b/tests/test_stripe.py new file mode 100644 index 0000000..90060a2 --- /dev/null +++ b/tests/test_stripe.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from node_wire_runtime import SecretProvider +from node_wire_stripe.logic import StripeConnector +from node_wire_stripe.schema import ( + CancelSubscriptionInput, + ChargeInput, + CreatePaymentIntentInput, + CreateSubscriptionInput, + IssueRefundInput, + StripeOperationOutput, +) + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +class MockSecretProvider(SecretProvider): + def get_secret(self, key: str) -> str: + return { + "stripe_api_key": "sk_test_mock", + }[key] + + +def _connector() -> StripeConnector: + """Return a StripeConnector with mock secrets.""" + return StripeConnector(secret_provider=MockSecretProvider()) + + +# --------------------------------------------------------------------------- +# Charge +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_stripe_charge_happy_path(): + connector = _connector() + params = ChargeInput(amount=1000, currency="usd", source="tok_visa") + + mock_charge = MagicMock(id="ch_123", receipt_url="http://stripe.com/receipt", paid=True) + + with patch("stripe.Charge.create", return_value=mock_charge) as mock_create: + result = await connector.charge(params, trace_id="test-trace") + + assert result.charge_id == "ch_123" + assert result.receipt_url == "http://stripe.com/receipt" + assert result.status == "succeeded" + mock_create.assert_called_once_with( + api_key="sk_test_mock", + amount=1000, + currency="usd", + source="tok_visa", + customer=None, + description=None, + metadata=None, + idempotency_key="test-trace", + ) + + +# --------------------------------------------------------------------------- +# Create Payment Intent +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_stripe_create_payment_intent_happy_path(): + connector = _connector() + params = CreatePaymentIntentInput(amount=2000, currency="eur", confirm=True) + + mock_pi = MagicMock(id="pi_123", client_secret="secret_abc", status="requires_payment_method") + + with patch("stripe.PaymentIntent.create", return_value=mock_pi) as mock_create: + result = await connector.create_payment_intent(params, trace_id="test-trace") + + assert result.payment_intent_id == "pi_123" + assert result.client_secret == "secret_abc" + assert result.status == "requires_payment_method" + mock_create.assert_called_once_with( + api_key="sk_test_mock", + amount=2000, + currency="eur", + customer=None, + payment_method=None, + confirm=True, + description=None, + metadata=None, + idempotency_key="test-trace", + ) + + +# --------------------------------------------------------------------------- +# Create Subscription +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_stripe_create_subscription_with_card_token(): + connector = _connector() + params = CreateSubscriptionInput(customer_id="cus_123", price_id="price_abc", card_token="tok_visa") + + mock_pm = MagicMock(id="pm_123") + mock_sub = MagicMock(id="sub_123", status="active", pending_setup_intent=None, latest_invoice=None) + + with patch("stripe.PaymentMethod.create", return_value=mock_pm) as mock_pm_create, \ + patch("stripe.PaymentMethod.attach") as mock_pm_attach, \ + patch("stripe.Subscription.create", return_value=mock_sub) as mock_sub_create: + + result = await connector.create_subscription(params, trace_id="test-trace") + + assert result.subscription_id == "sub_123" + assert result.status == "active" + + mock_pm_create.assert_called_once() + mock_pm_attach.assert_called_once_with("pm_123", api_key="sk_test_mock", customer="cus_123") + mock_sub_create.assert_called_once() + assert mock_sub_create.call_args.kwargs["default_payment_method"] == "pm_123" + assert mock_sub_create.call_args.kwargs["idempotency_key"] == "test-trace" + + +# --------------------------------------------------------------------------- +# Cancel Subscription +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_stripe_cancel_subscription_immediate(): + connector = _connector() + params = CancelSubscriptionInput(subscription_id="sub_123", cancel_at_period_end=False) + + mock_sub = MagicMock(id="sub_123", status="canceled") + + with patch("stripe.Subscription.cancel", return_value=mock_sub) as mock_cancel: + result = await connector.cancel_subscription(params, trace_id="test-trace") + + assert result.subscription_id == "sub_123" + assert result.status == "canceled" + mock_cancel.assert_called_once_with("sub_123", api_key="sk_test_mock", idempotency_key="test-trace") + + +# --------------------------------------------------------------------------- +# Issue Refund +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_stripe_issue_refund_happy_path(): + connector = _connector() + params = IssueRefundInput(payment_intent_id="pi_123", amount=500) + + mock_refund = MagicMock(id="re_123", status="succeeded") + + with patch("stripe.Refund.create", return_value=mock_refund) as mock_refund_create: + result = await connector.issue_refund(params, trace_id="test-trace") + + assert result.refund_id == "re_123" + assert result.status == "succeeded" + mock_refund_create.assert_called_once_with( + api_key="sk_test_mock", + charge=None, + payment_intent="pi_123", + amount=500, + reason=None, + metadata=None, + idempotency_key="test-trace", + ) + + +# --------------------------------------------------------------------------- +# Schema Validation +# --------------------------------------------------------------------------- + +def test_stripe_schema_validation_bounds(): + """Verify that amount and currency bounds are enforced.""" + # Valid + ChargeInput(amount=1, currency="usd", source="tok_visa") + + # Invalid amount (too small) + with pytest.raises(ValidationError): + ChargeInput(amount=0, currency="usd", source="tok_visa") + + # Invalid currency (wrong length/format) + with pytest.raises(ValidationError): + ChargeInput(amount=100, currency="us", source="tok_visa") + + with pytest.raises(ValidationError): + ChargeInput(amount=100, currency="USDT", source="tok_visa") + + +# --------------------------------------------------------------------------- +# Error Mapping +# --------------------------------------------------------------------------- + +def test_stripe_error_mapping(): + """Verify that Stripe exceptions are correctly mapped to ErrorCategory.""" + import stripe + connector = _connector() + from node_wire_runtime.models import ErrorCategory + + # Check specific mappings from StripeConnector.error_map + assert connector.error_map[stripe.error.CardError] == (ErrorCategory.BUSINESS, "STRIPE_CARD_ERROR") + assert connector.error_map[stripe.error.RateLimitError] == (ErrorCategory.RETRYABLE, "STRIPE_RATE_LIMIT") + assert connector.error_map[stripe.error.AuthenticationError] == (ErrorCategory.AUTH, "STRIPE_AUTH_ERROR") From 7e50d5cfae0e0f309aeea6a46e7ba04db4bc0209 Mon Sep 17 00:00:00 2001 From: kesav-aot Date: Tue, 5 May 2026 05:40:45 -0700 Subject: [PATCH 23/60] Add configurable timeouts, breakers, and sanitizers (#13) * Add configurable timeouts, breakers, and sanitizers Make connector timeouts and resilience configurable and sanitize telemetry. Replace hardcoded httpx/aiosmtplib 30s timeouts with AOT_CONNECTOR_TIMEOUT (default 30.0) across Cerner, Epic, HTTP generic, and SMTP connectors and add needed os imports. Update BaseConnector to use per-tenant CircuitBreaker instances with AOT_CIRCUIT_BREAKER_FAIL_MAX and AOT_CIRCUIT_BREAKER_RESET_TIMEOUT (defaults: 5 and 30) and add audit metadata on policy denials. Add sanitizing wrappers for span and log exporters to redact sensitive attributes before exporting. Also add commented sample.env entries documenting the new environment variables. * Manage tenant circuit breakers; FHIR connector fixes Initialize and use a per-tenant circuit breaker cache in BaseConnector (self._breakers) and guard against a missing attribute by creating it on demand. Use a local breaker_cache reference when creating/looking up tenant-specific CircuitBreaker instances and wrap execution with resilience. Add two tests to verify that tenant breakers are cached and that a missing cache is rebuilt. In FHIR Cerner logic, add FHIR JSON headers and read Cerner-specific secrets (private key, kid, client_id) to prepare for token/JWT construction. In FHIR Epic logic, add a codecs import and ensure the returned headers include the Bearer access token. * Delegate FHIR auth to AuthProvider Replace in-file JWT/token exchange logic for Cerner and Epic connectors with calls to the injected AuthProvider (get_auth_headers()). This removes duplicated client_assertion/JWT construction and HTTP token exchange code and lets providers handle token acquisition, scope resolution and caching. Also ensure FHIR Content-Type/Accept headers are present when providers omit them. Cerner-specific: guard retrieval of cerner_token_url and raise a clear ValueError if the URL contains the known malformed '/hosts/' sandbox path. --- sample.env | 5 ++ src/node_wire_fhir_cerner/logic.py | 47 ++++++++++++------- src/node_wire_fhir_epic/logic.py | 40 ++++++++++------ src/node_wire_http_generic/logic.py | 5 +- src/node_wire_runtime/base_connector.py | 26 ++++++++++- src/node_wire_runtime/observability.py | 62 ++++++++++++++++++++++--- src/node_wire_smtp/logic.py | 3 +- tests/test_aot_runtime_basic.py | 25 ++++++++++ 8 files changed, 172 insertions(+), 41 deletions(-) diff --git a/sample.env b/sample.env index beeb748..0cdbc72 100644 --- a/sample.env +++ b/sample.env @@ -93,4 +93,9 @@ NW_REST_LOAD_DOTENV=true # MCP contract (optional; Google Drive legacy payload `action: "upload"`) # NODE_WIRE_LEGACY_GDRIVE_ACTION_UPLOAD=warn # NODE_WIRE_LEGACY_GDRIVE_ACTION_UPLOAD=reject + +# Resilience & Timeout Configurations +# AOT_CONNECTOR_TIMEOUT=30.0 +# AOT_CIRCUIT_BREAKER_FAIL_MAX=5 +# AOT_CIRCUIT_BREAKER_RESET_TIMEOUT=30 NW_ALLOWED_CONNECTORS=fhir_cerner,fhir_epic,google_drive,http_generic,smtp,stripe diff --git a/src/node_wire_fhir_cerner/logic.py b/src/node_wire_fhir_cerner/logic.py index 7ac758c..58054c9 100644 --- a/src/node_wire_fhir_cerner/logic.py +++ b/src/node_wire_fhir_cerner/logic.py @@ -4,6 +4,8 @@ import base64 import json import logging +import os +import uuid from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Optional @@ -111,19 +113,32 @@ async def _get_auth_header(self) -> Dict[str, str]: Returns ready-to-use FHIR request headers including the Bearer token. Token acquisition, JWT construction, scope resolution and caching are - all handled by the provider — no duplication with fhir_epic. + all handled by the provider. """ # Cerner-specific safety check: if a token URL contains '/hosts/', # it is often a malformed sandbox URL that will return 401. - token_url = self._secret_provider.get_secret("cerner_token_url") - if "/hosts/" in token_url: + try: + token_url = self._secret_provider.get_secret("cerner_token_url") + except Exception: + token_url = None + + if token_url and "/hosts/" in token_url: raise ValueError( "Cerner token_url must not contain '/hosts/' (found in secret). " "Ensure you are using the 'smart-v1/token' endpoint, e.g. " "https://authorization.cerner.com/tenants/{tenant}/protocols/oauth2/profiles/smart-v1/token" ) - return await self.get_auth_headers() + + headers = await self.get_auth_headers() + # Ensure FHIR content types are present if the provider didn't include them (e.g. StaticTokenAuthProvider). + if "Content-Type" not in headers: + headers["Content-Type"] = "application/fhir+json" + if "Accept" not in headers: + headers["Accept"] = "application/fhir+json" + + return headers + # ------------------------------------------------------------------ # Internal name-field helpers @@ -202,8 +217,8 @@ async def _read_patient( ) try: - async with httpx.AsyncClient() as client: - response = await client.get(url, headers=auth_header, params=query_params, timeout=30.0) + async with httpx.AsyncClient(timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0"))) as client: + response = await client.get(url, headers=auth_header, params=query_params, timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0"))) response.raise_for_status() except Exception as exc: logger.error("FHIR Patient read failed | error=%s: %s", type(exc).__name__, str(exc), extra={"trace_id": trace_id}) @@ -246,11 +261,11 @@ async def _search_patients( async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[str]]: """Return (rid, resource_or_None, error_or_None).""" try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0"))) as client: resp = await client.get( f"{base_url}/Patient/{rid}", headers=auth_header, - timeout=30.0, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) resp.raise_for_status() return rid, resp.json(), None @@ -297,12 +312,12 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ ) try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0"))) as client: response = await client.get( f"{base_url}/Patient", headers=auth_header, params=name_params, - timeout=30.0, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) response.raise_for_status() except httpx.HTTPStatusError as exc: @@ -358,9 +373,9 @@ async def _search_encounter( auth_header = await self._get_auth_header() try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0"))) as client: response = await client.get( - f"{base_url}/Encounter", headers=auth_header, params=query_params, timeout=30.0, + f"{base_url}/Encounter", headers=auth_header, params=query_params, timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) response.raise_for_status() except httpx.HTTPStatusError as exc: @@ -506,9 +521,9 @@ async def _create_document_reference( logger.info("FHIR DocumentReference create", extra={"trace_id": trace_id}) try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0"))) as client: response = await client.post( - f"{base_url}/DocumentReference", json=doc_ref, headers=auth_header, timeout=30.0, + f"{base_url}/DocumentReference", json=doc_ref, headers=auth_header, timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) response.raise_for_status() except httpx.HTTPStatusError as exc: @@ -578,9 +593,9 @@ async def _search_document_reference( logger.info("FHIR DocumentReference search", extra={"trace_id": trace_id, "search_params": params.search_params}) try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0"))) as client: response = await client.get( - f"{base_url}/DocumentReference", headers=auth_header, params=params.search_params, timeout=30.0, + f"{base_url}/DocumentReference", headers=auth_header, params=params.search_params, timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) response.raise_for_status() except httpx.HTTPStatusError as exc: diff --git a/src/node_wire_fhir_epic/logic.py b/src/node_wire_fhir_epic/logic.py index 94e452d..9132acb 100644 --- a/src/node_wire_fhir_epic/logic.py +++ b/src/node_wire_fhir_epic/logic.py @@ -1,7 +1,10 @@ from __future__ import annotations import asyncio +import codecs import logging +import os +import uuid from datetime import datetime, timezone from typing import Any, Dict, List, Optional @@ -103,9 +106,18 @@ async def _get_auth_header(self) -> Dict[str, str]: """Delegate to the runtime AuthProvider injected by the factory. Returns ready-to-use FHIR request headers including the Bearer token. - Token acquisition and caching are handled by the provider. + Token acquisition, JWT construction, scope resolution and caching are + all handled by the provider. """ - return await self.get_auth_headers() + headers = await self.get_auth_headers() + # Ensure FHIR content types are present if the provider didn't include them (e.g. StaticTokenAuthProvider). + if "Content-Type" not in headers: + headers["Content-Type"] = "application/fhir+json" + if "Accept" not in headers: + headers["Accept"] = "application/fhir+json" + + return headers + @staticmethod def _build_name_search_params( @@ -186,9 +198,9 @@ async def _read_patient( ) try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0"))) as client: response = await client.get( - url, headers=auth_header, params=query_params, timeout=30.0 + url, headers=auth_header, params=query_params, timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")) ) response.raise_for_status() except Exception as exc: @@ -234,11 +246,11 @@ async def _search_patients( async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[str]]: try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0"))) as client: resp = await client.get( f"{base_url}/Patient/{rid}", headers=auth_header, - timeout=30.0, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) resp.raise_for_status() return rid, resp.json(), None @@ -291,12 +303,12 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ ) try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0"))) as client: response = await client.get( f"{base_url}/Patient", headers=auth_header, params=name_params, - timeout=30.0, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) response.raise_for_status() except httpx.HTTPStatusError as exc: @@ -357,12 +369,12 @@ async def _search_encounter( auth_header = await self._get_auth_header() try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0"))) as client: response = await client.get( f"{base_url}/Encounter", headers=auth_header, params=query_params, - timeout=30.0, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) response.raise_for_status() except httpx.HTTPStatusError as exc: @@ -431,12 +443,12 @@ async def _create_document_reference( logger.info("FHIR DocumentReference create", extra={"trace_id": trace_id}) try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0"))) as client: response = await client.post( f"{base_url}/DocumentReference", json=doc_ref, headers=auth_header, - timeout=30.0, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) response.raise_for_status() except httpx.HTTPStatusError as exc: @@ -516,12 +528,12 @@ async def _search_document_reference( ) try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0"))) as client: response = await client.get( f"{base_url}/DocumentReference", headers=auth_header, params=params.search_params, - timeout=30.0, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) response.raise_for_status() except httpx.HTTPStatusError as exc: diff --git a/src/node_wire_http_generic/logic.py b/src/node_wire_http_generic/logic.py index 1df5c7d..922dc98 100644 --- a/src/node_wire_http_generic/logic.py +++ b/src/node_wire_http_generic/logic.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os from typing import Any import httpx @@ -40,7 +41,7 @@ async def request(self, params: HttpRequestInput, *, trace_id: str) -> HttpRespo ) try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0"))) as client: response = await client.request( method=params.method, url=str(params.url), @@ -48,7 +49,7 @@ async def request(self, params: HttpRequestInput, *, trace_id: str) -> HttpRespo params=params.params, json=params.body if isinstance(params.body, (dict, list)) else None, content=None if isinstance(params.body, (dict, list)) else params.body, - timeout=30.0, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) except Exception as exc: # noqa: BLE001 # Let ErrorMapper classify the exception, but log clear context here. diff --git a/src/node_wire_runtime/base_connector.py b/src/node_wire_runtime/base_connector.py index f57a9f4..d3f8f9c 100644 --- a/src/node_wire_runtime/base_connector.py +++ b/src/node_wire_runtime/base_connector.py @@ -2,6 +2,7 @@ import inspect import logging +import os import uuid from abc import ABC from dataclasses import dataclass @@ -300,6 +301,7 @@ def __init__( reset_timeout=30, name=f"{cls.__name__}_breaker", ) + self._breakers: Dict[str, CircuitBreaker] = {} self._client: Any = None @property @@ -411,13 +413,17 @@ async def run( self._policy_hook.check(context) except PolicyDenied as exc: logger.warning( - "Execution blocked by policy hook", + "AUDIT: Execution blocked by policy hook", extra={ "trace_id": trace_id, "connector_id": self.connector_id, "action": self.action, "error_type": type(exc).__name__, "error_message": str(exc), + "audit": True, + "audit_event": "policy_denial", + "tenant_id": tenant_id, + "principal": principal, }, ) mapped = ErrorMapper.resolve(exc) @@ -429,7 +435,23 @@ async def run( trace_id=trace_id, ) - execute_with_resilience = with_resilience(self._breaker) + tenant_key = tenant_id or "default" + breaker_cache = getattr(self, "_breakers", None) + if breaker_cache is None: + breaker_cache = {} + self._breakers = breaker_cache + + if tenant_key not in breaker_cache: + fail_max = int(os.environ.get("AOT_CIRCUIT_BREAKER_FAIL_MAX", "5")) + reset_timeout = int(os.environ.get("AOT_CIRCUIT_BREAKER_RESET_TIMEOUT", "30")) + breaker_cache[tenant_key] = CircuitBreaker( + fail_max=fail_max, + reset_timeout=reset_timeout, + name=f"{self.connector_id}_breaker_{tenant_key}", + ) + + breaker = breaker_cache[tenant_key] + execute_with_resilience = with_resilience(breaker) @execute_with_resilience async def _do_execute(*, trace_id: str) -> Any: diff --git a/src/node_wire_runtime/observability.py b/src/node_wire_runtime/observability.py index e5f6eff..0c8ab57 100644 --- a/src/node_wire_runtime/observability.py +++ b/src/node_wire_runtime/observability.py @@ -10,9 +10,9 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.resources import Resource from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler -from opentelemetry.sdk._logs.export import BatchLogRecordProcessor +from opentelemetry.sdk._logs.export import BatchLogRecordProcessor, LogExporter from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter from opentelemetry.sdk.trace.sampling import ParentBased, TraceIdRatioBased logger = logging.getLogger("runtime.observability") @@ -33,6 +33,56 @@ def filter(self, record: logging.LogRecord) -> bool: # noqa: A003 return True +_SENSITIVE_KEYS = {"patient", "ssn", "secret", "password", "email", "phone", "dob", "encounter", "resourceid"} + +def _is_sensitive(key: str) -> bool: + k = key.lower().replace("_", "").replace("-", "").replace(" ", "") + for s in _SENSITIVE_KEYS: + if s in k: + return True + return False + +class SanitizingSpanExporter(SpanExporter): + def __init__(self, delegate: SpanExporter): + self._delegate = delegate + + def export(self, spans): + for span in spans: + if hasattr(span, "_attributes") and span._attributes: + for k in list(span._attributes.keys()): + if _is_sensitive(k): + span._attributes[k] = "***REDACTED***" + return self._delegate.export(spans) + + def shutdown(self): + return self._delegate.shutdown() + + def force_flush(self, timeout_millis: int = 30000): + if hasattr(self._delegate, "force_flush"): + return self._delegate.force_flush(timeout_millis) + return True + +class SanitizingLogExporter(LogExporter): + def __init__(self, delegate: LogExporter): + self._delegate = delegate + + def export(self, batch): + for record in batch: + if hasattr(record, "attributes") and record.attributes: + for k in list(record.attributes.keys()): + if _is_sensitive(k): + record.attributes[k] = "***REDACTED***" + return self._delegate.export(batch) + + def shutdown(self): + return self._delegate.shutdown() + + def force_flush(self, timeout_millis: int = 30000): + if hasattr(self._delegate, "force_flush"): + return self._delegate.force_flush(timeout_millis) + return True + + def init_observability(app_name: str = "node_wire") -> None: """ Initialize OpenTelemetry + OpenLLMetry/Traceloop for the process. @@ -67,11 +117,11 @@ def init_observability(app_name: str = "node_wire") -> None: otlp_headers: Optional[str] = os.getenv("OTEL_EXPORTER_OTLP_HEADERS") - span_exporter = OTLPSpanExporter( + span_exporter = SanitizingSpanExporter(OTLPSpanExporter( headers=dict( header.split("=", 1) for header in otlp_headers.split(",") ) if otlp_headers else None, - ) + )) span_processor = BatchSpanProcessor(span_exporter) tracer_provider.add_span_processor(span_processor) @@ -79,11 +129,11 @@ def init_observability(app_name: str = "node_wire") -> None: # Logs: export Python logging records via OTLP/HTTP to the local collector. # This enables Loki ingestion when using grafana/otel-lgtm. - log_exporter = OTLPLogExporter( + log_exporter = SanitizingLogExporter(OTLPLogExporter( headers=dict( header.split("=", 1) for header in otlp_headers.split(",") ) if otlp_headers else None, - ) + )) logger_provider = LoggerProvider(resource=resource) logger_provider.add_log_record_processor(BatchLogRecordProcessor(log_exporter)) set_logger_provider(logger_provider) diff --git a/src/node_wire_smtp/logic.py b/src/node_wire_smtp/logic.py index 9bbf150..888a66f 100644 --- a/src/node_wire_smtp/logic.py +++ b/src/node_wire_smtp/logic.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os from email.message import EmailMessage import aiosmtplib @@ -74,7 +75,7 @@ async def send_email(self, params: SmtpSendInput, *, trace_id: str) -> SmtpSendO password=password, use_tls=use_implicit, start_tls=params.use_tls and not use_implicit, - timeout=30.0, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) except Exception as exc: # noqa: BLE001 logger.error( diff --git a/tests/test_aot_runtime_basic.py b/tests/test_aot_runtime_basic.py index 1d956ba..8285046 100644 --- a/tests/test_aot_runtime_basic.py +++ b/tests/test_aot_runtime_basic.py @@ -38,6 +38,31 @@ def test_successful_execution(): assert isinstance(response.trace_id, str) +def test_successful_execution_uses_tenant_breaker_cache(): + connector = DoubleConnector() + + response: ConnectorResponse = asyncio.run( + connector.run({"action": "double", "value": 3}, tenant_id="tenant-a") + ) + + assert response.success is True + assert response.data == {"doubled": 6} + assert "tenant-a" in connector._breakers + + +def test_successful_execution_rebuilds_missing_breaker_cache(): + connector = DoubleConnector() + del connector._breakers + + response: ConnectorResponse = asyncio.run( + connector.run({"action": "double", "value": 4}, tenant_id="tenant-b") + ) + + assert response.success is True + assert response.data == {"doubled": 8} + assert "tenant-b" in connector._breakers + + class CustomError(Exception): pass From 470fc9ee4527deaf0490b6d8b403d9cf4c6b90ca Mon Sep 17 00:00:00 2001 From: Rahul Ap Date: Wed, 6 May 2026 13:23:47 +0530 Subject: [PATCH 24/60] CNP 47 migrate salesforce connector to node wire (#36) * feat: add Salesforce connector for managing CRM records - Introduced a new Salesforce connector to handle CRUD operations for Leads and Contacts. - Implemented OAuth2 authentication with refresh token support. - Added Dockerfile for the Salesforce connector. - Updated pyproject.toml and sample.env to include Salesforce configuration. - Enhanced build scripts to include Salesforce image building. - Created comprehensive tests for Salesforce connector actions. - Documented the Salesforce connector capabilities and usage in salesforce_connector.md. * feat: update Salesforce connector to use model_dump and ConfigDict for field population * feat: add public access to global connector registry and update test for Salesforce connector --- config/connectors.yaml | 162 +++++----- docker-compose.mcp.yml | 8 + docker/salesforce/Dockerfile | 35 +++ docs/connectors.md | 4 + docs/mcp-servers.md | 21 ++ docs/salesforce_connector.md | 95 ++++++ packages/connectors/salesforce/pyproject.toml | 22 ++ packages/connectors/salesforce/setup.py | 16 + playground/app.js | 174 +++++++++++ playground/index.html | 112 +++++++ playground/scenarios.py | 277 ++++++++++++++++++ pyproject.toml | 2 + sample.env | 6 + scripts/build-mcp-images.sh | 6 + scripts/build-packages.sh | 2 + src/agents/salesforce_mcp.py | 27 ++ src/bindings/factory.py | 2 + src/node_wire_runtime/auth/oauth2.py | 45 ++- src/node_wire_runtime/base_connector.py | 5 + src/node_wire_salesforce/logic.py | 168 +++++++++++ src/node_wire_salesforce/registration.py | 6 + src/node_wire_salesforce/schema.py | 132 +++++++++ tests/test_connectors_basic.py | 10 + tests/test_salesforce.py | 251 ++++++++++++++++ 24 files changed, 1511 insertions(+), 77 deletions(-) create mode 100644 docker/salesforce/Dockerfile create mode 100644 docs/salesforce_connector.md create mode 100644 packages/connectors/salesforce/pyproject.toml create mode 100644 packages/connectors/salesforce/setup.py create mode 100644 src/agents/salesforce_mcp.py create mode 100644 src/node_wire_salesforce/logic.py create mode 100644 src/node_wire_salesforce/registration.py create mode 100644 src/node_wire_salesforce/schema.py create mode 100644 tests/test_salesforce.py diff --git a/config/connectors.yaml b/config/connectors.yaml index 2d0f069..7e5940f 100644 --- a/config/connectors.yaml +++ b/config/connectors.yaml @@ -1,75 +1,87 @@ -# connectors.yaml — Node Wire connector configuration -# -# REST API auth (not stored here; set in environment): -# NW_REST_API_KEY — required for /connectors, /playground, /scenarios unless NW_REST_AUTH_DISABLED=true -# -# SECURITY RULE: This file must never contain secrets. -# - Non-sensitive config (base_url, host, port) → safe in YAML -# - Secrets (client_id, private_key, api_key) → environment variables (or cloud backend) -# -connectors: - http_generic: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - # auth: not set — defaults to NoAuthProvider - - smtp: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - host: "smtp.example.com" - port: 587 - from_email: "noreply@example.com" - auth: - provider: static_credentials - username_secret: SMTP_USERNAME - password_secret: SMTP_PASSWORD - - stripe: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - auth: - provider: static_token - secret_key: stripe_api_key - header_name: Authorization - prefix: "" # Stripe expects raw key - - google_drive: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - auth: - provider: service_account - sa_json_secret: GOOGLE_DRIVE_SA_JSON - scopes: - - https://www.googleapis.com/auth/drive - - fhir_epic: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - base_url: "${EPIC_FHIR_BASE_URL:https://fhir.epic.sandbox/api/FHIR/R4}" - auth: - provider: oauth2 - grant_method: private_key_jwt - token_url_secret: EPIC_TOKEN_URL - client_id_secret: EPIC_CLIENT_ID - private_key_secret: EPIC_PRIVATE_KEY - kid_secret: EPIC_KID - algorithm: RS384 - - fhir_cerner: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - base_url: "${CERNER_FHIR_BASE_URL:https://fhir-ehr-code.cerner.com/r4/your-tenant-id}" - auth: - provider: oauth2 - grant_method: private_key_jwt - token_url_secret: CERNER_TOKEN_URL - client_id_secret: CERNER_CLIENT_ID - private_key_secret: CERNER_PRIVATE_KEY - kid_secret: CERNER_KID - algorithm: RS384 - scopes_secret: CERNER_SCOPES - scopes: - - system/Patient.read - - system/Encounter.read - - system/DocumentReference.read - - system/DocumentReference.write \ No newline at end of file +# connectors.yaml — Node Wire connector configuration +# +# REST API auth (not stored here; set in environment): +# NW_REST_API_KEY — required for /connectors, /playground, /scenarios unless NW_REST_AUTH_DISABLED=true +# +# SECURITY RULE: This file must never contain secrets. +# - Non-sensitive config (base_url, host, port) → safe in YAML +# - Secrets (client_id, private_key, api_key) → environment variables (or cloud backend) +# +connectors: + http_generic: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + # auth: not set — defaults to NoAuthProvider + + smtp: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + host: "smtp.example.com" + port: 587 + from_email: "noreply@example.com" + auth: + provider: static_credentials + username_secret: SMTP_USERNAME + password_secret: SMTP_PASSWORD + + stripe: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + auth: + provider: static_token + secret_key: stripe_api_key + header_name: Authorization + prefix: "" # Stripe expects raw key + + google_drive: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + auth: + provider: service_account + sa_json_secret: GOOGLE_DRIVE_SA_JSON + scopes: + - https://www.googleapis.com/auth/drive + + fhir_epic: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + base_url: "${EPIC_FHIR_BASE_URL:https://fhir.epic.sandbox/api/FHIR/R4}" + auth: + provider: oauth2 + grant_method: private_key_jwt + token_url_secret: EPIC_TOKEN_URL + client_id_secret: EPIC_CLIENT_ID + private_key_secret: EPIC_PRIVATE_KEY + kid_secret: EPIC_KID + algorithm: RS384 + + fhir_cerner: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + base_url: "${CERNER_FHIR_BASE_URL:https://fhir-ehr-code.cerner.com/r4/your-tenant-id}" + auth: + provider: oauth2 + grant_method: private_key_jwt + token_url_secret: CERNER_TOKEN_URL + client_id_secret: CERNER_CLIENT_ID + private_key_secret: CERNER_PRIVATE_KEY + kid_secret: CERNER_KID + algorithm: RS384 + scopes_secret: CERNER_SCOPES + scopes: + - system/Patient.read + - system/Encounter.read + - system/DocumentReference.read + - system/DocumentReference.write + + salesforce: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + # instance_url is typically https://yourdomain.my.salesforce.com + auth: + provider: oauth2 + grant_method: refresh_token + token_url_secret: SALESFORCE_TOKEN_URL + client_id_secret: SALESFORCE_CLIENT_ID + client_secret_secret: SALESFORCE_CLIENT_SECRET + refresh_token_secret: SALESFORCE_REFRESH_TOKEN \ No newline at end of file diff --git a/docker-compose.mcp.yml b/docker-compose.mcp.yml index e4024df..f76b6f5 100644 --- a/docker-compose.mcp.yml +++ b/docker-compose.mcp.yml @@ -33,3 +33,11 @@ services: stdin_open: true tty: true restart: unless-stopped + + nw-salesforce: + image: nw-salesforce:latest + env_file: .env + stdin_open: true + tty: true + restart: unless-stopped + diff --git a/docker/salesforce/Dockerfile b/docker/salesforce/Dockerfile new file mode 100644 index 0000000..e255d6a --- /dev/null +++ b/docker/salesforce/Dockerfile @@ -0,0 +1,35 @@ +FROM python:3.12-slim@sha256:3d5ed973e45820f5ba5e46bd065bd88b3a504ff0724d85980dcd05eab361fcf4 + +LABEL org.opencontainers.image.title="nw-salesforce" \ + org.opencontainers.image.description="Node Wire — Salesforce MCP server" \ + org.opencontainers.image.source="https://github.com/AOT-Technologies/node-wire" + +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY src/ ./src/ +COPY config/ ./config/ +# Wheels are optional for local dev builds; build-mcp-images.sh populates them +COPY packages/runtime/dist/*.whl /wheels/ +COPY packages/connectors/salesforce/dist/*.whl /wheels/ + +ENV PYTHONPATH=/app/src \ + NW_ALLOWED_CONNECTORS=salesforce + +RUN pip install --no-cache-dir --find-links=/wheels \ + node-wire-runtime node-wire-salesforce "mcp>=1.6.0" httpx \ + || pip install --no-cache-dir "mcp>=1.6.0" httpx # Fallback if wheels missing + +RUN groupadd --system --gid 1000 app \ + && useradd --system --uid 1000 --gid app --home /app app \ + && chown -R app:app /app + +USER app + +HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ + python -c "from agents.salesforce_mcp import main; assert callable(main); print('ok')" || exit 1 + +CMD ["python", "-m", "agents.salesforce_mcp"] diff --git a/docs/connectors.md b/docs/connectors.md index 30b027a..f14bcf7 100644 --- a/docs/connectors.md +++ b/docs/connectors.md @@ -471,7 +471,9 @@ Published **`input_schema` omits the `action` property** (manifest contract v2+) | `http_generic` | `request` | | `smtp` | `send_email` | | `stripe` | `charge` | +| `salesforce` | `create_lead`, `read_lead`, `update_lead`, `delete_lead`, `create_contact`, `read_contact`, `update_contact`, `delete_contact` | | `google_drive` | `files.list`, `files.upload`, … (see `action_specs`) | + | `fhir_epic` | `read_patient`, `search_patients`, `search_encounter`, `create_document_reference`, `search_document_reference` | | `fhir_cerner` | Same family as Epic with Cerner-specific schemas | @@ -530,4 +532,6 @@ connectors: - [packaging.md](packaging.md) — Wheel build lifecycle, PyPI publish flow, client install model, secrets config, and pre-publish checklist. - [mcp-servers.md](mcp-servers.md) — MCP images, ToolHive, env vars. - [google_drive_connector.md](google_drive_connector.md) — Drive REST API and setup. +- [salesforce_connector.md](salesforce_connector.md) — Salesforce CRM operations and playground. - Per-connector READMEs under `src/node_wire_*/README.md` where present. + diff --git a/docs/mcp-servers.md b/docs/mcp-servers.md index d54ffc7..36b731a 100644 --- a/docs/mcp-servers.md +++ b/docs/mcp-servers.md @@ -50,6 +50,8 @@ flowchart TD | SMART on FHIR (Cerner) | `python -m agents.fhir_cerner_mcp` | `nw-smartonfhir-cerner` | `nw-smartonfhir-cerner` | All manifest actions for `fhir_cerner` (e.g. `fhir_cerner.read_patient`) | | SMTP | `python -m agents.smtp_mcp` | `nw-smtp` | `nw-smtp` | `smtp.send_email` | | Stripe | `python -m agents.stripe_mcp` | `nw-stripe` | `nw-stripe` | All manifest actions for `stripe` (e.g., `stripe.charge`) | +| Salesforce | `python -m agents.salesforce_mcp` | `nw-salesforce` | `nw-salesforce` | All manifest actions for `salesforce` (e.g., `salesforce.create_lead`) | + The unified server (`python -m agents.mcp_entrypoint`) exposes **every** connector enabled for MCP in `config/connectors.yaml` (e.g. `http_generic.request`, `stripe.charge`, `stripe.create_payment_intent`, `stripe.create_subscription`, `stripe.cancel_subscription`, `stripe.issue_refund`, plus the rows above). @@ -307,6 +309,25 @@ FROM_EMAIL=your-email@gmail.com STRIPE_API_KEY=sk_test_4eC39HqLyjWDarjtT1zdp7dc ``` +#### `nw-salesforce` + +| Variable | Description | +|---|---| +| `SALESFORCE_INSTANCE_URL` | Your Salesforce instance URL (e.g., `https://domain.my.salesforce.com`) | +| `SALESFORCE_TOKEN_URL` | OAuth2 token endpoint (usually `https://login.salesforce.com/services/oauth2/token`) | +| `SALESFORCE_CLIENT_ID` | Connected App Client ID | +| `SALESFORCE_CLIENT_SECRET` | Connected App Client Secret | +| `SALESFORCE_REFRESH_TOKEN` | Refresh token with `refresh_token` and `api` scopes | + +```env +SALESFORCE_INSTANCE_URL=https://nodenet.my.salesforce.com +SALESFORCE_TOKEN_URL=https://login.salesforce.com/services/oauth2/token +SALESFORCE_CLIENT_ID=your-client-id +SALESFORCE_CLIENT_SECRET=your-client-secret +SALESFORCE_REFRESH_TOKEN=your-refresh-token +``` + + ### ToolHive / Agent settings | Variable | Description | diff --git a/docs/salesforce_connector.md b/docs/salesforce_connector.md new file mode 100644 index 0000000..264c107 --- /dev/null +++ b/docs/salesforce_connector.md @@ -0,0 +1,95 @@ +# Salesforce Connector (`src/node_wire_salesforce`) + +The Salesforce connector provides a secure, asynchronous interface for managing CRM records (Leads and Contacts). It leverages Node Wire's `OAuth2AuthProvider` to handle token refresh automatically, allowing for seamless integration into agentic workflows and medical-to-CRM pipelines. + +## Capabilities + +The connector exposes full CRUD (Create, Read, Update, Delete) operations for the two most common Salesforce objects used in healthcare and enterprise outreach: + +| Action | Description | +|---|---| +| `create_lead` | Create a new Lead record. Requires `LastName` and `Company`. | +| `read_lead` | Fetch a single Lead record by ID. | +| `update_lead` | Update specific fields on an existing Lead. | +| `delete_lead` | Remove a Lead record. | +| `create_contact` | Create a new Contact record. Requires `LastName`. | +| `read_contact` | Fetch a single Contact record by ID. | +| `update_contact` | Update specific fields on an existing Contact. | +| `delete_contact` | Remove a Contact record. | + +## Configuration + +Add the following to your `config/connectors.yaml`: + +```yaml +connectors: + salesforce: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + auth: + provider: oauth2 + grant_method: refresh_token + token_url_secret: SALESFORCE_TOKEN_URL + client_id_secret: SALESFORCE_CLIENT_ID + client_secret_secret: SALESFORCE_CLIENT_SECRET + refresh_token_secret: SALESFORCE_REFRESH_TOKEN +``` + +## Environment Variables + +The following secrets must be provided (e.g., in `.env` or via your secret manager): + +| Variable | Example | +|---|---| +| `SALESFORCE_INSTANCE_URL` | `https://your-domain.my.salesforce.com` | +| `SALESFORCE_TOKEN_URL` | `https://login.salesforce.com/services/oauth2/token` | +| `SALESFORCE_CLIENT_ID` | `3MVG9...` | +| `SALESFORCE_CLIENT_SECRET` | `A1B2...` | +| `SALESFORCE_REFRESH_TOKEN` | `5Aep...` | + +## Example Usage + +### REST API + +```bash +curl -X POST http://localhost:8000/connectors/salesforce/create_lead \ + -H "X-API-Key: your-key" \ + -H "Content-Type: application/json" \ + -d '{ + "LastName": "Doe", + "Company": "Acme Corp", + "Email": "john.doe@example.com", + "Status": "Open - Not Contacted" + }' +``` + +### Agentic (MCP) + +If registered via MCP, the agent can call `salesforce.create_lead` with the following arguments: + +```json +{ + "LastName": "Smith", + "Company": "HealthTech", + "Email": "jane@smith.com" +} +``` + +## Playground Interface + +The Node Wire playground includes a **CRM Synchronization** panel specifically for Salesforce. This interface allows you to: + +1. **Toggle between Lead and Contact management**: Use the action dropdown to switch contexts. +2. **Execute full CRUD operations**: The form dynamically adjusts based on whether you are creating, reading, updating, or deleting a record. +3. **Real-time Pipeline Visualization**: Watch the synchronization steps (Authentication → Fetch/Update → Verification) in real-time. +4. **Instant Record Validation**: See the exact Salesforce resource IDs and data returned by the API. + +Access the playground at `http://localhost:8000/playground` (when running locally). + +## Security Note + +- **OAuth2**: Tokens are never stored in plain text in logs. Node Wire's `AuthProvider` handles encryption and secure memory storage. +- **Refresh Token Support**: The connector is configured to use `grant_method: refresh_token`, ensuring it can stay authenticated for long-running agentic tasks. +- **Traceability**: All actions are logged with a `trace_id` for auditing and idempotency tracking. +- **PII Protection**: Ensure your logging levels are set correctly; by default, the connector logs the metadata of the transaction but not the full PII payload. + diff --git a/packages/connectors/salesforce/pyproject.toml b/packages/connectors/salesforce/pyproject.toml new file mode 100644 index 0000000..d13034b --- /dev/null +++ b/packages/connectors/salesforce/pyproject.toml @@ -0,0 +1,22 @@ +[project] +name = "node-wire-salesforce" +version = "0.1.0" +description = "Node Wire connector — Salesforce CRM" +requires-python = ">=3.11" +authors = [{ name = "AOT", email = "dev@aot.local" }] + +dependencies = [ + "node-wire-runtime>=0.1.0", + "httpx>=0.27.0", +] + +[project.entry-points."node_wire.connectors"] +salesforce = "node_wire_salesforce.logic" + +[build-system] +requires = ["setuptools>=69.0.0", "cython>=3.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["../../../src"] +include = ["node_wire_salesforce*"] diff --git a/packages/connectors/salesforce/setup.py b/packages/connectors/salesforce/setup.py new file mode 100644 index 0000000..49a3fbf --- /dev/null +++ b/packages/connectors/salesforce/setup.py @@ -0,0 +1,16 @@ +import glob, os +from Cython.Build import cythonize +from setuptools import setup +from setuptools.command.build_py import build_py as _BuildPy + +class NoPyBuild(_BuildPy): + def find_package_modules(self, package, package_dir): + return [] + +src_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../src/node_wire_salesforce")) +py_files = glob.glob(os.path.join(src_root, "**", "*.py"), recursive=True) + +setup( + cmdclass={"build_py": NoPyBuild}, + ext_modules=cythonize(py_files, compiler_directives={"language_level": "3"}, build_dir="build"), +) diff --git a/playground/app.js b/playground/app.js index 4c17acf..87cc564 100644 --- a/playground/app.js +++ b/playground/app.js @@ -64,10 +64,31 @@ document.addEventListener('DOMContentLoaded', () => { cancel_subscription: document.getElementById('stripe-section-cancel'), refund: document.getElementById('stripe-section-refund') }; + + const salesforceForm = document.getElementById('salesforce-form'); + const salesforceRunBtn = document.getElementById('salesforce-run-btn'); + const salesforceSpinner = salesforceRunBtn.querySelector('.loading-spinner'); + const salesforceBtnText = salesforceRunBtn.querySelector('.btn-lbl'); + const salesforcePanel = document.getElementById('salesforce-panel'); + const salesforceActionSelect = document.getElementById('salesforce-action-select'); + const salesforceSections = { + create_lead: document.getElementById('salesforce-section-lead'), + update_lead: document.getElementById('salesforce-section-lead'), + create_contact: document.getElementById('salesforce-section-contact'), + update_contact: document.getElementById('salesforce-section-contact'), + read_lead: document.getElementById('salesforce-section-id-only'), + delete_lead: document.getElementById('salesforce-section-id-only'), + read_contact: document.getElementById('salesforce-section-id-only'), + delete_contact: document.getElementById('salesforce-section-id-only') + }; + + let currentSubMode = 'file'; let currentStripeSubMode = 'charge'; + let currentSalesforceSubMode = 'create_lead'; const connectorStatus = document.getElementById('connector-status'); + const brandLabel = document.querySelector('.brand-text h1 span.accent'); const tagline = document.querySelector('.tagline'); const layoutMain = document.querySelector('.layout-main'); @@ -145,7 +166,34 @@ document.addEventListener('DOMContentLoaded', () => { "Validate Charge", "Process Refund", "Verify Refund" + ], + salesforce_create_lead: [ + "Initialize CRM Sync", + "Create Lead Record", + "Verify Lead Status" + ], + salesforce_create_contact: [ + "Initialize CRM Sync", + "Create Contact Record", + "Verify Contact Status" + ], + salesforce_read: [ + "Authenticate CRM", + "Fetch Record Metadata", + "Verify Data Integrity" + ], + salesforce_update: [ + "Authenticate CRM", + "Apply Partial Update", + "Verify State Change" + ], + salesforce_delete: [ + "Authenticate CRM", + "Execute Soft Delete", + "Verify Termination" ] + + }; const nodes = [ @@ -381,6 +429,48 @@ document.addEventListener('DOMContentLoaded', () => { return pipelineLabels.stripe_charge; } + function salesforcePipelineLabelOverride() { + if (currentSalesforceSubMode.startsWith('create')) return pipelineLabels.salesforce_create_lead; + if (currentSalesforceSubMode.startsWith('read')) return pipelineLabels.salesforce_read; + if (currentSalesforceSubMode.startsWith('update')) return pipelineLabels.salesforce_update; + if (currentSalesforceSubMode.startsWith('delete')) return pipelineLabels.salesforce_delete; + return pipelineLabels.salesforce_create_lead; + } + + function syncSalesforceActionForm() { + Object.values(salesforceSections).forEach(sec => { + if (sec) sec.classList.add('hidden'); + }); + const activeSec = salesforceSections[currentSalesforceSubMode] || salesforceSections['create_lead']; + if (activeSec) activeSec.classList.remove('hidden'); + + // Handle record ID field visibility in Lead/Contact sections + const idFields = document.querySelectorAll('#salesforce-form .id-field'); + idFields.forEach(f => { + if (currentSalesforceSubMode.startsWith('update')) { + f.classList.remove('hidden'); + } else { + f.classList.add('hidden'); + } + }); + + // Handle generic ID label for read/delete + const idLabel = document.getElementById('sf-resource-id-label'); + if (idLabel) { + if (currentSalesforceSubMode.includes('lead')) { + idLabel.textContent = 'Lead Record ID'; + } else { + idLabel.textContent = 'Contact Record ID'; + } + } + + if (salesforceActionSelect) { + salesforceActionSelect.value = currentSalesforceSubMode; + } + } + + + function syncStripeActionForm() { Object.values(stripeSections).forEach(sec => { if (sec) sec.classList.add('hidden'); @@ -402,8 +492,10 @@ document.addEventListener('DOMContentLoaded', () => { cernerPanel.classList.add('hidden'); gdrivePanel.classList.add('hidden'); stripePanel.classList.add('hidden'); + salesforcePanel.classList.add('hidden'); if (mode === 'ehr') { + ehrPanel.classList.remove('hidden'); connectorStatus.textContent = 'Epic R4 Online'; tagline.textContent = 'Enterprise EHR Orchestration'; @@ -433,6 +525,12 @@ document.addEventListener('DOMContentLoaded', () => { tagline.textContent = 'Financial Infrastructure'; document.documentElement.style.setProperty('--brand-accent', '#635bff'); log('Switched to Stripe Payment Orchestration mode', 'system'); + } else if (mode === 'salesforce') { + salesforcePanel.classList.remove('hidden'); + connectorStatus.textContent = 'Salesforce Online'; + tagline.textContent = 'CRM Orchestration'; + document.documentElement.style.setProperty('--brand-accent', '#00A1E0'); + log('Switched to Salesforce CRM Orchestration mode', 'system'); } if (mode === 'gdrive') { syncGdriveActionForm(); @@ -440,9 +538,13 @@ document.addEventListener('DOMContentLoaded', () => { } else if (mode === 'stripe') { syncStripeActionForm(); resetUI(stripePipelineLabelOverride()); + } else if (mode === 'salesforce') { + syncSalesforceActionForm(); + resetUI(salesforcePipelineLabelOverride()); } else { resetUI(); } + } // Root Tab Switching (MCP Orchestration vs Connectors) @@ -766,6 +868,78 @@ document.addEventListener('DOMContentLoaded', () => { await handleSubmission(submitPayload, endpoint, stripeRunBtn, stripeBtnText, stripeSpinner, 'Process Action'); }); + salesforceForm.addEventListener('submit', async (e) => { + e.preventDefault(); + const formData = new FormData(salesforceForm); + const payload = Object.fromEntries(formData.entries()); + + let endpoint = '/scenarios/salesforce-create-lead'; + let submitPayload = {}; + + if (currentSalesforceSubMode === 'create_lead') { + submitPayload = { + first_name: payload.lead_first_name || undefined, + last_name: payload.lead_last_name, + company: payload.lead_company, + email: payload.lead_email || undefined + }; + endpoint = '/scenarios/salesforce-create-lead'; + } else if (currentSalesforceSubMode === 'update_lead') { + submitPayload = { + record_id: payload.lead_id, + first_name: payload.lead_first_name || undefined, + last_name: payload.lead_last_name || undefined, + company: payload.lead_company || undefined, + email: payload.lead_email || undefined + }; + endpoint = '/scenarios/salesforce-update-lead'; + } else if (currentSalesforceSubMode === 'read_lead') { + submitPayload = { record_id: payload.generic_record_id }; + endpoint = '/scenarios/salesforce-read-lead'; + } else if (currentSalesforceSubMode === 'delete_lead') { + submitPayload = { record_id: payload.generic_record_id }; + endpoint = '/scenarios/salesforce-delete-lead'; + } else if (currentSalesforceSubMode === 'create_contact') { + submitPayload = { + first_name: payload.contact_first_name || undefined, + last_name: payload.contact_last_name, + email: payload.contact_email || undefined, + account_id: payload.contact_account_id || undefined + }; + endpoint = '/scenarios/salesforce-create-contact'; + } else if (currentSalesforceSubMode === 'update_contact') { + submitPayload = { + record_id: payload.contact_id, + first_name: payload.contact_first_name || undefined, + last_name: payload.contact_last_name || undefined, + email: payload.contact_email || undefined, + account_id: payload.contact_account_id || undefined + }; + endpoint = '/scenarios/salesforce-update-contact'; + } else if (currentSalesforceSubMode === 'read_contact') { + submitPayload = { record_id: payload.generic_record_id }; + endpoint = '/scenarios/salesforce-read-contact'; + } else if (currentSalesforceSubMode === 'delete_contact') { + submitPayload = { record_id: payload.generic_record_id }; + endpoint = '/scenarios/salesforce-delete-contact'; + } + + await handleSubmission(submitPayload, endpoint, salesforceRunBtn, salesforceBtnText, salesforceSpinner, 'Execute Action'); + }); + + + if (salesforceActionSelect) { + salesforceActionSelect.addEventListener('change', (e) => { + const mode = e.target.value; + if (mode === currentSalesforceSubMode) return; + currentSalesforceSubMode = mode; + syncSalesforceActionForm(); + resetUI(salesforcePipelineLabelOverride()); + log(`Switched to Salesforce mode [${currentSalesforceSubMode}]`); + }); + } + + if (stripeActionSelect) { stripeActionSelect.addEventListener('change', (e) => { const mode = e.target.value; diff --git a/playground/index.html b/playground/index.html index 9d2f331..3d39c0f 100644 --- a/playground/index.html +++ b/playground/index.html @@ -218,6 +218,20 @@

Stripe

Financial transaction and subscription management infrastructure.

+ +
+
+ + + + +
+
+

Salesforce

+

Lead and contact management for CRM-driven enterprise workflows.

+
+
+ @@ -627,8 +641,106 @@ + + +

Smart Pipeline

diff --git a/playground/scenarios.py b/playground/scenarios.py index 558c72c..5f69bd7 100644 --- a/playground/scenarios.py +++ b/playground/scenarios.py @@ -42,6 +42,20 @@ FilesUpdateOperation, ) from node_wire_stripe.schema import ChargeInput +from node_wire_salesforce.logic import SalesforceConnector +from node_wire_salesforce.schema import ( + CreateLeadInput, + ReadLeadInput, + UpdateLeadInput, + DeleteLeadInput, + CreateContactInput, + ReadContactInput, + UpdateContactInput, + DeleteContactInput, + SalesforceOperationOutput, +) + + logger = logging.getLogger("playground.scenarios") router = APIRouter(prefix="/scenarios", tags=["scenarios"]) @@ -144,6 +158,36 @@ def require_upload_fields_when_not_list(self) -> "GoogleDriveArchivalInput": raise ValueError("document_name and recipient_email are required for archival upload actions") return self +class SalesforceLeadInputPlayground(BaseModel): + last_name: str + company: str + first_name: Optional[str] = None + email: Optional[str] = None + status: str = "Open - Not Contacted" + +class SalesforceContactInputPlayground(BaseModel): + last_name: str + first_name: Optional[str] = None + email: Optional[str] = None + account_id: Optional[str] = None + +class SalesforceGenericIdInputPlayground(BaseModel): + record_id: str + +class SalesforceUpdateLeadInputPlayground(BaseModel): + record_id: str + first_name: Optional[str] = None + last_name: Optional[str] = None + company: Optional[str] = None + email: Optional[str] = None + +class SalesforceUpdateContactInputPlayground(BaseModel): + record_id: str + first_name: Optional[str] = None + last_name: Optional[str] = None + email: Optional[str] = None + account_id: Optional[str] = None + class ScenarioStep(BaseModel): name: str status: str # "pending", "success", "error" @@ -273,6 +317,14 @@ def get_stripe_connector(): return connector +def get_salesforce_connector(): + connector = resolve_connector("salesforce") + if not connector: + raise HTTPException(status_code=500, detail="Salesforce connector not configured") + return connector + + + @router.post("/post-consultation", response_model=ScenarioResponse) async def post_consultation_scenario( payload: PostConsultationInput, @@ -1695,3 +1747,228 @@ async def stream_events(): }) + "\n" return StreamingResponse(stream_events(), media_type="application/x-ndjson") +@router.post("/salesforce-create-lead", response_model=ScenarioResponse) +async def salesforce_create_lead_scenario( + payload: SalesforceLeadInputPlayground, + connector: SalesforceConnector = Depends(get_salesforce_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + + def add_step(name: str, status: str, details: str = "", display_name: str = "", data: Any = None): + steps.append(ScenarioStep(name=name, status=status, details=details, display_name=display_name, data=data)) + + add_step("Create Lead", "pending", display_name="Create Salesforce Lead") + + sf_input = CreateLeadInput( + LastName=payload.last_name, + Company=payload.company, + FirstName=payload.first_name, + Email=payload.email, + Status=payload.status + ) + + try: + res = await execute_with_retry(connector, sf_input, trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = "Lead record created" + steps[-1].data = {"resource_id": res.resource_id, "raw": res.data} + return ScenarioResponse( + success=True, + trace_id=trace_id, + steps=steps, + final_resource_id=res.resource_id, + human_summary=f"Salesforce Lead created successfully with ID: {res.resource_id}" + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Lead creation failed") + +@router.post("/salesforce-create-contact", response_model=ScenarioResponse) +async def salesforce_create_contact_scenario( + payload: SalesforceContactInputPlayground, + connector: SalesforceConnector = Depends(get_salesforce_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + + def add_step(name: str, status: str, details: str = "", display_name: str = "", data: Any = None): + steps.append(ScenarioStep(name=name, status=status, details=details, display_name=display_name, data=data)) + + add_step("Create Contact", "pending", display_name="Create Salesforce Contact") + + sf_input = CreateContactInput( + LastName=payload.last_name, + FirstName=payload.first_name, + Email=payload.email, + AccountId=payload.account_id + ) + + try: + res = await execute_with_retry(connector, sf_input, trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = "Contact record created" + steps[-1].data = {"resource_id": res.resource_id, "raw": res.data} + return ScenarioResponse( + success=True, + trace_id=trace_id, + steps=steps, + final_resource_id=res.resource_id, + human_summary=f"Salesforce Contact created successfully with ID: {res.resource_id}" + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Contact creation failed") + +@router.post("/salesforce-read-lead", response_model=ScenarioResponse) +async def salesforce_read_lead_scenario( + payload: SalesforceGenericIdInputPlayground, + connector: SalesforceConnector = Depends(get_salesforce_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + def add_step(name, status, display_name): + steps.append(ScenarioStep(name=name, status=status, display_name=display_name)) + add_step("Read Lead", "pending", "Fetching Lead Details") + try: + res = await execute_with_retry(connector, ReadLeadInput(record_id=payload.record_id), trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = "Lead data retrieved" + steps[-1].data = res.data + return ScenarioResponse(success=True, trace_id=trace_id, steps=steps, human_summary=f"Lead data retrieved for {payload.record_id}", final_resource_id=payload.record_id) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Read failed") + +@router.post("/salesforce-update-lead", response_model=ScenarioResponse) +async def salesforce_update_lead_scenario( + payload: SalesforceUpdateLeadInputPlayground, + connector: SalesforceConnector = Depends(get_salesforce_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + def add_step(name, status, display_name): + steps.append(ScenarioStep(name=name, status=status, display_name=display_name)) + add_step("Update Lead", "pending", "Updating Lead Record") + fields = {k: v for k, v in payload.model_dump().items() if v is not None and k != "record_id"} + # Map to SF internal names + sf_fields = {} + if "first_name" in fields: sf_fields["FirstName"] = fields["first_name"] + if "last_name" in fields: sf_fields["LastName"] = fields["last_name"] + if "company" in fields: sf_fields["Company"] = fields["company"] + if "email" in fields: sf_fields["Email"] = fields["email"] + + try: + res = await execute_with_retry(connector, UpdateLeadInput(record_id=payload.record_id, fields=sf_fields), trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = "Lead updated" + # Salesforce PATCH returns 204 No Content, so we show the sent fields as confirmation + steps[-1].data = {"record_id": payload.record_id, "updated_fields": sf_fields, "raw": res.data} + return ScenarioResponse( + success=True, + trace_id=trace_id, + steps=steps, + final_resource_id=payload.record_id, + human_summary=f"Lead {payload.record_id} updated successfully." + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Update failed") + +@router.post("/salesforce-delete-lead", response_model=ScenarioResponse) +async def salesforce_delete_lead_scenario( + payload: SalesforceGenericIdInputPlayground, + connector: SalesforceConnector = Depends(get_salesforce_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + def add_step(name, status, display_name): + steps.append(ScenarioStep(name=name, status=status, display_name=display_name)) + add_step("Delete Lead", "pending", "Removing Lead Record") + try: + res = await execute_with_retry(connector, DeleteLeadInput(record_id=payload.record_id), trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = "Lead deleted" + return ScenarioResponse( + success=True, + trace_id=trace_id, + steps=steps, + final_resource_id=payload.record_id, + human_summary=f"Lead {payload.record_id} deleted." + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Delete failed") + +@router.post("/salesforce-read-contact", response_model=ScenarioResponse) +async def salesforce_read_contact_scenario( + payload: SalesforceGenericIdInputPlayground, + connector: SalesforceConnector = Depends(get_salesforce_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + def add_step(name, status, display_name): + steps.append(ScenarioStep(name=name, status=status, display_name=display_name)) + add_step("Read Contact", "pending", "Fetching Contact Details") + try: + res = await execute_with_retry(connector, ReadContactInput(record_id=payload.record_id), trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = "Contact data retrieved" + steps[-1].data = res.data + return ScenarioResponse(success=True, trace_id=trace_id, steps=steps, human_summary=f"Contact data retrieved for {payload.record_id}", final_resource_id=payload.record_id) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Read failed") + +@router.post("/salesforce-update-contact", response_model=ScenarioResponse) +async def salesforce_update_contact_scenario( + payload: SalesforceUpdateContactInputPlayground, + connector: SalesforceConnector = Depends(get_salesforce_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + def add_step(name, status, display_name): + steps.append(ScenarioStep(name=name, status=status, display_name=display_name)) + add_step("Update Contact", "pending", "Updating Contact Record") + fields = {k: v for k, v in payload.model_dump().items() if v is not None and k != "record_id"} + sf_fields = {} + if "first_name" in fields: sf_fields["FirstName"] = fields["first_name"] + if "last_name" in fields: sf_fields["LastName"] = fields["last_name"] + if "email" in fields: sf_fields["Email"] = fields["email"] + if "account_id" in fields: sf_fields["AccountId"] = fields["account_id"] + + try: + res = await execute_with_retry(connector, UpdateContactInput(record_id=payload.record_id, fields=sf_fields), trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = "Contact updated" + # Salesforce PATCH returns 204 No Content, so we show the sent fields as confirmation + steps[-1].data = {"record_id": payload.record_id, "updated_fields": sf_fields, "raw": res.data} + return ScenarioResponse( + success=True, + trace_id=trace_id, + steps=steps, + final_resource_id=payload.record_id, + human_summary=f"Contact {payload.record_id} updated successfully." + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Update failed") + +@router.post("/salesforce-delete-contact", response_model=ScenarioResponse) +async def salesforce_delete_contact_scenario( + payload: SalesforceGenericIdInputPlayground, + connector: SalesforceConnector = Depends(get_salesforce_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + def add_step(name, status, display_name): + steps.append(ScenarioStep(name=name, status=status, display_name=display_name)) + add_step("Delete Contact", "pending", "Removing Contact Record") + try: + res = await execute_with_retry(connector, DeleteContactInput(record_id=payload.record_id), trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = "Contact deleted" + return ScenarioResponse( + success=True, + trace_id=trace_id, + steps=steps, + final_resource_id=payload.record_id, + human_summary=f"Contact {payload.record_id} deleted." + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Delete failed") + + diff --git a/pyproject.toml b/pyproject.toml index c87ec5b..cb4043d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,8 @@ stripe = "node_wire_stripe.logic" google_drive = "node_wire_google_drive.logic" fhir_epic = "node_wire_fhir_epic.logic" fhir_cerner = "node_wire_fhir_cerner.logic" +salesforce = "node_wire_salesforce.logic" + [tool.setuptools.packages.find] where = ["src"] diff --git a/sample.env b/sample.env index 0cdbc72..8540c6a 100644 --- a/sample.env +++ b/sample.env @@ -99,3 +99,9 @@ NW_REST_LOAD_DOTENV=true # AOT_CIRCUIT_BREAKER_FAIL_MAX=5 # AOT_CIRCUIT_BREAKER_RESET_TIMEOUT=30 NW_ALLOWED_CONNECTORS=fhir_cerner,fhir_epic,google_drive,http_generic,smtp,stripe +# Salesforce CRM +SALESFORCE_INSTANCE_URL=https://your-instance.my.salesforce.com +SALESFORCE_TOKEN_URL=https://login.salesforce.com/services/oauth2/token +SALESFORCE_CLIENT_ID=your-client-id +SALESFORCE_CLIENT_SECRET=your-client-secret +SALESFORCE_REFRESH_TOKEN=your-refresh-token diff --git a/scripts/build-mcp-images.sh b/scripts/build-mcp-images.sh index 5e5c0eb..714a3c1 100755 --- a/scripts/build-mcp-images.sh +++ b/scripts/build-mcp-images.sh @@ -71,5 +71,11 @@ docker build -f docker/stripe/Dockerfile \ -t "nw-stripe:${VERSION}" \ . +docker build -f docker/salesforce/Dockerfile \ + -t nw-salesforce:latest \ + -t "nw-salesforce:${VERSION}" \ + . + + echo "Done." diff --git a/scripts/build-packages.sh b/scripts/build-packages.sh index b8ff256..92b525c 100755 --- a/scripts/build-packages.sh +++ b/scripts/build-packages.sh @@ -32,9 +32,11 @@ ALL_PACKAGES=( packages/connectors/fhir_cerner packages/connectors/smtp packages/connectors/stripe + packages/connectors/salesforce packages/connectors/http_generic ) + usage() { cat <<'USAGE' Usage: diff --git a/src/agents/salesforce_mcp.py b/src/agents/salesforce_mcp.py new file mode 100644 index 0000000..d0da72b --- /dev/null +++ b/src/agents/salesforce_mcp.py @@ -0,0 +1,27 @@ +"""MCP Server — Salesforce connector only. Usage: python -m agents.salesforce_mcp""" +from __future__ import annotations + +import logging +import os + +from dotenv import load_dotenv + +load_dotenv() +load_dotenv(os.path.join(os.path.dirname(__file__), ".env")) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("agents.salesforce_mcp") + + +def main() -> None: + from bindings.mcp_server.server import McpServer + + logger.info("Starting nw-salesforce MCP server (stdio, manifest-driven)") + McpServer( + server_name="nw-salesforce", + connector_ids=["salesforce"], + ).run_stdio() + + +if __name__ == "__main__": + main() diff --git a/src/bindings/factory.py b/src/bindings/factory.py index 63267bd..c931144 100644 --- a/src/bindings/factory.py +++ b/src/bindings/factory.py @@ -178,8 +178,10 @@ def _build_auth_provider(self, connector_id: str, cfg: dict) -> Any: private_key_secret=auth_cfg.get("private_key_secret"), kid_secret=auth_cfg.get("kid_secret"), client_secret_secret=auth_cfg.get("client_secret_secret"), + refresh_token_secret=auth_cfg.get("refresh_token_secret"), scopes=auth_cfg.get("scopes"), scopes_secret=auth_cfg.get("scopes_secret"), + extra_content_type_headers=auth_cfg.get("extra_headers"), buffer_secs=int(auth_cfg.get("buffer_secs", 60)), jwt_ttl_secs=int(auth_cfg.get("jwt_ttl_secs", 300)), diff --git a/src/node_wire_runtime/auth/oauth2.py b/src/node_wire_runtime/auth/oauth2.py index 6ee52ce..505950e 100644 --- a/src/node_wire_runtime/auth/oauth2.py +++ b/src/node_wire_runtime/auth/oauth2.py @@ -88,16 +88,17 @@ def __init__( private_key_secret: Optional[str] = None, kid_secret: Optional[str] = None, client_secret_secret: Optional[str] = None, + refresh_token_secret: Optional[str] = None, scopes: Optional[List[str]] = None, scopes_secret: Optional[str] = None, extra_content_type_headers: Optional[Dict[str, str]] = None, buffer_secs: int = _DEFAULT_BUFFER_SECS, jwt_ttl_secs: int = 300, ) -> None: - if grant_method not in ("private_key_jwt", "client_secret_post"): + if grant_method not in ("private_key_jwt", "client_secret_post", "refresh_token"): raise ValueError( f"Unsupported grant_method {grant_method!r}. " - "Use 'private_key_jwt' or 'client_secret_post'." + "Use 'private_key_jwt', 'client_secret_post', or 'refresh_token'." ) self._sp = secret_provider self._grant_method = grant_method @@ -107,6 +108,8 @@ def __init__( self._private_key_secret = private_key_secret self._kid_secret = kid_secret self._client_secret_secret = client_secret_secret + self._refresh_token_secret = refresh_token_secret + self._static_scopes = scopes self._scopes_secret = scopes_secret self._extra_headers: Dict[str, str] = ( @@ -224,8 +227,46 @@ async def _fetch_token(self) -> Dict[str, Any]: """Dispatch to the appropriate grant method implementation.""" if self._grant_method == "private_key_jwt": return await self._fetch_private_key_jwt() + if self._grant_method == "refresh_token": + return await self._fetch_refresh_token() return await self._fetch_client_secret_post() + async def _fetch_refresh_token(self) -> Dict[str, Any]: + """Exchange refresh_token for a new access token.""" + if not self._refresh_token_secret: + raise ValueError( + "OAuth2AuthProvider (refresh_token): " + "'refresh_token_secret' must be configured." + ) + + client_id = self._sp.get_secret(self._client_id_secret) + client_secret = ( + self._sp.get_secret(self._client_secret_secret) + if self._client_secret_secret + else None + ) + refresh_token = self._sp.get_secret(self._refresh_token_secret) + token_url = self._sp.get_secret(self._token_url_secret) + + post_data: Dict[str, str] = { + "grant_type": "refresh_token", + "client_id": client_id, + "refresh_token": refresh_token, + } + if client_secret: + post_data["client_secret"] = client_secret + + scope = self._resolve_scopes() + if scope: + post_data["scope"] = scope + + logger.debug( + "OAuth2AuthProvider: refresh_token token request", + extra={"token_url": token_url}, + ) + return await self._post_token(token_url, post_data) + + async def _fetch_private_key_jwt(self) -> Dict[str, Any]: """ Exchange a signed JWT assertion for an access token. diff --git a/src/node_wire_runtime/base_connector.py b/src/node_wire_runtime/base_connector.py index d3f8f9c..01ca344 100644 --- a/src/node_wire_runtime/base_connector.py +++ b/src/node_wire_runtime/base_connector.py @@ -495,6 +495,11 @@ async def _do_execute(*, trace_id: str) -> Any: trace_id=trace_id, ) + @classmethod + def get_registry(cls) -> Dict[str, Type[BaseConnector]]: + """Public access to the global connector registry.""" + return dict(_CONNECTOR_REGISTRY) + @classmethod def sdk_action_metas(cls) -> Dict[str, NwActionMeta]: """Registry of action name -> metadata (for manifest/ingress).""" diff --git a/src/node_wire_salesforce/logic.py b/src/node_wire_salesforce/logic.py new file mode 100644 index 0000000..6c2dcf5 --- /dev/null +++ b/src/node_wire_salesforce/logic.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional, Union, Tuple, Type +import httpx + +from node_wire_runtime import BaseConnector, nw_action +from node_wire_runtime.models import ErrorCategory +from .schema import ( + CreateLeadInput, + CreateContactInput, + ReadLeadInput, + UpdateLeadInput, + DeleteLeadInput, + ReadContactInput, + UpdateContactInput, + DeleteContactInput, + SalesforceOperationOutput, + SalesforceError, +) + +logger = logging.getLogger("connectors.salesforce") + +class SalesforceTransientError(httpx.HTTPStatusError): + """Exception for transient Salesforce errors that should be retried.""" + pass + +class SalesforceConnector(BaseConnector): + """Salesforce connector for managing Leads and Contacts.""" + + connector_id = "salesforce" + action = "execute" # Multi-action dispatcher + output_model = SalesforceOperationOutput + + error_map: Dict[Type[BaseException], Tuple[ErrorCategory, str]] = { + httpx.ConnectError: (ErrorCategory.RETRYABLE, "SALESFORCE_CONNECT_ERROR"), + httpx.TimeoutException: (ErrorCategory.RETRYABLE, "SALESFORCE_TIMEOUT"), + SalesforceTransientError: (ErrorCategory.RETRYABLE, "SALESFORCE_TRANSIENT_ERROR"), + httpx.HTTPStatusError: (ErrorCategory.BUSINESS, "SALESFORCE_API_ERROR"), + } + + def _get_base_url(self) -> str: + return self.secret_provider.get_secret("salesforce_instance_url").rstrip("/") + + def _get_api_version(self) -> str: + return "v58.0" + + async def _get_auth_headers(self) -> Dict[str, str]: + return await self.get_auth_headers() + + @nw_action("create_lead") + async def create_lead( + self, params: CreateLeadInput, *, trace_id: str + ) -> SalesforceOperationOutput: + return await self._execute_rest("POST", "Lead", params.model_dump(by_alias=True, exclude={"action"}), trace_id) + + @nw_action("read_lead") + async def read_lead( + self, params: ReadLeadInput, *, trace_id: str + ) -> SalesforceOperationOutput: + return await self._execute_rest("GET", f"Lead/{params.record_id}", None, trace_id) + + @nw_action("update_lead") + async def update_lead( + self, params: UpdateLeadInput, *, trace_id: str + ) -> SalesforceOperationOutput: + return await self._execute_rest("PATCH", f"Lead/{params.record_id}", params.fields, trace_id) + + @nw_action("delete_lead") + async def delete_lead( + self, params: DeleteLeadInput, *, trace_id: str + ) -> SalesforceOperationOutput: + return await self._execute_rest("DELETE", f"Lead/{params.record_id}", None, trace_id) + + @nw_action("create_contact") + async def create_contact( + self, params: CreateContactInput, *, trace_id: str + ) -> SalesforceOperationOutput: + return await self._execute_rest("POST", "Contact", params.model_dump(by_alias=True, exclude={"action"}), trace_id) + + @nw_action("read_contact") + async def read_contact( + self, params: ReadContactInput, *, trace_id: str + ) -> SalesforceOperationOutput: + return await self._execute_rest("GET", f"Contact/{params.record_id}", None, trace_id) + + @nw_action("update_contact") + async def update_contact( + self, params: UpdateContactInput, *, trace_id: str + ) -> SalesforceOperationOutput: + return await self._execute_rest("PATCH", f"Contact/{params.record_id}", params.fields, trace_id) + + @nw_action("delete_contact") + async def delete_contact( + self, params: DeleteContactInput, *, trace_id: str + ) -> SalesforceOperationOutput: + return await self._execute_rest("DELETE", f"Contact/{params.record_id}", None, trace_id) + + async def _execute_rest( + self, method: str, path: str, payload: Optional[Dict[str, Any]], trace_id: str + ) -> SalesforceOperationOutput: + base_url = self._get_base_url() + api_version = self._get_api_version() + url = f"{base_url}/services/data/{api_version}/sobjects/{path}" + + headers = await self._get_auth_headers() + if payload: + headers["Content-Type"] = "application/json" + if isinstance(payload, dict): + payload = {k: v for k, v in payload.items() if v is not None} + + logger.info( + "Executing Salesforce REST call", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "method": method, + "path": path, + } + ) + + async with httpx.AsyncClient() as client: + try: + response = await client.request(method, url, headers=headers, json=payload, timeout=30.0) + + # Handle transient errors (5xx) by raising a retryable exception + if response.status_code >= 500: + raise SalesforceTransientError( + message=f"Salesforce server error: {response.status_code}", + request=response.request, + response=response + ) + + response.raise_for_status() + + data = {} + if response.content: + try: + data = response.json() + except Exception: + data = {"text": response.text} + + obj_type = path.split("/")[0] + res_id = data.get("id") or data.get("Id") if isinstance(data, dict) else None + + if not res_id and "/" in path: + res_id = path.split("/")[1] + + return SalesforceOperationOutput( + success=True, + resource_id=res_id, + resource_type=obj_type, + data=data + ) + except Exception as exc: + # We log and re-raise to let the platform (ErrorMapper + Resilience) handle it + logger.error( + "Salesforce REST call failed", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "method": method, + "path": path, + "error_type": type(exc).__name__, + "error_message": str(exc), + } + ) + raise diff --git a/src/node_wire_salesforce/registration.py b/src/node_wire_salesforce/registration.py new file mode 100644 index 0000000..86de1be --- /dev/null +++ b/src/node_wire_salesforce/registration.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +# Salesforce registration module. +# Mappings for httpx errors are now handled directly in SalesforceConnector.error_map +# in logic.py, which BaseConnector registers automatically. +# This module remains for package-level registration/discovery side effects. diff --git a/src/node_wire_salesforce/schema.py b/src/node_wire_salesforce/schema.py new file mode 100644 index 0000000..fc7ee0b --- /dev/null +++ b/src/node_wire_salesforce/schema.py @@ -0,0 +1,132 @@ +from typing import Any, Dict, List, Literal, Optional, Union +from pydantic import BaseModel, Field, field_validator, ConfigDict +import re + +SALESFORCE_ID_REGEX = re.compile(r"^[a-zA-Z0-9]{15,18}$") + +class SalesforceError(BaseModel): + message: str + code: Optional[str] = None + fields: Optional[List[str]] = None + +class SalesforceOperationOutput(BaseModel): + success: bool = True + resource_id: Optional[str] = None + resource_type: Optional[str] = None + data: Optional[Dict[str, Any]] = None + errors: Optional[List[SalesforceError]] = None + +# Creation Models +class CreateLeadInput(BaseModel): + action: Literal["create_lead"] = "create_lead" + last_name: str = Field(..., alias="LastName") + company: str = Field(..., alias="Company") + first_name: Optional[str] = Field(None, alias="FirstName") + title: Optional[str] = Field(None, alias="Title") + email: Optional[str] = Field(None, alias="Email") + phone: Optional[str] = Field(None, alias="Phone") + mobile_phone: Optional[str] = Field(None, alias="MobilePhone") + street: Optional[str] = Field(None, alias="Street") + city: Optional[str] = Field(None, alias="City") + state: Optional[str] = Field(None, alias="State") + postal_code: Optional[str] = Field(None, alias="PostalCode") + country: Optional[str] = Field(None, alias="Country") + description: Optional[str] = Field(None, alias="Description") + lead_source: Optional[str] = Field(None, alias="LeadSource") + status: Optional[str] = Field(None, alias="Status") + rating: Optional[str] = Field(None, alias="Rating") + website: Optional[str] = Field(None, alias="Website") + number_of_employees: Optional[int] = Field(None, alias="NumberOfEmployees") + industry: Optional[str] = Field(None, alias="Industry") + annual_revenue: Optional[float] = Field(None, alias="AnnualRevenue") + + model_config = ConfigDict(populate_by_name=True) + +class CreateContactInput(BaseModel): + action: Literal["create_contact"] = "create_contact" + last_name: str = Field(..., alias="LastName") + first_name: Optional[str] = Field(None, alias="FirstName") + account_id: Optional[str] = Field(None, alias="AccountId") + title: Optional[str] = Field(None, alias="Title") + + @field_validator("account_id") + @classmethod + def validate_account_id(cls, v: Optional[str]) -> Optional[str]: + if v and not SALESFORCE_ID_REGEX.match(v): + raise ValueError("Invalid Salesforce AccountId format (must be 15 or 18 alphanumeric characters)") + return v + + email: Optional[str] = Field(None, alias="Email") + phone: Optional[str] = Field(None, alias="Phone") + mobile_phone: Optional[str] = Field(None, alias="MobilePhone") + mailing_street: Optional[str] = Field(None, alias="MailingStreet") + mailing_city: Optional[str] = Field(None, alias="MailingCity") + mailing_state: Optional[str] = Field(None, alias="MailingState") + mailing_postal_code: Optional[str] = Field(None, alias="MailingPostalCode") + mailing_country: Optional[str] = Field(None, alias="MailingCountry") + description: Optional[str] = Field(None, alias="Description") + lead_source: Optional[str] = Field(None, alias="LeadSource") + department: Optional[str] = Field(None, alias="Department") + + model_config = ConfigDict(populate_by_name=True) + + +# Read/Delete Models +class SalesforceResourceInput(BaseModel): + action: Literal["read_lead", "delete_lead", "read_contact", "delete_contact"] + record_id: str + + @field_validator("record_id") + @classmethod + def validate_id(cls, v: str) -> str: + if not SALESFORCE_ID_REGEX.match(v): + raise ValueError("Invalid Salesforce record_id format (must be 15 or 18 alphanumeric characters)") + return v + +class ReadLeadInput(SalesforceResourceInput): + action: Literal["read_lead"] = "read_lead" + +class DeleteLeadInput(SalesforceResourceInput): + action: Literal["delete_lead"] = "delete_lead" + +class ReadContactInput(SalesforceResourceInput): + action: Literal["read_contact"] = "read_contact" + +class DeleteContactInput(SalesforceResourceInput): + action: Literal["delete_contact"] = "delete_contact" + +# Update Models +class UpdateLeadInput(BaseModel): + action: Literal["update_lead"] = "update_lead" + record_id: str + fields: Dict[str, Any] + + @field_validator("record_id") + @classmethod + def validate_id(cls, v: str) -> str: + if not SALESFORCE_ID_REGEX.match(v): + raise ValueError("Invalid Salesforce record_id format (must be 15 or 18 alphanumeric characters)") + return v + +class UpdateContactInput(BaseModel): + action: Literal["update_contact"] = "update_contact" + record_id: str + fields: Dict[str, Any] + + @field_validator("record_id") + @classmethod + def validate_id(cls, v: str) -> str: + if not SALESFORCE_ID_REGEX.match(v): + raise ValueError("Invalid Salesforce record_id format (must be 15 or 18 alphanumeric characters)") + return v + +SalesforceInput = Union[ + CreateLeadInput, + CreateContactInput, + ReadLeadInput, + UpdateLeadInput, + DeleteLeadInput, + ReadContactInput, + UpdateContactInput, + DeleteContactInput +] diff --git a/tests/test_connectors_basic.py b/tests/test_connectors_basic.py index e6a744a..1dc35e5 100644 --- a/tests/test_connectors_basic.py +++ b/tests/test_connectors_basic.py @@ -7,6 +7,7 @@ from node_wire_http_generic.logic import HttpGenericConnector from node_wire_smtp.logic import SmtpConnector from node_wire_stripe.logic import StripeConnector +from node_wire_salesforce.logic import SalesforceConnector from node_wire_runtime import ConnectorResponse, ErrorCategory, BaseConnector, SecretProvider from node_wire_runtime.connector_registry import auto_register @@ -42,3 +43,12 @@ def test_stripe_connector_instantiation_only(): connector = StripeConnector(secret_provider=DummySecretProvider()) assert connector.connector_id == "stripe" assert connector.action == "execute" + + +def test_salesforce_connector_instantiation_only(): + store = {"salesforce_instance_url": "https://test.salesforce.com"} + provider = type("Mock", (), {"get_secret": lambda s, k: store[k]})() + connector = BaseConnector.get_registry()["salesforce"](secret_provider=provider) + assert connector.connector_id == "salesforce" + assert "create_lead" in connector._action_registry + diff --git a/tests/test_salesforce.py b/tests/test_salesforce.py new file mode 100644 index 0000000..36bcbba --- /dev/null +++ b/tests/test_salesforce.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from pydantic import ValidationError + +from node_wire_runtime import SecretProvider +from node_wire_salesforce.logic import SalesforceConnector, SalesforceTransientError +from node_wire_salesforce.schema import ( + CreateLeadInput, + ReadLeadInput, + UpdateLeadInput, + DeleteLeadInput, + CreateContactInput, + ReadContactInput, + UpdateContactInput, + DeleteContactInput, + SalesforceOperationOutput, +) + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +class MockSecretProvider(SecretProvider): + def get_secret(self, key: str) -> str: + return { + "salesforce_instance_url": "https://test.salesforce.com", + }[key] + + +def _connector() -> SalesforceConnector: + """Return a SalesforceConnector with mock secrets.""" + conn = SalesforceConnector(secret_provider=MockSecretProvider()) + # Mock auth headers + conn.get_auth_headers = AsyncMock(return_value={"Authorization": "Bearer mock_token"}) + return conn + + +# --------------------------------------------------------------------------- +# Create Contact +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_salesforce_create_contact_happy_path(): + connector = _connector() + params = CreateContactInput(LastName="Doe", FirstName="John", Email="john@example.com") + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 201 + mock_response.content = b'{"id": "003123456789012", "success": true}' + mock_response.json.return_value = {"id": "003123456789012", "success": True} + mock_response.text = '{"id": "003123456789012", "success": true}' + + with patch("httpx.AsyncClient.request", return_value=mock_response) as mock_request: + result = await connector.create_contact(params, trace_id="test-trace") + + assert result.success is True + assert result.resource_id == "003123456789012" + assert result.data["id"] == "003123456789012" + + +@pytest.mark.asyncio +async def test_salesforce_create_contact_validation_error(): + # Invalid AccountId (too short) + with pytest.raises(ValidationError) as excinfo: + CreateContactInput(LastName="Doe", AccountId="short") + assert "Invalid Salesforce AccountId format" in str(excinfo.value) + + +# --------------------------------------------------------------------------- +# Update Contact (204 No Content) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_salesforce_update_contact_204_path(): + connector = _connector() + params = UpdateContactInput(record_id="003123456789012", fields={"FirstName": "Jane"}) + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 204 + mock_response.content = b"" + mock_response.text = "" + + with patch("httpx.AsyncClient.request", return_value=mock_response) as mock_request: + result = await connector.update_contact(params, trace_id="test-trace") + + assert result.success is True + assert result.resource_id == "003123456789012" + assert result.data == {} + + +# --------------------------------------------------------------------------- +# Error Handling (Raises Exception) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_salesforce_error_raises_exception(): + connector = _connector() + params = ReadContactInput(record_id="003123456789012") + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 400 + mock_response.text = 'Bad Request' + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + message="Bad Request", request=MagicMock(), response=mock_response + ) + + with patch("httpx.AsyncClient.request", return_value=mock_response): + with pytest.raises(httpx.HTTPStatusError): + await connector.read_contact(params, trace_id="test-trace") + + +# --------------------------------------------------------------------------- +# Transient Error (Raises SalesforceTransientError) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_salesforce_transient_error_raises(): + connector = _connector() + params = ReadContactInput(record_id="003123456789012") + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 503 + mock_response.text = 'Service Unavailable' + + with patch("httpx.AsyncClient.request", return_value=mock_response): + with pytest.raises(SalesforceTransientError): + await connector.read_contact(params, trace_id="test-trace") + + +# --------------------------------------------------------------------------- +# End-to-End internal_execute logic (checks mapping) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_salesforce_internal_execute_mapping(): + connector = _connector() + # Mocking internal_execute because BaseConnector handles the exception wrapping + + params = ReadContactInput(record_id="003123456789012") + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 503 + mock_response.text = "Transient Error" + + with patch("httpx.AsyncClient.request", return_value=mock_response): + # We call internal_execute directly to bypass BaseConnector.run's retry logic for now + # but check that it raises the expected transient error + with pytest.raises(SalesforceTransientError): + await connector.internal_execute(params, trace_id="test-trace") + + +# --------------------------------------------------------------------------- +# Delete Contact +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_salesforce_delete_contact_happy_path(): + connector = _connector() + params = DeleteContactInput(record_id="003123456789012") + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 204 + mock_response.content = b"" + + with patch("httpx.AsyncClient.request", return_value=mock_response) as mock_request: + result = await connector.delete_contact(params, trace_id="test-trace") + + assert result.success is True + assert result.resource_id == "003123456789012" + mock_request.assert_called_once() + assert mock_request.call_args[0][0] == "DELETE" + + +# --------------------------------------------------------------------------- +# Lead Operations +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_salesforce_create_lead_happy_path(): + connector = _connector() + params = CreateLeadInput(LastName="Smith", Company="Acme Corp", Email="smith@acme.com") + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 201 + mock_response.content = b'{"id": "00Q123456789012", "success": true}' + mock_response.json.return_value = {"id": "00Q123456789012", "success": True} + mock_response.text = '{"id": "00Q123456789012", "success": true}' + + with patch("httpx.AsyncClient.request", return_value=mock_response) as mock_request: + result = await connector.create_lead(params, trace_id="test-trace") + + assert result.success is True + assert result.resource_id == "00Q123456789012" + assert "LastName" in mock_request.call_args[1]["json"] + assert mock_request.call_args[1]["json"]["LastName"] == "Smith" + +@pytest.mark.asyncio +async def test_salesforce_read_lead_happy_path(): + connector = _connector() + params = ReadLeadInput(record_id="00Q123456789012") + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.content = b'{"Id": "00Q123456789012", "LastName": "Smith"}' + mock_response.json.return_value = {"Id": "00Q123456789012", "LastName": "Smith"} + + with patch("httpx.AsyncClient.request", return_value=mock_response) as mock_request: + result = await connector.read_lead(params, trace_id="test-trace") + + assert result.success is True + assert result.resource_id == "00Q123456789012" + assert result.data["LastName"] == "Smith" + +@pytest.mark.asyncio +async def test_salesforce_update_lead_happy_path(): + connector = _connector() + params = UpdateLeadInput(record_id="00Q123456789012", fields={"Company": "New Acme"}) + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 204 + mock_response.content = b"" + + with patch("httpx.AsyncClient.request", return_value=mock_response) as mock_request: + result = await connector.update_lead(params, trace_id="test-trace") + + assert result.success is True + assert result.resource_id == "00Q123456789012" + assert mock_request.call_args[0][0] == "PATCH" + assert mock_request.call_args[1]["json"]["Company"] == "New Acme" + +@pytest.mark.asyncio +async def test_salesforce_delete_lead_happy_path(): + connector = _connector() + params = DeleteLeadInput(record_id="00Q123456789012") + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 204 + mock_response.content = b"" + + with patch("httpx.AsyncClient.request", return_value=mock_response) as mock_request: + result = await connector.delete_lead(params, trace_id="test-trace") + + assert result.success is True + assert result.resource_id == "00Q123456789012" + assert mock_request.call_args[0][0] == "DELETE" From 2b8af07706af471594824c93517a89dbba8f5d8b Mon Sep 17 00:00:00 2001 From: kesav-aot Date: Wed, 6 May 2026 01:03:50 -0700 Subject: [PATCH 25/60] Propagate HTTP headers and implement streamable agent chat support (#32) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Propagate HTTP headers via ContextVar Introduce a ContextVar (_http_request_headers) to capture ASGI request headers and pass them into authenticate_mcp_request. The McpServer now retrieves headers from the context when authenticating, and the ASGI wrapper sets/resets the ContextVar around request handling. Tests updated to set the context var in the test ASGI app and new tests verify authentication accepts Authorization and X-API-Key headers (tools/list flow). * Add streamable agent chat support Implement streaming agent chat over a 'streamable-http' transport and UI improvements. Frontend: add transport status UI, NDJSON stream reader, streaming chat bubble/loader/tracing UI, and logic to fetch /scenarios/agent-transport and /scenarios/agent-chat-stream. Backend (playground/scenarios.py): add /agent-transport endpoint, refactor task builder and transport detection, and add /agent-chat-stream StreamingResponse that proxies ToolHiveAgent events. Agent core (src/agents/toolhive.py): add helpers to chunk text and emit stream events via ToolHiveAgent.run_events (meta/status/step/final_chunk/error/done). Tests: add test to verify run_events emits final done event with trace_id. Styles: add CSS for transport pill, streaming bubble, and end messages. * Add streaming utilities and buffered iterator Introduce a new streaming helper module (node_wire_runtime.streaming) that provides StreamSignal, stream_completion_log, resolve_stream_buffer_ms, and an async BufferedStreamIterator to optionally buffer streamed events. Wire these helpers into ToolHive and MCP server: ToolHive now emits completion logs and supports configurable buffering for run_events, and the MCP server logs stream completion on success/failure. Documentation and samples updated to document NW_STREAM_BUFFER_MS, completion signals, and playground fallback handling. The buffering default is 0 (disabled) and values are clamped (0–30000 ms). * Update test_scope_policy_transport.py --- README.md | 4 +- Setup.md | 6 +- playground/app.js | 93 +++++++++++++++++--- playground/scenarios.py | 45 +++++++--- playground/style.css | 71 ++++++++++++++++ sample.env | 4 + src/agents/toolhive.py | 57 ++++++++++--- src/bindings/mcp_server/server.py | 42 +++++++-- src/node_wire_runtime/__init__.py | 5 ++ src/node_wire_runtime/streaming.py | 75 ++++++++++++++++ tests/test_mcp_transport.py | 122 ++++++++++++++++++++++++++- tests/test_scope_policy_transport.py | 4 +- tests/test_toolhive_agent.py | 28 ++++++ 13 files changed, 503 insertions(+), 53 deletions(-) create mode 100644 src/node_wire_runtime/streaming.py diff --git a/README.md b/README.md index 93cea35..c93c0b3 100644 --- a/README.md +++ b/README.md @@ -130,7 +130,9 @@ Examples: Google Drive has a full doc at `src/node_wire_google_drive/README.md`; - **MCP:** `MODE=MCP` starts a minimal placeholder. For production agentic workflows, `python -m agents.mcp_entrypoint` supports: - **stdio** (Default): Legacy transport for ToolHive/subprocesses. - **streamable-http**: Native HTTP/SSE transport for first-class HTTP citizen integration. - (Configured via `NW_MCP_TRANSPORT` and `NW_MCP_PORT`). + (Configured via `NW_MCP_TRANSPORT`, `NW_MCP_PORT`, and optionally `NW_STREAM_BUFFER_MS` for stream buffering). + +The core runtime also emits structured completion logs when streaming ends so headless consumers can easily detect completion. The playground reads `NW_MCP_TRANSPORT` through `/scenarios/agent-transport` and displays the active mode in the Agentic Workflow panel. In `stdio` mode, chat responses are buffered until the backend agent run completes. In `streamable-http` mode, tool cards and final-answer chunks render progressively. diff --git a/Setup.md b/Setup.md index 82000fa..c6fdd5c 100644 --- a/Setup.md +++ b/Setup.md @@ -84,7 +84,7 @@ You only need to fill in the sections for the connectors you plan to use. The pl | **Google Drive** | `GOOGLE_DRIVE_SA_JSON`, `GOOGLE_DRIVE_FOLDER_ID` | Google Drive connector | | **SMTP** | `SMTP_HOST`, `SMTP_PORT`, `SMTP_USERNAME`, `SMTP_PASSWORD` | Sending emails | | **LLM / Agent** | `LLM_PROVIDER`, `GROQ_API_KEY` (or other provider key) | AI agent / ToolHive | -| **ToolHive / MCP**| `TOOLHIVE_MCP_URLS` (multi-server), `NW_MCP_TRANSPORT`, `NW_MCP_PORT` | AI agent / ToolHive | +| **ToolHive / MCP**| `TOOLHIVE_MCP_URLS` (multi-server), `NW_MCP_TRANSPORT`, `NW_MCP_PORT`, `NW_STREAM_BUFFER_MS` | AI agent / ToolHive | See `sample.env` for the full list with example values. @@ -299,6 +299,10 @@ Node Wire supports two transport modes for AI agents. Switch between them using - **`stdio`** (Default): Communicates via standard I/O. Best for local development, subprocess-based clients, and ToolHive-style wrapping. The playground uses buffered agent responses in this mode. - **`streamable-http`**: Native HTTP MCP server. Exposes a direct endpoint on `NW_MCP_HOST`, `NW_MCP_PORT`, and `NW_MCP_PATH`. The playground streams tool progress and final answer chunks in this mode. +### Streaming Features +- **Configurable Buffering (`NW_STREAM_BUFFER_MS`)**: When streaming, output can be buffered to reduce event spam. Set to the duration in milliseconds (e.g. `2000` for a 2-second batching window). Default is `0` (no buffering). +- **Completion Signals**: The core runtime emits structured "done" signals (`stream_completion_log`) via Python logging when streaming ends, allowing package consumers to easily detect when a stream finishes. + **Example: stdio mode** ```powershell diff --git a/playground/app.js b/playground/app.js index 87cc564..f473a86 100644 --- a/playground/app.js +++ b/playground/app.js @@ -1121,13 +1121,28 @@ document.addEventListener('DOMContentLoaded', () => { return bubble; } - function appendStreamingBubble(label = 'Agent') { + function appendStreamingBubble(label = 'Agent Streaming') { const bubble = document.createElement('div'); bubble.className = 'chat-bubble assistant streaming-bubble'; - bubble.innerHTML = `
${escapeHTML(label)}

`; + bubble.innerHTML = ` +
+ ${escapeHTML(label)} +

+
+ + + + Streaming response... +
+
+ `; agentChatHistory.appendChild(bubble); agentChatHistory.scrollTop = agentChatHistory.scrollHeight; - return bubble.querySelector('p'); + return { + bubble, + text: bubble.querySelector('.streaming-text'), + loader: bubble.querySelector('.stream-tail-loader'), + }; } function appendTraceBadge(traceId, transportLabel = '') { @@ -1140,6 +1155,14 @@ document.addEventListener('DOMContentLoaded', () => { agentChatHistory.scrollTop = agentChatHistory.scrollHeight; } + function appendStreamEndMessage(message, success = true) { + const end = document.createElement('div'); + end.className = `stream-end-message ${success ? 'success' : 'error'}`; + end.textContent = message || (success ? 'Streaming completed.' : 'Streaming ended with an error.'); + agentChatHistory.appendChild(end); + agentChatHistory.scrollTop = agentChatHistory.scrollHeight; + } + function updateAgentTransportStatus() { if (!agentTransportStatus) return; const label = agentTransportMode === 'streamable-http' ? 'Streamable HTTP' : 'stdio'; @@ -1159,6 +1182,33 @@ document.addEventListener('DOMContentLoaded', () => { updateAgentTransportStatus(); } + async function readNdjsonStream(response, handlers) { + if (!response.body) throw new Error('Browser did not expose a readable response stream'); + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let pending = ''; + + while (true) { + const { value, done } = await reader.read(); + if (done) break; + + pending += decoder.decode(value, { stream: true }); + const lines = pending.split('\n'); + pending = lines.pop() || ''; + + for (const line of lines) { + if (!line.trim()) continue; + const event = JSON.parse(line); + if (handlers[event.type]) handlers[event.type](event); + } + } + + if (pending.trim()) { + const event = JSON.parse(pending); + if (handlers[event.type]) handlers[event.type](event); + } + } + function appendStepCard(step) { const card = document.createElement('div'); card.className = 'agent-step-card'; @@ -1247,7 +1297,8 @@ document.addEventListener('DOMContentLoaded', () => { let finalText = ''; let traceId = ''; let success = true; - let streamedText = null; + let doneMessage = ''; + let streamView = null; await readNdjsonStream(response, { meta: (event) => { @@ -1263,36 +1314,52 @@ document.addEventListener('DOMContentLoaded', () => { args: event.args || {}, result: event.result || '' }); + if (!streamView) { + streamView = appendStreamingBubble(); + } else { + agentChatHistory.appendChild(streamView.bubble); + agentChatHistory.scrollTop = agentChatHistory.scrollHeight; + } }, final_chunk: (event) => { agentTyping.classList.add('hidden'); - if (!streamedText) streamedText = appendStreamingBubble('Agent Streaming'); + if (!streamView) streamView = appendStreamingBubble(); finalText += event.content || ''; - streamedText.textContent = finalText; + streamView.text.textContent = finalText; agentChatHistory.scrollTop = agentChatHistory.scrollHeight; }, error: (event) => { success = false; agentTyping.classList.add('hidden'); - if (!streamedText) streamedText = appendStreamingBubble('Agent Streaming'); + if (!streamView) streamView = appendStreamingBubble(); finalText += event.message || ''; - streamedText.textContent = finalText; + streamView.text.textContent = finalText; }, done: (event) => { traceId = event.trace_id || traceId; success = Boolean(event.success); + doneMessage = event.message || `Streaming ${success ? 'completed' : 'failed'}. trace_id=${traceId}`; + if (!streamView) streamView = appendStreamingBubble(); + streamView.loader.classList.add('hidden'); + appendStreamEndMessage(doneMessage, success); } }); agentTyping.classList.add('hidden'); - if (!streamedText && !finalText) { - streamedText = appendStreamingBubble('Agent Streaming'); - finalText = success ? 'Completed.' : 'The streamed run ended before a final answer was returned.'; - streamedText.textContent = finalText; + if (!doneMessage) { + if (!streamView) streamView = appendStreamingBubble(); + streamView.loader.classList.add('hidden'); + doneMessage = `Streaming connection closed before done event. trace_id=${traceId || 'unknown'}`; + appendStreamEndMessage(doneMessage, false); + success = false; + } + if (!finalText) { + finalText = success ? 'Completed.' : 'The stream ended before a final answer was returned.'; + if (streamView) streamView.text.textContent = finalText; } agentConversationHistory.push({ role: 'assistant', content: finalText }); appendTraceBadge(traceId, 'streamable-http'); - log(`Agent Chat: ${success ? 'Stream complete' : 'Stream failed'}`, success ? 'success' : 'error'); + log(`Agent Chat: ${success ? 'Stream complete' : 'Stream failed'} | ${doneMessage}`, success ? 'success' : 'error'); return; } diff --git a/playground/scenarios.py b/playground/scenarios.py index 5f69bd7..e2209e1 100644 --- a/playground/scenarios.py +++ b/playground/scenarios.py @@ -1449,6 +1449,27 @@ class AgentChatResponse(BaseModel): success: bool +def _current_agent_transport() -> str: + transport = os.environ.get("NW_MCP_TRANSPORT", "stdio").strip().lower() or "stdio" + return transport if transport in {"stdio", "streamable-http"} else "stdio" + + +def _build_agent_chat_task(payload: AgentChatInput) -> str: + history_text_parts = [] + for msg in payload.history: + role = msg.get("role", "user") + content = msg.get("content", "") + history_text_parts.append(f"{role.upper()}: {content}") + + if history_text_parts: + return ( + "Previous conversation:\n" + + "\n".join(history_text_parts) + + f"\n\nUSER (latest): {payload.message}" + ) + return payload.message + + @router.get("/agent-transport") async def agent_transport() -> Dict[str, str]: transport = _current_agent_transport() @@ -1561,15 +1582,11 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: task = _build_agent_chat_task(payload) - # Determine MCP transport — try proxy first, optionally fallback to local stdio. - # Default behavior surfaces proxy/auth errors directly in the UI so demos can - # show MCP failures (instead of silently falling back to stdio). - fallback_to_stdio = ( - (os.environ.get("PLAYGROUND_AGENT_PROXY_FALLBACK_TO_STDIO", "false").strip().lower()) - in {"1", "true", "yes", "on"} - ) - urls = resolve_mcp_urls() + # Determine MCP transport — try proxy first, fallback to local stdio + transport = _current_agent_transport() + urls = resolve_mcp_urls() if transport == "streamable-http" else [] run_result = None + fallback_to_stdio = os.environ.get("PLAYGROUND_AGENT_PROXY_FALLBACK_TO_STDIO", "false").lower() == "true" if urls: logger.info("Agent Chat | trying ToolHive proxy URL(s): %s", ",".join(urls)) @@ -1667,11 +1684,10 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: @router.post("/agent-chat-stream") async def agent_chat_stream(payload: AgentChatInput) -> Any: """ - Stream agent progress and final answer chunks to the playground UI. + Stream agent progress and final-answer chunks to web clients. - Tool steps are emitted as each tool finishes. The final assistant answer is - emitted as chunks instead of waiting for the browser to receive one large - buffered JSON payload. + The terminal ``done`` event includes ``trace_id`` and ``message``. Clients + should stop their streaming loader only when that event arrives. """ async def stream_events(): @@ -1689,14 +1705,16 @@ async def stream_events(): ) if not payload.message.strip(): + trace_id = str(uuid.uuid4()) yield json.dumps({ "type": "final_chunk", "content": "Please type a message to get started.", }) + "\n" yield json.dumps({ "type": "done", - "trace_id": str(uuid.uuid4()), + "trace_id": trace_id, "success": False, + "message": f"Streaming failed. trace_id={trace_id}", }) + "\n" return @@ -1744,6 +1762,7 @@ async def stream_events(): "type": "done", "trace_id": trace_id, "success": False, + "message": f"Streaming failed. trace_id={trace_id}", }) + "\n" return StreamingResponse(stream_events(), media_type="application/x-ndjson") diff --git a/playground/style.css b/playground/style.css index 3b62fe3..9294a58 100644 --- a/playground/style.css +++ b/playground/style.css @@ -1269,6 +1269,77 @@ textarea:focus { margin-top: 0.5rem; } +.transport-status-bar { + display: flex; + justify-content: flex-start; + align-items: center; + margin: -0.75rem 0 1rem; + padding: 0.7rem 0.75rem; + background: #f8fafc; + border: 1px solid #e2e8f0; + border-radius: 0.875rem; +} + +.transport-status-pill { + display: flex; + align-items: center; + gap: 0.55rem; + color: var(--brand-accent); + background: white; + border: 1px solid #e2e8f0; + border-radius: 999px; + padding: 0.5rem 0.85rem; + font-weight: 700; + font-size: 0.78rem; + box-shadow: 0 8px 18px rgba(15, 23, 42, 0.04); +} + +.transport-status-dot { + width: 0.55rem; + height: 0.55rem; + border-radius: 999px; + background: var(--success); + box-shadow: 0 0 0 4px rgba(16, 185, 129, 0.14); +} + +.streaming-bubble p { + white-space: pre-wrap; +} + +.stream-tail-loader { + display: inline-flex; + align-items: center; + gap: 0.45rem; + margin-top: 0.65rem; + color: var(--text-muted); + font-size: 0.78rem; + font-weight: 600; +} + +.stream-end-message { + align-self: flex-start; + max-width: 85%; + font-family: monospace; + font-size: 0.68rem; + padding: 0.35rem 0.55rem; + border-radius: 0.45rem; + border: 1px solid rgba(59, 130, 246, 0.12); + background: rgba(59, 130, 246, 0.06); + color: var(--text-muted); +} + +.stream-end-message.success { + border-color: rgba(16, 185, 129, 0.18); + background: rgba(16, 185, 129, 0.08); + color: #047857; +} + +.stream-end-message.error { + border-color: rgba(244, 63, 94, 0.18); + background: rgba(244, 63, 94, 0.08); + color: var(--error); +} + /* Final Responsive Overrides */ @media (max-width: 1100px) { .playground-layout { diff --git a/sample.env b/sample.env index 8540c6a..acf5042 100644 --- a/sample.env +++ b/sample.env @@ -39,6 +39,10 @@ TOOLHIVE_MCP_URLS= PLAYGROUND_AGENT_PROXY_FALLBACK_TO_STDIO=false # Cap MCP tool JSON size sent back to the LLM (Groq on-demand TPM); default 12000 # TOOLHIVE_MAX_TOOL_RESULT_CHARS=12000 + +# Stream buffering window in milliseconds (default: 0 = no buffering). +# Set to e.g. 2000 for a 2-second batching window on streamed results. +NW_STREAM_BUFFER_MS=0 # Native MCP Transport (for agents.mcp_entrypoint and per-connector MCP servers) # ----------------------------------------------------------------------------- # NW_MCP_TRANSPORT: Selects the communication layer. diff --git a/src/agents/toolhive.py b/src/agents/toolhive.py index 5ce86ad..2c2b33f 100644 --- a/src/agents/toolhive.py +++ b/src/agents/toolhive.py @@ -158,10 +158,9 @@ def _tool_failure_abort_message(tool_name: str, max_failures: int) -> str: def _chunk_agent_text(text: str, chunk_size: int = 180) -> List[str]: - """Split final assistant text into UI-friendly chunks.""" + """Split final assistant text into UI-friendly chunks for stream consumers.""" if not text: return [""] - chunks: List[str] = [] current = "" for part in text.split(" "): @@ -176,6 +175,17 @@ def _chunk_agent_text(text: str, chunk_size: int = 180) -> List[str]: return chunks +def _stream_done_event(trace_id: str, *, success: bool) -> Dict[str, Any]: + from node_wire_runtime.streaming import stream_completion_log + stream_completion_log(trace_id, success, connector_id="agent", action="run_events") + return { + "type": "done", + "trace_id": trace_id, + "success": success, + "message": f"Streaming completed. trace_id={trace_id}", + } + + # --------------------------------------------------------------------------- # Result model # --------------------------------------------------------------------------- @@ -631,17 +641,42 @@ async def run(self, task: str) -> AgentRunResult: result.error = f"Agent reached max_steps ({self._max_steps}) without completing the task." logger.warning(result.error) + from node_wire_runtime.streaming import stream_completion_log + stream_completion_log(trace_id, result.success, connector_id="agent", action="run") return result async def run_events(self, task: str) -> AsyncIterator[Dict[str, Any]]: + trace_id = str(uuid.uuid4()) + from node_wire_runtime.streaming import resolve_stream_buffer_ms, BufferedStreamIterator + + buffer_ms = resolve_stream_buffer_ms() + iterator = self._run_events_inner(task, trace_id) + + if buffer_ms > 0: + async for item in BufferedStreamIterator(iterator, buffer_ms, trace_id, connector_id="agent", action="run_events"): + yield item + else: + async for item in iterator: + yield item + + async def _run_events_inner(self, task: str, trace_id: str) -> AsyncIterator[Dict[str, Any]]: """ + Stream agent progress events for web clients. + + Contract: + - ``meta``: emitted once with ``trace_id``. + - ``status``: informational progress text. + - ``step``: emitted after each MCP tool call completes. + - ``final_chunk``: chunks of the final assistant answer. + - ``error``: recoverable terminal error text. + - ``done``: always emitted at terminal completion; clients should stop + loaders when this event arrives. Stream agent progress events as the ReAct loop runs. The LLM providers currently return complete assistant messages, so final answer chunks begin after the final LLM call completes. Tool-step events are emitted immediately after each MCP tool call completes. """ - trace_id = str(uuid.uuid4()) logger.info("Streaming agent run started | trace_id=%s", trace_id) logger.info("Task: %s", task) @@ -657,7 +692,7 @@ async def run_events(self, task: str) -> AsyncIterator[Dict[str, Any]]: error = f"Failed to list MCP tools: {exc}" logger.error(error) yield {"type": "error", "trace_id": trace_id, "message": error} - yield {"type": "done", "trace_id": trace_id, "success": False} + yield _stream_done_event(trace_id, success=False) return messages: List[LLMMessage] = [ @@ -676,7 +711,7 @@ async def run_events(self, task: str) -> AsyncIterator[Dict[str, Any]]: error = f"LLM error at step {step_num}: {exc}" logger.error(error) yield {"type": "error", "trace_id": trace_id, "message": error} - yield {"type": "done", "trace_id": trace_id, "success": False} + yield _stream_done_event(trace_id, success=False) return messages.append(LLMMessage( @@ -686,10 +721,9 @@ async def run_events(self, task: str) -> AsyncIterator[Dict[str, Any]]: )) if not llm_resp.wants_tool_call: - final_answer = llm_resp.content or "" - for chunk in _chunk_agent_text(final_answer): + for chunk in _chunk_agent_text(llm_resp.content or ""): yield {"type": "final_chunk", "content": chunk} - yield {"type": "done", "trace_id": trace_id, "success": True} + yield _stream_done_event(trace_id, success=True) return abort_message: Optional[str] = None @@ -712,10 +746,9 @@ async def run_events(self, task: str) -> AsyncIterator[Dict[str, Any]]: "result": tool_result_str, } - llm_tool_content = truncate_tool_result_for_llm(tool_result_str) messages.append(LLMMessage( role="tool", - content=llm_tool_content, + content=truncate_tool_result_for_llm(tool_result_str), tool_call_id=tc.id, name=tc.name, )) @@ -730,14 +763,14 @@ async def run_events(self, task: str) -> AsyncIterator[Dict[str, Any]]: if abort_message: for chunk in _chunk_agent_text(abort_message): yield {"type": "final_chunk", "content": chunk} - yield {"type": "done", "trace_id": trace_id, "success": False} + yield _stream_done_event(trace_id, success=False) return error = f"Agent reached max_steps ({self._max_steps}) without completing the task." logger.warning(error) for chunk in _chunk_agent_text(error): yield {"type": "final_chunk", "content": chunk} - yield {"type": "done", "trace_id": trace_id, "success": False} + yield _stream_done_event(trace_id, success=False) # --------------------------------------------------------------------------- diff --git a/src/bindings/mcp_server/server.py b/src/bindings/mcp_server/server.py index 7dd165c..2de01d8 100644 --- a/src/bindings/mcp_server/server.py +++ b/src/bindings/mcp_server/server.py @@ -3,6 +3,7 @@ import json import logging import uuid +from contextvars import ContextVar from typing import Any, Dict, List, Mapping, Optional from bindings.factory import ConnectorFactory @@ -12,9 +13,15 @@ from node_wire_runtime.manifest import MCP_MANIFEST_CONTRACT_VERSION, build_manifest from node_wire_runtime import BaseConnector, ConnectorResponse, ErrorCategory from node_wire_runtime.ingress import enforce_authoritative_action, normalize_mcp_tool_arguments +from node_wire_runtime.streaming import stream_completion_log logger = logging.getLogger("bindings.mcp_server") +_http_request_headers: ContextVar[Mapping[str, str] | None] = ContextVar( + "mcp_http_request_headers", + default=None, +) + class McpServer: """ @@ -89,7 +96,10 @@ def _ensure_identity( ) -> CallerIdentity | None: if identity is not None: return identity - return authenticate_mcp_request(meta=meta) + return authenticate_mcp_request( + headers=_http_request_headers.get(), + meta=meta, + ) def _request_meta_from_context(self) -> Mapping[str, Any] | None: try: @@ -135,13 +145,19 @@ async def invoke_tool( enforce_authoritative_action(run_args, action) run_args["action"] = action - response = await connector.run( - run_args, - principal=identity.principal if identity else None, - tenant_id=identity.tenant_id if identity else None, - scopes=identity.scopes if identity else None, - ) - return response.model_dump() + trace_id = run_args.get("trace_id") or str(uuid.uuid4()) + try: + response = await connector.run( + run_args, + principal=identity.principal if identity else None, + tenant_id=identity.tenant_id if identity else None, + scopes=identity.scopes if identity else None, + ) + stream_completion_log(trace_id, True, connector_id=connector_id, action=action) + return response.model_dump() + except Exception as exc: + stream_completion_log(trace_id, False, connector_id=connector_id, action=action) + raise def _setup_lowlevel_server(self) -> Any: from mcp.server import NotificationOptions, Server as LowLevelServer @@ -268,7 +284,15 @@ def __init__(self, handler): self.handler = handler async def __call__(self, scope, receive, send): - await self.handler(scope, receive, send) + headers = { + key.decode("latin-1"): value.decode("latin-1") + for key, value in scope.get("headers", []) + } + token = _http_request_headers.set(headers) + try: + await self.handler(scope, receive, send) + finally: + _http_request_headers.reset(token) starlette_app = Starlette( lifespan=lifespan, diff --git a/src/node_wire_runtime/__init__.py b/src/node_wire_runtime/__init__.py index 99af6ba..b05fea7 100644 --- a/src/node_wire_runtime/__init__.py +++ b/src/node_wire_runtime/__init__.py @@ -11,6 +11,7 @@ execute_spec_in_thread, navigate_resource, ) +from .streaming import StreamSignal, stream_completion_log, resolve_stream_buffer_ms, BufferedStreamIterator __all__ = [ "ConnectorResponse", @@ -37,4 +38,8 @@ "default_build_kwargs", "execute_spec_in_thread", "navigate_resource", + "StreamSignal", + "stream_completion_log", + "resolve_stream_buffer_ms", + "BufferedStreamIterator", ] diff --git a/src/node_wire_runtime/streaming.py b/src/node_wire_runtime/streaming.py new file mode 100644 index 0000000..83db856 --- /dev/null +++ b/src/node_wire_runtime/streaming.py @@ -0,0 +1,75 @@ +import os +import time +import logging +from enum import Enum +from typing import AsyncIterator, Dict, Any, Optional + +logger = logging.getLogger("runtime.streaming") + +class StreamSignal(str, Enum): + STARTED = "started" + CHUNK = "chunk" + COMPLETED = "completed" + FAILED = "failed" + +def stream_completion_log(trace_id: str, success: bool, *, connector_id: str, action: str) -> None: + status = StreamSignal.COMPLETED.value if success else StreamSignal.FAILED.value + msg = "Stream completed" if success else "Stream failed" + extra = { + "trace_id": trace_id, + "connector_id": connector_id, + "action": action, + "stream_status": status, + } + if success: + logger.info("%s | trace_id=%s | connector_id=%s | action=%s | status=%s", + msg, trace_id, connector_id, action, status, extra=extra) + else: + logger.warning("%s | trace_id=%s | connector_id=%s | action=%s | status=%s", + msg, trace_id, connector_id, action, status, extra=extra) + +def resolve_stream_buffer_ms(override: Optional[int] = None) -> int: + if override is not None: + return max(0, min(int(override), 30000)) + val = os.environ.get("NW_STREAM_BUFFER_MS", "0").strip() + try: + n = int(val) + except ValueError: + n = 0 + return max(0, min(n, 30000)) + +async def BufferedStreamIterator( + iterator: AsyncIterator[Dict[str, Any]], + buffer_ms: int, + trace_id: str, + connector_id: str = "agent", + action: str = "stream" +) -> AsyncIterator[Dict[str, Any]]: + success = True + try: + if buffer_ms <= 0: + async for item in iterator: + yield item + return + + buffer_sec = buffer_ms / 1000.0 + buffer = [] + last_flush = time.monotonic() + + async for item in iterator: + buffer.append(item) + now = time.monotonic() + if now - last_flush >= buffer_sec: + for b_item in buffer: + yield b_item + buffer.clear() + last_flush = now + + for b_item in buffer: + yield b_item + except Exception: + success = False + raise + finally: + # Automatically emit completion log when stream ends + stream_completion_log(trace_id, success, connector_id=connector_id, action=action) diff --git a/tests/test_mcp_transport.py b/tests/test_mcp_transport.py index 57efd89..b083cf5 100644 --- a/tests/test_mcp_transport.py +++ b/tests/test_mcp_transport.py @@ -2,7 +2,7 @@ import anyio import httpx from unittest.mock import MagicMock, patch -from bindings.mcp_server.server import McpServer +from bindings.mcp_server.server import McpServer, _http_request_headers from starlette.applications import Starlette from starlette.routing import Route from mcp.server.streamable_http_manager import StreamableHTTPSessionManager @@ -11,7 +11,15 @@ class _ASGIApp: def __init__(self, handler): self.handler = handler async def __call__(self, scope, receive, send): - await self.handler(scope, receive, send) + headers = { + key.decode("latin-1"): value.decode("latin-1") + for key, value in scope.get("headers", []) + } + token = _http_request_headers.set(headers) + try: + await self.handler(scope, receive, send) + finally: + _http_request_headers.reset(token) @pytest.fixture(autouse=True) def allow_only_standard_connectors(monkeypatch): @@ -117,3 +125,113 @@ async def test_mcp_http_tools_list_success(): assert list_resp.status_code == 200 data = list_resp.json() assert "tools" in data["result"] + + +@pytest.mark.anyio +async def test_mcp_http_tools_list_accepts_authorization_header(monkeypatch): + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + server = McpServer(server_name="test-server", connector_ids=["smtp"]) + low = server._setup_lowlevel_server() + session_manager = StreamableHTTPSessionManager(low, json_response=True) + + starlette_app = Starlette( + routes=[ + Route("/mcp", endpoint=_ASGIApp(session_manager.handle_request), methods=["GET", "POST"]) + ] + ) + + common_headers = { + "Accept": "application/json, text/event-stream", + "Authorization": "Bearer unit-test-secret", + } + + async with session_manager.run(): + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=starlette_app), base_url="http://testserver") as client: + init_resp = await client.post("/mcp", json={ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0"} + } + }, headers=common_headers) + assert init_resp.status_code == 200 + session_id = init_resp.headers.get("Mcp-Session-Id") + + headers = common_headers.copy() + if session_id: + headers["Mcp-Session-Id"] = session_id + + list_resp = await client.post("/mcp", + json={ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + "params": {} + }, + headers=headers + ) + assert list_resp.status_code == 200 + data = list_resp.json() + assert "tools" in data["result"] + assert any(t["name"] == "smtp.send_email" for t in data["result"]["tools"]) + + +@pytest.mark.anyio +async def test_mcp_http_tools_list_accepts_x_api_key_header(monkeypatch): + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + server = McpServer(server_name="test-server", connector_ids=["smtp"]) + low = server._setup_lowlevel_server() + session_manager = StreamableHTTPSessionManager(low, json_response=True) + + starlette_app = Starlette( + routes=[ + Route("/mcp", endpoint=_ASGIApp(session_manager.handle_request), methods=["GET", "POST"]) + ] + ) + + common_headers = { + "Accept": "application/json, text/event-stream", + "X-API-Key": "unit-test-secret", + } + + async with session_manager.run(): + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=starlette_app), base_url="http://testserver") as client: + init_resp = await client.post("/mcp", json={ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0"} + } + }, headers=common_headers) + assert init_resp.status_code == 200 + session_id = init_resp.headers.get("Mcp-Session-Id") + + headers = common_headers.copy() + if session_id: + headers["Mcp-Session-Id"] = session_id + + list_resp = await client.post("/mcp", + json={ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + "params": {} + }, + headers=headers + ) + assert list_resp.status_code == 200 + data = list_resp.json() + assert "tools" in data["result"] + assert any(t["name"] == "smtp.send_email" for t in data["result"]["tools"]) diff --git a/tests/test_scope_policy_transport.py b/tests/test_scope_policy_transport.py index 972c654..a2238cd 100644 --- a/tests/test_scope_policy_transport.py +++ b/tests/test_scope_policy_transport.py @@ -19,7 +19,7 @@ class _Output(BaseModel): class _PolicyTestConnector(BaseConnector): - connector_id = "fhir_epic" + connector_id = "policy_test_fhir_epic" output_model = _Output @nw_action("read_patient") @@ -29,7 +29,7 @@ async def read_patient(self, params: _Input, *, trace_id: str) -> _Output: def _connector_with_scope_map() -> _PolicyTestConnector: return _PolicyTestConnector( - policy_hook=ScopePolicyHook({"fhir_epic.read_patient": "mcp:fhir.read_patient"}) + policy_hook=ScopePolicyHook({"policy_test_fhir_epic.read_patient": "mcp:fhir.read_patient"}) ) diff --git a/tests/test_toolhive_agent.py b/tests/test_toolhive_agent.py index a2628a0..b1346c7 100644 --- a/tests/test_toolhive_agent.py +++ b/tests/test_toolhive_agent.py @@ -226,6 +226,34 @@ async def test_agent_runs_three_tool_sequence() -> None: assert mock_mcp.call_tool.await_count == 3 +@pytest.mark.asyncio +async def test_agent_run_events_emits_done_message_with_trace_id() -> None: + responses = [ + LLMResponse( + content=None, + tool_calls=[_tool_call("fhir_cerner.read_patient", {"resource_id": "12724066"})], + stop_reason="tool_calls", + ), + LLMResponse(content="All done.", tool_calls=[], stop_reason="stop"), + ] + provider = _MockLLMProvider(responses) + + mock_mcp = AsyncMock(spec=ToolHiveMcpClient) + mock_mcp.list_tools.return_value = SAMPLE_TOOLS + mock_mcp.call_tool.return_value = '{"status": "ok"}' + + agent = ToolHiveAgent(mcp_client=mock_mcp, llm_provider=provider, max_steps=5) + events = [event async for event in agent.run_events("Fetch patient 12724066")] + + assert events[0]["type"] == "meta" + assert any(event["type"] == "step" for event in events) + assert any(event["type"] == "final_chunk" for event in events) + assert events[-1]["type"] == "done" + assert events[-1]["success"] is True + assert events[-1]["trace_id"] == events[0]["trace_id"] + assert events[-1]["message"] == f"Streaming completed. trace_id={events[0]['trace_id']}" + + @pytest.mark.asyncio async def test_agent_id_first_turn_calls_read_patient_with_resource_id() -> None: """Document ID-first flow: Cerner read uses canonical resource_id (not search_patients).""" From ddc311bafeee258e99eb55afb2f18445415c4a73 Mon Sep 17 00:00:00 2001 From: kesav-aot Date: Thu, 7 May 2026 03:07:56 -0700 Subject: [PATCH 26/60] feat(slack-connector): add Slack connector with messaging and file upload capabilities (#17) - Implemented Slack connector with actions for posting messages, sending direct messages, and uploading files. - Added Pydantic models for input and output validation. - Integrated error handling with domain-specific exceptions for authentication, permission, rate limits, and upload errors. - Created a README file detailing the connector's operations and error taxonomy. - Developed unit tests to ensure functionality and security, including token leakage prevention in logs. --- .gitignore | 3 +- README.md | 2 + Setup.md | 24 ++ config/connectors.yaml | 180 +++++---- docker/slack/Dockerfile | 34 ++ docs/connectors.md | 2 + docs/mcp-servers.md | 24 ++ docs/slack_connector.md | 154 ++++++++ packages/connectors/slack/pyproject.toml | 22 ++ packages/connectors/slack/setup.py | 16 + playground/app.js | 139 ++++++- playground/index.html | 87 +++- playground/scenarios.py | 116 ++++++ playground/style.css | 6 +- pyproject.toml | 2 +- sample.env | 18 +- scripts/build-mcp-images.sh | 6 + scripts/build-packages.sh | 1 + src/agents/slack_mcp.py | 27 ++ src/node_wire_slack/README.md | 125 ++++++ src/node_wire_slack/__init__.py | 1 + src/node_wire_slack/exceptions.py | 27 ++ src/node_wire_slack/logic.py | 482 +++++++++++++++++++++++ src/node_wire_slack/registration.py | 33 ++ src/node_wire_slack/schema.py | 118 ++++++ tests/test_slack_connector.py | 411 +++++++++++++++++++ 26 files changed, 1954 insertions(+), 106 deletions(-) create mode 100644 docker/slack/Dockerfile create mode 100644 docs/slack_connector.md create mode 100644 packages/connectors/slack/pyproject.toml create mode 100644 packages/connectors/slack/setup.py create mode 100644 src/agents/slack_mcp.py create mode 100644 src/node_wire_slack/README.md create mode 100644 src/node_wire_slack/__init__.py create mode 100644 src/node_wire_slack/exceptions.py create mode 100644 src/node_wire_slack/logic.py create mode 100644 src/node_wire_slack/registration.py create mode 100644 src/node_wire_slack/schema.py create mode 100644 tests/test_slack_connector.py diff --git a/.gitignore b/.gitignore index f2f8975..da7f5fd 100644 --- a/.gitignore +++ b/.gitignore @@ -11,11 +11,12 @@ dist/ .coverage .coverage.* htmlcov/ - +*.gitattributes # GCP / cloud credentials connectorplatform-*.json *-service-account.json *credentials*.json +service_account.json # Grafana exports (auto-generated) grafana/*.json diff --git a/README.md b/README.md index c93c0b3..bced5f8 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ Each connector can run as its own independent MCP server (Docker image). | `nw-smartonfhir-epic` | All `fhir_epic.` (e.g. `fhir_epic.read_patient`) | `docker/fhir-epic/Dockerfile` | | `nw-smartonfhir-cerner` | All `fhir_cerner.` (e.g. `fhir_cerner.read_patient`) | `docker/fhir-cerner/Dockerfile` | | `nw-smtp` | `smtp.send_email` | `docker/smtp/Dockerfile` | +| `nw-slack` | All `slack.` (e.g. `slack.post_message`) | `docker/slack/Dockerfile` | See [docs/mcp-servers.md](docs/mcp-servers.md) for build, env config, docker-compose, and ToolHive registration. @@ -95,6 +96,7 @@ The platform is split into three layers: | **google_drive**| Google Drive (list, create, get, update, upload, delete, permissions) | `execute` (payload discriminator) | rest, grpc, mcp | | **fhir_epic** | FHIR R4 integration for Epic (multi-action) | `read_patient`, `search_patients`, `search_encounter`, `create_document_reference`, `search_document_reference` | rest, grpc, mcp | | **fhir_cerner** | FHIR R4 integration for Cerner (multi-action) | `read_patient`, `search_patients`, `search_encounter`, `create_document_reference`, `search_document_reference` | rest, grpc, mcp | +| **slack** | Slack (messaging, files) | `post_message`, `send_direct_message`, `upload_file` | rest, mcp | ### Connector-specific documentation diff --git a/Setup.md b/Setup.md index c6fdd5c..647509a 100644 --- a/Setup.md +++ b/Setup.md @@ -83,6 +83,7 @@ You only need to fill in the sections for the connectors you plan to use. The pl | **FHIR Cerner** | `CERNER_FHIR_BASE_URL`, `CERNER_TOKEN_URL`, `CERNER_CLIENT_ID`, `CERNER_KID`, `CERNER_PRIVATE_KEY`, `CERNER_SCOPES` | Cerner EHR integration | | **Google Drive** | `GOOGLE_DRIVE_SA_JSON`, `GOOGLE_DRIVE_FOLDER_ID` | Google Drive connector | | **SMTP** | `SMTP_HOST`, `SMTP_PORT`, `SMTP_USERNAME`, `SMTP_PASSWORD` | Sending emails | +| **Slack** | `SLACK_BOT_TOKEN` | Sending Slack messages | | **LLM / Agent** | `LLM_PROVIDER`, `GROQ_API_KEY` (or other provider key) | AI agent / ToolHive | | **ToolHive / MCP**| `TOOLHIVE_MCP_URLS` (multi-server), `NW_MCP_TRANSPORT`, `NW_MCP_PORT`, `NW_STREAM_BUFFER_MS` | AI agent / ToolHive | @@ -175,6 +176,7 @@ All responses use the same standard shape: | **google_drive** | List, upload, download, manage Drive files | GCP service account JSON | [Google Drive setup & API](docs/google_drive_connector.md#google-drive-service-account-setup) | | **fhir_epic** | Read/write patient data from Epic EHR | Epic SMART credentials + private key | [FHIR Epic Setup](#fhir-epic) | | **fhir_cerner** | Read/write patient data from Cerner EHR | Cerner SMART credentials + private key | [FHIR Cerner Setup](#fhir-cerner) | +| **slack** | Send messages and files to Slack channels | Slack Bot OAuth Token | [Slack Setup](#slack) | --- @@ -365,6 +367,27 @@ npx @modelcontextprotocol/inspector In Inspector, choose `Streamable HTTP`, enter `http://127.0.0.1:8081/mcp`, connect, then use `Tools -> List Tools` and run a safe tool call with valid JSON arguments. --- +### Slack + +Add the bot token to your `.env`: + +```env +SLACK_BOT_TOKEN=xoxb-your-slack-bot-token +``` + +1. Create a Slack App at [api.slack.com/apps](https://api.slack.com/apps). +2. Go to **OAuth & Permissions** and add **Bot Token Scopes**: + - `chat:write` (to post messages) + - `files:write` (to upload files) + - `im:write` (to send direct messages) + - `conversations:open` (to resolve User IDs) +3. Install the app to your workspace. +4. Copy the **Bot User OAuth Token** (`xoxb-...`). +5. Invite the bot to any private channels you want it to access. + +--- + +## MCP Server & ToolHive The platform exposes connector tools for AI agents via the MCP (Model Context Protocol). There are two deployment modes: @@ -379,6 +402,7 @@ Each connector runs as its own independent MCP server. This is the preferred app | `nw-smartonfhir-epic` | All `fhir_epic.` (e.g. `fhir_epic.read_patient`) | `docker/fhir-epic/Dockerfile` | | `nw-smartonfhir-cerner` | All `fhir_cerner.` (e.g. `fhir_cerner.read_patient`) | `docker/fhir-cerner/Dockerfile` | | `nw-smtp` | `smtp.send_email` | `docker/smtp/Dockerfile` | +| `nw-slack` | All `slack.` (e.g. `slack.post_message`) | `docker/slack/Dockerfile` | **Full guide (build, env config, ToolHive registration, multi-server agent usage):** [docs/mcp-servers.md](docs/mcp-servers.md) diff --git a/config/connectors.yaml b/config/connectors.yaml index 7e5940f..dca2771 100644 --- a/config/connectors.yaml +++ b/config/connectors.yaml @@ -1,87 +1,93 @@ -# connectors.yaml — Node Wire connector configuration -# -# REST API auth (not stored here; set in environment): -# NW_REST_API_KEY — required for /connectors, /playground, /scenarios unless NW_REST_AUTH_DISABLED=true -# -# SECURITY RULE: This file must never contain secrets. -# - Non-sensitive config (base_url, host, port) → safe in YAML -# - Secrets (client_id, private_key, api_key) → environment variables (or cloud backend) -# -connectors: - http_generic: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - # auth: not set — defaults to NoAuthProvider - - smtp: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - host: "smtp.example.com" - port: 587 - from_email: "noreply@example.com" - auth: - provider: static_credentials - username_secret: SMTP_USERNAME - password_secret: SMTP_PASSWORD - - stripe: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - auth: - provider: static_token - secret_key: stripe_api_key - header_name: Authorization - prefix: "" # Stripe expects raw key - - google_drive: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - auth: - provider: service_account - sa_json_secret: GOOGLE_DRIVE_SA_JSON - scopes: - - https://www.googleapis.com/auth/drive - - fhir_epic: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - base_url: "${EPIC_FHIR_BASE_URL:https://fhir.epic.sandbox/api/FHIR/R4}" - auth: - provider: oauth2 - grant_method: private_key_jwt - token_url_secret: EPIC_TOKEN_URL - client_id_secret: EPIC_CLIENT_ID - private_key_secret: EPIC_PRIVATE_KEY - kid_secret: EPIC_KID - algorithm: RS384 - - fhir_cerner: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - base_url: "${CERNER_FHIR_BASE_URL:https://fhir-ehr-code.cerner.com/r4/your-tenant-id}" - auth: - provider: oauth2 - grant_method: private_key_jwt - token_url_secret: CERNER_TOKEN_URL - client_id_secret: CERNER_CLIENT_ID - private_key_secret: CERNER_PRIVATE_KEY - kid_secret: CERNER_KID - algorithm: RS384 - scopes_secret: CERNER_SCOPES - scopes: - - system/Patient.read - - system/Encounter.read - - system/DocumentReference.read - - system/DocumentReference.write - - salesforce: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - # instance_url is typically https://yourdomain.my.salesforce.com - auth: - provider: oauth2 - grant_method: refresh_token - token_url_secret: SALESFORCE_TOKEN_URL - client_id_secret: SALESFORCE_CLIENT_ID - client_secret_secret: SALESFORCE_CLIENT_SECRET - refresh_token_secret: SALESFORCE_REFRESH_TOKEN \ No newline at end of file +# connectors.yaml — Node Wire connector configuration +# +# REST API auth (not stored here; set in environment): +# NW_REST_API_KEY — required for /connectors, /playground, /scenarios unless NW_REST_AUTH_DISABLED=true +# +# SECURITY RULE: This file must never contain secrets. +# - Non-sensitive config (base_url, host, port) → safe in YAML +# - Secrets (client_id, private_key, api_key) → environment variables (or cloud backend) +# +connectors: + http_generic: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + # auth: not set — defaults to NoAuthProvider + + smtp: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + host: "smtp.example.com" + port: 587 + from_email: "noreply@example.com" + auth: + provider: static_credentials + username_secret: SMTP_USERNAME + password_secret: SMTP_PASSWORD + + stripe: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] # keeping REST unless you explicitly want to remove it + auth: + provider: static_token + secret_key: stripe_api_key + header_name: Authorization + prefix: "" # Stripe expects raw key + + google_drive: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + auth: + provider: service_account + sa_json_secret: GOOGLE_DRIVE_SA_JSON + scopes: + - https://www.googleapis.com/auth/drive + + fhir_epic: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + base_url: "${EPIC_FHIR_BASE_URL:https://fhir.epic.sandbox/api/FHIR/R4}" + auth: + provider: oauth2 + grant_method: private_key_jwt + token_url_secret: EPIC_TOKEN_URL + client_id_secret: EPIC_CLIENT_ID + private_key_secret: EPIC_PRIVATE_KEY + kid_secret: EPIC_KID + algorithm: RS384 + + fhir_cerner: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + base_url: "${CERNER_FHIR_BASE_URL:https://fhir-ehr-code.cerner.com/r4/your-tenant-id}" + auth: + provider: oauth2 + grant_method: private_key_jwt + token_url_secret: CERNER_TOKEN_URL + client_id_secret: CERNER_CLIENT_ID + private_key_secret: CERNER_PRIVATE_KEY + kid_secret: CERNER_KID + algorithm: RS384 + scopes_secret: CERNER_SCOPES + scopes: + - system/Patient.read + - system/Encounter.read + - system/DocumentReference.read + - system/DocumentReference.write + + salesforce: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + auth: + provider: oauth2 + grant_method: refresh_token + token_url_secret: SALESFORCE_TOKEN_URL + client_id_secret: SALESFORCE_CLIENT_ID + client_secret_secret: SALESFORCE_CLIENT_SECRET + refresh_token_secret: SALESFORCE_REFRESH_TOKEN + + slack: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + auth: + provider: static_token + secret_key: SLACK_BOT_TOKEN \ No newline at end of file diff --git a/docker/slack/Dockerfile b/docker/slack/Dockerfile new file mode 100644 index 0000000..8b4b3c4 --- /dev/null +++ b/docker/slack/Dockerfile @@ -0,0 +1,34 @@ +FROM python:3.12-slim@sha256:3d5ed973e45820f5ba5e46bd065bd88b3a504ff0724d85980dcd05eab361fcf4 + +LABEL org.opencontainers.image.title="nw-slack" \ + org.opencontainers.image.description="Node Wire — Slack MCP server" \ + org.opencontainers.image.source="https://github.com/AOT-Technologies/node-wire" + +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY src/ ./src/ +COPY config/ ./config/ +COPY packages/runtime/dist/*.whl /wheels/ +COPY packages/connectors/slack/dist/*.whl /wheels/ + +ENV PYTHONPATH=/app/src \ + NW_ALLOWED_CONNECTORS=slack + +RUN pip install --no-cache-dir --find-links=/wheels \ + node-wire-runtime node-wire-slack "mcp>=1.6.0" \ + && rm -rf /wheels + +RUN groupadd --system --gid 1000 app \ + && useradd --system --uid 1000 --gid app --home /app app \ + && chown -R app:app /app + +USER app + +HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ + python -c "from agents.slack_mcp import main; assert callable(main); print('ok')" || exit 1 + +CMD ["python", "-m", "agents.slack_mcp"] diff --git a/docs/connectors.md b/docs/connectors.md index f14bcf7..05c3176 100644 --- a/docs/connectors.md +++ b/docs/connectors.md @@ -476,6 +476,7 @@ Published **`input_schema` omits the `action` property** (manifest contract v2+) | `fhir_epic` | `read_patient`, `search_patients`, `search_encounter`, `create_document_reference`, `search_document_reference` | | `fhir_cerner` | Same family as Epic with Cerner-specific schemas | +| `slack` | `post_message`, `send_direct_message`, `upload_file` | MCP tool names: **`.`** (e.g. `fhir_epic.read_patient`). See [`docs/mcp-servers.md`](mcp-servers.md). @@ -533,5 +534,6 @@ connectors: - [mcp-servers.md](mcp-servers.md) — MCP images, ToolHive, env vars. - [google_drive_connector.md](google_drive_connector.md) — Drive REST API and setup. - [salesforce_connector.md](salesforce_connector.md) — Salesforce CRM operations and playground. +- [slack_connector.md](slack_connector.md) — Slack bot token and setup. - Per-connector READMEs under `src/node_wire_*/README.md` where present. diff --git a/docs/mcp-servers.md b/docs/mcp-servers.md index 36b731a..b635e12 100644 --- a/docs/mcp-servers.md +++ b/docs/mcp-servers.md @@ -32,11 +32,13 @@ flowchart TD Epic[nw-smartonfhir-epic] Cerner[nw-smartonfhir-cerner] SMTP[nw-smtp] + Slack[nw-slack] end Agent -->|"TOOLHIVE_MCP_URLS"| GDrive Agent -->|"TOOLHIVE_MCP_URLS"| Epic Agent -->|"TOOLHIVE_MCP_URLS"| Cerner Agent -->|"TOOLHIVE_MCP_URLS"| SMTP + Agent -->|"TOOLHIVE_MCP_URLS"| Slack ``` --- @@ -51,6 +53,7 @@ flowchart TD | SMTP | `python -m agents.smtp_mcp` | `nw-smtp` | `nw-smtp` | `smtp.send_email` | | Stripe | `python -m agents.stripe_mcp` | `nw-stripe` | `nw-stripe` | All manifest actions for `stripe` (e.g., `stripe.charge`) | | Salesforce | `python -m agents.salesforce_mcp` | `nw-salesforce` | `nw-salesforce` | All manifest actions for `salesforce` (e.g., `salesforce.create_lead`) | +| Slack | `python -m agents.slack_mcp` | `nw-slack` | `nw-slack` | All manifest actions for `slack` (e.g. `slack.post_message`) | The unified server (`python -m agents.mcp_entrypoint`) exposes **every** connector enabled for MCP in `config/connectors.yaml` (e.g. `http_generic.request`, `stripe.charge`, `stripe.create_payment_intent`, `stripe.create_subscription`, `stripe.cancel_subscription`, `stripe.issue_refund`, plus the rows above). @@ -328,6 +331,18 @@ SALESFORCE_REFRESH_TOKEN=your-refresh-token ``` +#### `nw-slack` + +| Variable | Description | +|---|---| +| `SLACK_BOT_TOKEN` | Slack Bot User OAuth Token (`xoxb-...`) | +| `NW_SLACK_ATTACHMENTS_DIR` | Optional: sandboxed directory for uploads (default: `/slack_attachments`) | + +```env +SLACK_BOT_TOKEN=xoxb-your-bot-token +NW_SLACK_ATTACHMENTS_DIR=/slack_attachments +``` + ### ToolHive / Agent settings | Variable | Description | @@ -377,6 +392,7 @@ This produces images tagged as both `latest` and the version string: | `nw-smartonfhir-epic` | `nw-smartonfhir-epic:latest`, `nw-smartonfhir-epic:0.1.0` | | `nw-smartonfhir-cerner` | `nw-smartonfhir-cerner:latest`, `nw-smartonfhir-cerner:0.1.0` | | `nw-smtp` | `nw-smtp:latest`, `nw-smtp:0.1.0` | +| `nw-slack` | `nw-slack:latest`, `nw-slack:0.1.0` | To build a single image manually from the repo root: @@ -392,6 +408,9 @@ docker build -f docker/fhir-cerner/Dockerfile -t nw-smartonfhir-cerner:latest . # SMTP only docker build -f docker/smtp/Dockerfile -t nw-smtp:latest . + +# Slack only +docker build -f docker/slack/Dockerfile -t nw-slack:latest . ``` > **Note:** The build context must be the repository root (`.`) so the `COPY src/` and `COPY config/` instructions resolve correctly. @@ -454,6 +473,11 @@ thv run --name nw-smtp --transport stdio \ --secret SMTP_PASSWORD,target=SMTP_PASSWORD \ --secret FROM_EMAIL,target=FROM_EMAIL \ nw-smtp:latest + +# Slack +thv run --name nw-slack --transport stdio \ + --secret SLACK_BOT_TOKEN,target=SLACK_BOT_TOKEN \ + nw-slack:latest ``` > **Google Drive + ToolHive:** Set `GOOGLE_DRIVE_SA_JSON` to the JSON *contents* (not a file path) when storing in ToolHive secrets, because ToolHive injects secrets as string values. diff --git a/docs/slack_connector.md b/docs/slack_connector.md new file mode 100644 index 0000000..d5ac828 --- /dev/null +++ b/docs/slack_connector.md @@ -0,0 +1,154 @@ +# Slack Connector + +This document covers the Slack connector under `src/node_wire_slack` in two parts: + +1. **[Slack Bot Setup](#slack-bot-setup)** — Create a Slack app, configure OAuth scopes, and obtain your bot token. +2. **[REST API Reference](#rest-api-reference)** — Connector actions, request/response shapes, and flexible channel resolution. + +For **MCP** (e.g. ToolHive), tools are named `slack.` from the connector manifest (e.g. `slack.post_message`). + +--- + +## Slack Bot Setup + +The Slack connector uses a **Bot User OAuth Token** to interact with your workspace. + +### Prerequisites + +- A Slack workspace where you have permission to install apps. +- [Slack API Dashboard](https://api.slack.com/apps) access. + +### Step 1: Create a Slack App + +1. Go to [api.slack.com/apps](https://api.slack.com/apps) and click **Create New App**. +2. Select **From scratch**. +3. Give your app a name (e.g., `Node-Wire Connector`) and select your workspace. +4. Click **Create App**. + +### Step 2: Configure Scopes + +1. In the left sidebar, go to **OAuth & Permissions**. +2. Scroll down to **Scopes > Bot Token Scopes**. +3. Add the following scopes: + - `chat:write` — Send messages to channels and DMs. + - `files:write` — Upload and share files. + - `im:write` — Start direct messages with users. + - `conversations:open` — Resolve User IDs to DM channel IDs. + - `groups:read` (optional) — If you need to post to private channels the bot is invited to. + - `channels:read` (optional) — If you need to resolve channel names. + +### Step 3: Install and Get Token + +1. Scroll back up to the top of the **OAuth & Permissions** page. +2. Click **Install to Workspace**. +3. Click **Allow** to authorize the bot. +4. Copy the **Bot User OAuth Token** (it starts with `xoxb-`). + +### Step 4: Configure the Connector + +Add the token to your `.env` file: + +```env +SLACK_BOT_TOKEN=xoxb-your-token-here +``` + +### Step 5: Invite the Bot (Important) + +Slack bots cannot "see" private channels unless they are explicitly invited. + +1. Go to the Slack channel you want the bot to use. +2. Type `/invite @YourAppName` and press Enter. + +--- + +## REST API Reference + +The connector exposes actions as standard REST endpoints. Channel identifiers are flexible and automatically resolved. + +### Operations overview + +- Connector ID: `slack` +- Base REST path: `POST /connectors/slack/{action}` + +### Actions + +#### `post_message` + +Send a message to a channel, group, or user. + +**Request body:** + +```json +{ + "channel": "#general", + "message": "Clinical alert: Patient summary available.", + "blocks": [ + { + "type": "section", + "text": { "type": "mrkdwn", "text": "*Emergency Update*: BP 180/110" } + } + ] +} +``` + +**Channel Resolution:** +- **Channel Name**: Starts with `#` (e.g., `#general`). +- **Channel ID**: Starts with `C` or `G` (e.g., `C12345`). +- **User ID**: Starts with `U` or `W` (e.g., `U12345`). Automatically resolved to a DM channel. + +#### `send_direct_message` + +A specialized action for DMs. If targeted at a User ID, the connector ensures the DM channel is open before posting. + +**Request body:** + +```json +{ + "channel": "U12345678", + "message": "You have a new lab result to review." +} +``` + +#### `upload_file` + +Uploads a file to a Slack channel or DM. + +**Request body (Base64):** + +```json +{ + "channel": "C12345678", + "filename": "labs.pdf", + "content_base64": "JVBER...", + "initial_comment": "Here is the PDF summary." +} +``` + +**Request body (Filesystem):** + +```json +{ + "channel": "U12345678", + "filename": "summary.pdf", + "filepath": "/slack_attachments/p_123.pdf" +} +``` + +> **Note:** `filepath` must be within the directory defined by `NW_SLACK_ATTACHMENTS_DIR` (default `/slack_attachments`). + +### Error Taxonomy + +| Category | Platform Code | Cause | +|---|---|---| +| `AUTH` | `SLACK_AUTH_ERROR` | Invalid or revoked token | +| `AUTH` | `SLACK_PERMISSION_ERROR` | Missing OAuth scope | +| `RETRYABLE` | `SLACK_RATE_LIMIT` | Slack rate limit (429) | +| `BUSINESS` | `SLACK_MESSAGE_ERROR` | Channel not found or invalid payload | +| `BUSINESS` | `SLACK_UPLOAD_ERROR` | File too large or bad content | + +--- + +### Related + +- Individual MCP Servers: [docs/mcp-servers.md](mcp-servers.md) +- Connector Architecture: [docs/connectors.md](connectors.md) diff --git a/packages/connectors/slack/pyproject.toml b/packages/connectors/slack/pyproject.toml new file mode 100644 index 0000000..199859e --- /dev/null +++ b/packages/connectors/slack/pyproject.toml @@ -0,0 +1,22 @@ +[project] +name = "node-wire-slack" +version = "0.1.0" +description = "Node Wire connector — Slack API (messaging and file uploads)" +requires-python = ">=3.11" +authors = [{ name = "AOT", email = "dev@aot.local" }] + +dependencies = [ + "node-wire-runtime>=0.1.0", + "httpx>=0.27.0", +] + +[project.entry-points."node_wire.connectors"] +slack = "node_wire_slack.logic" + +[build-system] +requires = ["setuptools>=69.0.0", "cython>=3.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["../../../src"] +include = ["node_wire_slack*"] diff --git a/packages/connectors/slack/setup.py b/packages/connectors/slack/setup.py new file mode 100644 index 0000000..801faaa --- /dev/null +++ b/packages/connectors/slack/setup.py @@ -0,0 +1,16 @@ +import glob, os +from Cython.Build import cythonize +from setuptools import setup +from setuptools.command.build_py import build_py as _BuildPy + +class NoPyBuild(_BuildPy): + def find_package_modules(self, package, package_dir): + return [] + +src_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../src/node_wire_slack")) +py_files = glob.glob(os.path.join(src_root, "**", "*.py"), recursive=True) + +setup( + cmdclass={"build_py": NoPyBuild}, + ext_modules=cythonize(py_files, compiler_directives={"language_level": "3"}, build_dir="build"), +) diff --git a/playground/app.js b/playground/app.js index f473a86..c75ac0f 100644 --- a/playground/app.js +++ b/playground/app.js @@ -45,6 +45,19 @@ document.addEventListener('DOMContentLoaded', () => { const gdriveUploadOnly = document.getElementById('gdrive-upload-only'); const gdriveListOnly = document.getElementById('gdrive-list-only'); const gdriveSubNav = document.getElementById('gdrive-sub-nav'); + const slackForm = document.getElementById('slack-form'); + const slackRunBtn = document.getElementById('slack-run-btn'); + const slackSpinner = slackRunBtn?.querySelector('.loading-spinner'); + const slackBtnText = slackRunBtn?.querySelector('.btn-lbl'); + const slackPanel = document.getElementById('slack-panel'); + const slackActionSelect = document.getElementById('slack-action-select'); + const slackMessageSection = document.getElementById('slack-message-section'); + const slackFileSection = document.getElementById('slack-file-section'); + const slackFileInput = document.getElementById('slack-file'); + const slackFileDropZone = document.getElementById('slack-file-drop-zone'); + const slackFileChosenPreview = document.getElementById('slack-file-chosen-preview'); + const slackPreviewName = slackFileChosenPreview?.querySelector('.preview-name'); + const slackRemoveFileBtn = slackFileChosenPreview?.querySelector('.remove-file-btn'); const fileDropZone = document.getElementById('file-drop-zone'); const fileChosenPreview = document.getElementById('file-chosen-preview'); const previewName = fileChosenPreview?.querySelector('.preview-name'); @@ -142,25 +155,31 @@ document.addEventListener('DOMContentLoaded', () => { "Verify file metadata", "Complete update" ], + slack: [ + "Format Slack Payload", + "Dispatch to Slack API", + "Verify Acknowledgment", + "Update Audit Trail", + ], stripe_charge: [ "Initialize Payment", "Process Charge", - "Verify Transaction" + "Verify Transaction", ], stripe_payment_intent: [ "Initialize Session", "Create Payment Intent", - "Verify Allocation" + "Verify Allocation", ], stripe_subscription: [ "Validate Customer", "Create Subscription", - "Verify Provisioning" + "Verify Provisioning", ], stripe_cancel_subscription: [ "Locate Resource", "Cancel Subscription", - "Verify Termination" + "Verify Termination", ], stripe_refund: [ "Validate Charge", @@ -191,9 +210,7 @@ document.addEventListener('DOMContentLoaded', () => { "Authenticate CRM", "Execute Soft Delete", "Verify Termination" - ] - - + ], }; const nodes = [ @@ -493,6 +510,7 @@ document.addEventListener('DOMContentLoaded', () => { gdrivePanel.classList.add('hidden'); stripePanel.classList.add('hidden'); salesforcePanel.classList.add('hidden'); + if (slackPanel) slackPanel.classList.add('hidden'); if (mode === 'ehr') { @@ -531,6 +549,12 @@ document.addEventListener('DOMContentLoaded', () => { tagline.textContent = 'CRM Orchestration'; document.documentElement.style.setProperty('--brand-accent', '#00A1E0'); log('Switched to Salesforce CRM Orchestration mode', 'system'); + } else if (mode === 'slack') { + if (slackPanel) slackPanel.classList.remove('hidden'); + connectorStatus.textContent = 'Slack Online'; + tagline.textContent = 'Team Collaboration & Notifications'; + document.documentElement.style.setProperty('--brand-accent', '#4A154B'); + log('Switched to Slack Operations mode', 'system'); } if (mode === 'gdrive') { syncGdriveActionForm(); @@ -1107,6 +1131,107 @@ document.addEventListener('DOMContentLoaded', () => { } }); + if (slackActionSelect) { + slackActionSelect.addEventListener('change', () => { + const action = slackActionSelect.value; + if (action === 'upload_file') { + if (slackMessageSection) slackMessageSection.classList.add('hidden'); + if (slackFileSection) slackFileSection.classList.remove('hidden'); + } else { + if (slackMessageSection) slackMessageSection.classList.remove('hidden'); + if (slackFileSection) slackFileSection.classList.add('hidden'); + } + }); + } + + if (slackFileInput && slackFileChosenPreview && slackPreviewName && slackFileDropZone) { + slackFileInput.addEventListener('change', () => { + if (slackFileInput.files.length > 0) { + const fileName = slackFileInput.files[0].name; + slackPreviewName.textContent = fileName; + slackFileChosenPreview.classList.remove('hidden'); + slackFileDropZone.classList.add('hidden'); + } + }); + } + + if (slackRemoveFileBtn && slackFileInput && slackFileChosenPreview && slackFileDropZone) { + slackRemoveFileBtn.addEventListener('click', (e) => { + e.stopPropagation(); + slackFileInput.value = ''; + slackFileChosenPreview.classList.add('hidden'); + slackFileDropZone.classList.remove('hidden'); + }); + } + + if (slackFileDropZone) { + slackFileDropZone.addEventListener('dragover', (e) => { + e.preventDefault(); + slackFileDropZone.style.borderColor = 'var(--brand-accent)'; + slackFileDropZone.style.background = 'rgba(255, 255, 255, 0.08)'; + }); + + slackFileDropZone.addEventListener('dragleave', () => { + slackFileDropZone.style.borderColor = ''; + slackFileDropZone.style.background = ''; + }); + + slackFileDropZone.addEventListener('drop', (e) => { + e.preventDefault(); + slackFileDropZone.style.borderColor = ''; + slackFileDropZone.style.background = ''; + if (slackFileInput && e.dataTransfer.files.length > 0) { + slackFileInput.files = e.dataTransfer.files; + slackFileInput.dispatchEvent(new Event('change')); + } + }); + } + + if (slackForm) { + slackForm.addEventListener('submit', async (e) => { + e.preventDefault(); + const formData = new FormData(slackForm); + const payload = Object.fromEntries(formData.entries()); + + if (payload.action === 'upload_file' && slackFileInput && slackFileInput.files.length > 0) { + const file = slackFileInput.files[0]; + const reader = new FileReader(); + + resetUI(); + if (slackRunBtn) slackRunBtn.disabled = true; + if (slackSpinner) slackSpinner.classList.remove('hidden'); + if (slackBtnText) slackBtnText.textContent = 'Formatting payload...'; + + reader.onload = async (event) => { + try { + const base64Data = event.target.result.split(',')[1]; + payload.content_base64 = base64Data; + // Always override filename with actual file name if uploaded directly + payload.filename = file.name; + + await handleSubmission(payload, '/scenarios/slack-messaging', slackRunBtn, slackBtnText, slackSpinner, 'Send to Slack'); + } catch (error) { + log(`File parsing error: ${error.message}`, 'error'); + if (slackBtnText) slackBtnText.textContent = 'System Error'; + if (slackRunBtn) slackRunBtn.disabled = false; + if (slackSpinner) slackSpinner.classList.add('hidden'); + } + }; + + reader.onerror = () => { + log('Failed to read binary file from memory.', 'error'); + if (slackBtnText) slackBtnText.textContent = 'System Error'; + if (slackRunBtn) slackRunBtn.disabled = false; + if (slackSpinner) slackSpinner.classList.add('hidden'); + }; + + reader.readAsDataURL(file); + } else { + await handleSubmission(payload, '/scenarios/slack-messaging', slackRunBtn, slackBtnText, slackSpinner, 'Send to Slack'); + } + }); + } + // ====================================================== // AI Agent Chat Logic // ====================================================== diff --git a/playground/index.html b/playground/index.html index 3d39c0f..65050ae 100644 --- a/playground/index.html +++ b/playground/index.html @@ -206,6 +206,18 @@

Google Drive

+
+
+ + + +
+
+

Slack

+

Intelligent Team Notifications & File Uploads.

+
+
+
@@ -231,7 +243,6 @@

Salesforce

Lead and contact management for CRM-driven enterprise workflows.

- @@ -535,19 +546,85 @@ + - +
@@ -539,7 +539,7 @@
- + + + diff --git a/playground/scenarios.py b/playground/scenarios.py index 9a80d03..883c7d3 100644 --- a/playground/scenarios.py +++ b/playground/scenarios.py @@ -57,11 +57,13 @@ SlackSendDirectMessageInput, SlackUploadFileInput, ) +from .ext_patient_viewer.schema import ExternalPatientViewerInput load_dotenv() ErrorMapper.register(ValidationError, ErrorCategory.BUSINESS, code="UNSUPPORTED_OPERATION") +ErrorMapper.register(ValueError, ErrorCategory.BUSINESS, code="VALIDATION_ERROR") load_dotenv() @@ -413,9 +415,13 @@ def add_step( patient_id = payload.patient_id else: patient_search_params = { - "family": payload.patient_family, - "given": payload.patient_given, - "birthdate": payload.patient_birthdate, + k: v + for k, v in { + "family": payload.patient_family, + "given": payload.patient_given, + "birthdate": payload.patient_birthdate, + }.items() + if v is not None } logger.info(f"Searching for patient: {patient_search_params}") p_res = await execute_with_retry( @@ -1361,7 +1367,7 @@ def add_step( add_step("Drive List", "pending", display_name="List Drive Files") try: raw_ps = payload.list_page_size - page_size = 10 if raw_ps is None else int(raw_ps) + page_size = 10 if raw_ps is None else raw_ps page_size = max(1, min(100, page_size)) q = (payload.list_query or "").strip() or None fields = (payload.list_fields or "").strip() or None @@ -2399,3 +2405,371 @@ def add_step(name, status, display_name): ) except Exception as e: return _safe_error_return(e, steps, trace_id, "Delete failed") + + +# --------------------------------------------------------------------------- +# External Patient Viewer — Read-Only Retrieval +# --------------------------------------------------------------------------- + + +def _get_viewer_connector(source_system: str) -> Any: + """Return the correct FHIR connector based on source_system string.""" + if source_system.lower() == "cerner": + return get_cerner_connector() + return get_fhir_connector() + + +@router.post("/external-patient-viewer", response_model=ScenarioResponse) +async def external_patient_viewer_scenario( + payload: ExternalPatientViewerInput, +) -> ScenarioResponse: + """ + 4-step read-only workflow: resolve patient identity, retrieve demographics, + retrieve encounter history, retrieve document metadata. + + No FHIR resource is created or mutated during this workflow. + Encounter-as-document fallback is applied when document_references are absent. + """ + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + is_epic = payload.source_system.lower() != "cerner" + + def add_step( + name: str, status: str, details: str = "", display_name: str = "", data: Any = None + ): + steps.append( + ScenarioStep( + name=name, status=status, details=details, display_name=display_name, data=data + ) + ) + + connector = _get_viewer_connector(payload.source_system) + system_label = "Epic FHIR R4" if is_epic else "Cerner FHIR R4" + + # ── STEP 1: Patient Resolution ────────────────────────────────────────── + add_step("Patient Resolution", "pending", display_name="Resolve Patient Identity") + try: + if payload.patient_id: + logger.info( + "[ExtViewer] Direct Patient ID lookup: %s on %s", + payload.patient_id, + system_label, + extra={"trace_id": trace_id}, + ) + if is_epic: + p_res = await execute_with_retry( + connector, + FhirPatientReadInput(resource_id=payload.patient_id), + trace_id, + steps[-1], + ) + else: + p_res = await execute_with_retry( + connector, + FhirCernerPatientReadInput(resource_id=payload.patient_id), + trace_id, + steps[-1], + ) + patient_id = payload.patient_id + patient_resource = p_res.resource or {} + else: + # Identity-layer search: resolve via name + birthdate + if not (payload.patient_family or payload.patient_given): + raise ValueError( + "Provide either patient_id or at least one name field (given/family) " + "to resolve patient identity." + ) + search_params = { + k: v + for k, v in { + "family": payload.patient_family, + "given": payload.patient_given, + "birthdate": payload.patient_birthdate, + }.items() + if v + } + logger.info( + "[ExtViewer] Identity-layer search: %s on %s", + search_params, + system_label, + extra={"trace_id": trace_id}, + ) + if is_epic: + p_res = await execute_with_retry( + connector, + FhirPatientReadInput(search_params=search_params), + trace_id, + steps[-1], + ) + else: + p_res = await execute_with_retry( + connector, + FhirCernerPatientReadInput(search_params=search_params), + trace_id, + steps[-1], + ) + patient_resource = p_res.resource or {} + patient_id = patient_resource.get("id") + + if not patient_id: + raise ValueError("Patient could not be resolved. No matching record found.") + + # Extract display name from FHIR resource + name_obj = patient_resource.get("name", [{}]) + if name_obj and isinstance(name_obj, list): + official = next((n for n in name_obj if n.get("use") == "official"), name_obj[0]) + else: + official = {} + given_parts = official.get("given", []) + family_part = official.get("family", "") + patient_display = f"{' '.join(given_parts)} {family_part}".strip() or ( + f"{payload.patient_given or ''} {payload.patient_family or ''}".strip() or patient_id + ) + patient_dob = patient_resource.get("birthDate", "Unknown") + patient_gender = patient_resource.get("gender", "Unknown") + + steps[-1].status = "success" + steps[-1].details = f"Resolved: {patient_display} (ID: {patient_id})" + steps[-1].display_name = f"Identity Resolved: {patient_display}" + steps[-1].data = { + "patient_id": patient_id, + "display_name": patient_display, + "dob": patient_dob, + "gender": patient_gender, + "source_system": system_label, + "raw": patient_resource, + } + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 1 — Patient Resolution failed") + + # ── STEP 2: Encounter History ──────────────────────────────────────────── + add_step("Encounter History", "pending", display_name="Retrieve Encounter History") + encounters: List[Any] = [] + try: + max_enc = max(1, min(20, payload.max_encounters)) + + if is_epic: + enc_res = await execute_with_retry( + connector, + FhirEncounterSearchInput( + search_params={"patient": patient_id, "_count": str(max_enc)} + ), + trace_id, + steps[-1], + ) + else: + enc_res = await execute_with_retry( + connector, + FhirCernerEncounterSearchInput( + search_params={"patient": patient_id, "_count": str(max_enc)} + ), + trace_id, + steps[-1], + ) + + encounters = enc_res.resources or [] + enc_count = len(encounters) + + most_recent_enc: dict = {} + if encounters: + most_recent_enc = encounters[0] + recent_enc_type = ( + (most_recent_enc.get("type") or [{}])[0].get("text", "Encounter") + if most_recent_enc + else "None" + ) + recent_enc_date = ( + most_recent_enc.get("period", {}).get("start", "Unknown date") + if most_recent_enc + else "N/A" + ) + + steps[-1].status = "success" + steps[-1].details = ( + f"Retrieved {enc_count} encounter(s). " + f"Most recent: {recent_enc_type} on {recent_enc_date}" + if enc_count + else "No encounters found for this patient." + ) + steps[-1].display_name = ( + f"Encounter History: {enc_count} record(s)" if enc_count else "No Encounters Found" + ) + steps[-1].data = { + "encounter_count": enc_count, + "encounters": encounters, + "raw": {"total": enc_count, "entries": encounters}, + } + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 2 — Encounter Retrieval failed") + + # ── STEP 3: Document Metadata ──────────────────────────────────────────── + add_step("Document Metadata", "pending", display_name="Retrieve Document Metadata") + documents: List[Any] = [] + doc_source = "fhir" + try: + max_docs = max(1, min(50, payload.max_documents)) + + if is_epic: + doc_res = await execute_with_retry( + connector, + FhirDocumentReferenceSearchInput( + search_params={"patient": patient_id, "_count": str(max_docs)} + ), + trace_id, + steps[-1], + ) + else: + doc_res = await execute_with_retry( + connector, + FhirCernerDocumentReferenceSearchInput( + search_params={"patient": patient_id, "_count": str(max_docs)} + ), + trace_id, + steps[-1], + ) + + documents = doc_res.resources or [] + + # Encounter-as-document fallback: when no DocumentReference exists, + # synthesise a lightweight document record from each encounter entry. + if not documents and encounters: + doc_source = "encounter_fallback" + for enc in encounters[:max_docs]: + enc_id = enc.get("id", "unknown") + enc_type_text = (enc.get("type") or [{}])[0].get("text", "Clinical Encounter") + enc_date = enc.get("period", {}).get("start", "Unknown") + enc_status = enc.get("status", "unknown") + documents.append( + { + "id": f"ENC-{enc_id}", + "resourceType": "EncounterFallback", + "status": enc_status, + "type": {"text": enc_type_text}, + "date": enc_date, + "description": "Encounter summary (no DocumentReference found)", + "subject": {"reference": f"Patient/{patient_id}"}, + "_synthetic": True, + } + ) + logger.info( + "[ExtViewer] No DocumentReferences found; using %d encounter fallback record(s)", + len(documents), + extra={"trace_id": trace_id}, + ) + + doc_count = len(documents) + fallback_note = " (encounter-fallback)" if doc_source == "encounter_fallback" else "" + + steps[-1].status = "success" + steps[-1].details = ( + f"Retrieved {doc_count} document(s){fallback_note}." + if doc_count + else "No documents or encounters available for this patient." + ) + steps[-1].display_name = ( + f"Documents: {doc_count} record(s){fallback_note}" + if doc_count + else "No Documents Found" + ) + steps[-1].data = { + "document_count": doc_count, + "source": doc_source, + "documents": documents, + "raw": {"total": doc_count, "source": doc_source, "entries": documents}, + } + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 3 — Document Retrieval failed") + + # ── STEP 4: Viewer Assembly ────────────────────────────────────────────── + add_step("Chart Assembly", "pending", display_name="Assemble External Chart View") + try: + enc_lines = [] + for i, enc in enumerate(encounters[:5]): + enc_type = (enc.get("type") or [{}])[0].get("text", "Encounter") + enc_date = enc.get("period", {}).get("start", "Unknown") + enc_status = enc.get("status", "?") + enc_lines.append(f" [{i + 1}] {enc_type} | {enc_date} | Status: {enc_status}") + + doc_lines = [] + for i, doc in enumerate(documents[:5]): + d_type = doc.get("type", {}).get("text") or doc.get("description", "Document") + d_date = doc.get("date") or doc.get("period", {}).get("start", "Unknown") + d_status = doc.get("status", "?") + d_synth = " [enc-fallback]" if doc.get("_synthetic") else "" + doc_lines.append(f" [{i + 1}] {d_type}{d_synth} | {d_date} | Status: {d_status}") + + content_lines = ( + [ + f"=== External Patient Chart ({system_label}) ===", + f"Patient : {patient_display}", + f"FHIR ID : {patient_id}", + f"DOB : {patient_dob}", + f"Gender : {patient_gender}", + "", + f"--- Encounter History ({len(encounters)} record(s)) ---", + ] + + (enc_lines if enc_lines else [" No encounters found."]) + + [ + "", + f"--- Documents ({len(documents)} record(s)" + + (" — encounter fallback" if doc_source == "encounter_fallback" else "") + + ") ---", + ] + + (doc_lines if doc_lines else [" No documents found."]) + + [ + "", + "[READ-ONLY] No data was written to the source system.", + ] + ) + + beautiful_data = { + "id": f"CHART-{patient_id}", + "type": "External Patient Chart", + "date": datetime.now(tz=timezone.utc).isoformat(), + "status": "READ-ONLY", + "patient_name": patient_display, + "author": system_label, + "category": "Clinical Chart View", + "description": ( + f"{len(encounters)} Encounter(s) · " + f"{len(documents)} Document(s)" + + (" [enc-fallback]" if doc_source == "encounter_fallback" else "") + ), + "content_text": "\n".join(content_lines), + } + + steps[-1].status = "success" + steps[-1].details = ( + f"Chart assembled. {len(encounters)} encounter(s), " + f"{len(documents)} document(s). Read-only — 0 writes." + ) + steps[-1].display_name = "Chart Ready (Read-Only)" + steps[-1].data = { + "patient_id": patient_id, + "encounter_count": len(encounters), + "document_count": len(documents), + "document_source": doc_source, + "read_only": True, + "raw": { + "patient_id": patient_id, + "source_system": system_label, + "encounters": len(encounters), + "documents": len(documents), + "document_source": doc_source, + }, + "beautiful_data": beautiful_data, + } + + return ScenarioResponse( + success=True, + steps=steps, + final_resource_id=patient_id, + human_summary=( + f"External chart loaded for {patient_display} from {system_label}. " + f"{len(encounters)} encounter(s) and {len(documents)} document(s) retrieved. " + "No data was written to the source system." + ), + trace_id=trace_id, + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 4 — Chart Assembly failed") diff --git a/playground/style.css b/playground/style.css index 94a703b..c8441e7 100644 --- a/playground/style.css +++ b/playground/style.css @@ -653,13 +653,14 @@ textarea:focus { .completion-icon { width: 2rem; height: 2rem; - background: rgba(16, 185, 129, 0.1); - color: var(--success); + background: var(--success); + color: white; border-radius: 50%; display: flex; align-items: center; justify-content: center; flex-shrink: 0; + box-shadow: 0 6px 14px rgba(16, 185, 129, 0.24); } .completion-icon svg { @@ -694,8 +695,9 @@ textarea:focus { } .completion-card.error-toast .completion-icon { - background: rgba(244, 63, 94, 0.1); - color: var(--error); + background: var(--error); + color: white; + box-shadow: 0 6px 14px rgba(244, 63, 94, 0.24); } /* Logs */ @@ -980,6 +982,109 @@ textarea:focus { border: 1px solid rgba(0,0,0,0.07); } +/* Slack: purple brand */ +.bg-slack { + background: #4A154B; +} + +/* Stripe: indigo-purple brand */ +.bg-stripe { + background: #635BFF; +} + +/* Salesforce: sky blue brand */ +.bg-salesforce { + background: #00A1E0; +} + +/* External Patient Viewer: teal read-only indicator */ +.bg-ext-viewer { + background: linear-gradient(135deg, #0d9488, #0891b2); +} + +/* Read-only badge used inside the chart viewer */ +.readonly-badge { + display: inline-flex; + align-items: center; + gap: 0.35rem; + background: rgba(13, 148, 136, 0.12); + color: #0d9488; + border: 1px solid rgba(13, 148, 136, 0.25); + padding: 0.2rem 0.6rem; + border-radius: 999px; + font-size: 0.7rem; + font-weight: 700; + text-transform: uppercase; + letter-spacing: 0.04em; +} + +/* Viewer scope controls (range slider row) */ +.viewer-scope-row { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 1.5rem; + margin-bottom: 1.5rem; +} + +.scope-field { + display: flex; + flex-direction: column; + gap: 0.4rem; +} + +.scope-field label { + font-size: 0.875rem; + font-weight: 600; + color: var(--text-main); + display: flex; + justify-content: space-between; + align-items: center; +} + +.scope-field label span.scope-val { + font-size: 0.8rem; + color: #0d9488; + font-weight: 700; +} + +input[type="range"] { + -webkit-appearance: none; + appearance: none; + width: 100%; + height: 6px; + background: linear-gradient(to right, #0d9488 0%, #0d9488 var(--pct, 25%), #e2e8f0 var(--pct, 25%), #e2e8f0 100%); + border-radius: 999px; + outline: none; + border: none; + padding: 0; + cursor: pointer; +} + +input[type="range"]::-webkit-slider-thumb { + -webkit-appearance: none; + width: 18px; + height: 18px; + border-radius: 50%; + background: #0d9488; + cursor: pointer; + box-shadow: 0 0 0 3px rgba(13,148,136,0.2); + transition: box-shadow 0.2s; +} + +input[type="range"]::-webkit-slider-thumb:hover { + box-shadow: 0 0 0 5px rgba(13,148,136,0.3); +} + +/* Viewer action button teal variant */ +.action-btn.btn-viewer { + background: linear-gradient(135deg, #0d9488, #0891b2); +} + +.action-btn.btn-viewer:hover { + background: #0f172a; + box-shadow: 0 15px 30px -10px rgba(13, 148, 136, 0.4); +} + .connector-details h3 { font-family: 'Outfit', sans-serif; font-size: 1.25rem; From d89f91398c6c6aa108b5000305b70b598fe6597c Mon Sep 17 00:00:00 2001 From: Rahul Ap Date: Fri, 15 May 2026 09:46:02 +0530 Subject: [PATCH 38/60] fix(factory): skip unregistered connectors filtered by NW_ALLOWED_CONNECTORS (#46) prevent MCP server startup crash when enabled connectors are filtered from registry add safety check in ConnectorFactory.load() log warning instead of raising RuntimeError allow Salesforce and Stripe MCP servers to continue startup gracefully improve resilience against connector config and registry mismatch --- src/bindings/factory.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/bindings/factory.py b/src/bindings/factory.py index bf0aff0..214425d 100644 --- a/src/bindings/factory.py +++ b/src/bindings/factory.py @@ -190,6 +190,16 @@ def load(self) -> None: ) continue + if connector_id not in _CONNECTOR_REGISTRY: + logger.warning( + "Connector enabled in configuration but not registered; skipping instantiation", + extra={ + "connector_id": connector_id, + "reason": "Filtered by NW_ALLOWED_CONNECTORS or not installed", + }, + ) + continue + instance = self._instantiate(connector_id) self._connectors[connector_id] = instance From 74177fe6d858ff22fbeac7f25069b1be82244720 Mon Sep 17 00:00:00 2001 From: kesav-aot Date: Wed, 20 May 2026 07:18:18 -0700 Subject: [PATCH 39/60] Consolidate setup into README and add compliance docs (#50) Remove the large standalone Setup.md and migrate essential setup, quick-start, and MCP image build/run instructions into README and existing docs. Add a new docs/code-quality-compliance.md describing Ruff, Mypy, Bandit, pip-audit, REUSE usage and local developer commands. Update docs: connectors (HTTP generic method normalization/log sanitization), installation (allowlist note, run modes, mypy guidance), mcp-servers (add Stripe/Salesforce images and thv examples, build script usage), and quality-security-gates (local check workflow, deterministic pytest env, SonarQube notes). These changes centralize onboarding, clarify MCP image/workflow, and surface compliance tooling for developers. --- README.md | 153 +++++---- Setup.md | 572 -------------------------------- docs/code-quality-compliance.md | 95 ++++++ docs/connectors.md | 2 +- docs/installation.md | 41 ++- docs/mcp-servers.md | 30 +- docs/quality-security-gates.md | 75 +++-- 7 files changed, 311 insertions(+), 657 deletions(-) delete mode 100644 Setup.md create mode 100644 docs/code-quality-compliance.md diff --git a/README.md b/README.md index 2e0bfae..1c1bc4d 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,18 @@ SPDX-License-Identifier: Apache-2.0 Node Wire is a three-layer Python platform that runs connector adapters (Google Drive, SMTP, Stripe, FHIR, etc.) and exposes them over REST, gRPC, or MCP. It provides a consistent execution contract with built-in validation, resilience, and telemetry. +## Prerequisites + +Before getting started, make sure you have: + +| Requirement | Version | Notes | +|---|---|---| +| Python | 3.11+ | Required to run the platform | +| `uv` or `pip` | Latest | `uv` is recommended for local development | +| Git | Any recent version | Required to clone the repository | +| Docker | Latest | Required for MCP server image builds and `docker-compose.mcp.yml` | +| Node.js | Any LTS | Only needed for MCP Inspector | + ## Quick Start ### 1. Install @@ -56,99 +68,124 @@ The platform includes an interactive web playground at [http://localhost:8000/pl --- -## Documentation - -For more detailed information, please refer to the following guides: - -- **[Architecture](docs/architecture.md)** — Layered design and data flow. -- **[Installation](docs/installation.md)** — Detailed setup and prerequisites. -- **[Configuration](docs/configuration.md)** — Environment variables and `connectors.yaml`. -- **[Connectors Guide](docs/connectors.md)** — How to use and build connectors. -- **[MCP Integration](docs/mcp.md)** — Using Node Wire with AI agents. -- **[Troubleshooting](docs/troubleshooting.md)** — Common errors and fixes. -- **[MCP Servers & Docker](docs/mcp-servers.md)** — Deploying individual connectors as MCP servers. -- **[Packaging & Publishing](docs/packaging.md)** — Wheel builds and CI flow. +## Build MCP Server Images -## Setup and development docs +Use this workflow when you want Docker images for the individual MCP servers such as Google Drive, SMTP, Stripe, Salesforce, or Slack. -- Platform setup (REST/gRPC/agents MCP): [Setup.md](Setup.md) -- Individual connector MCP servers (ToolHive): [docs/mcp-servers.md](docs/mcp-servers.md) -- Creating a new connector: [docs/connectors.md](docs/connectors.md) -- Quality/security gates (Bandit, SonarQube): [docs/quality-security-gates.md](docs/quality-security-gates.md) +### Build prerequisites ---- +Before building images, make sure: -## Code Quality (Linting & Formatting) +- Docker is installed and available on your shell path. +- You are running commands from the repository root. +- Local wheels have been built first. -This project uses **Ruff** for linting and formatting, and **Mypy** for static type checking. +See [docs/local-packages-to-images.md](docs/local-packages-to-images.md) for the full package -> image workflow and required wheel artifacts per image. -These checks are configured to run automatically in CI on Pull Requests against the `main` branch. +### Build all MCP server images -### Manual Usage for Developers -Make sure you have dev dependencies installed (`pip install -e ".[dev]"`). +All MCP server images are built from the repository root using the automation script: -- **Check formatting and linting errors:** `ruff check .` -- **Auto-fix and format code:** `ruff check --fix . && ruff format .` -- **Run static type validation:** `mypy` (paths default from `[tool.mypy]` `files` in `pyproject.toml`; avoid `mypy .`, which scans packaging `setup.py` scripts under `packages/`). To include tests: `mypy src tests`. +```bash +./scripts/build-mcp-images.sh +``` -### Pre-commit Hooks -You can attach `.pre-commit-config.yaml` so checks run on every commit: +To tag with a specific version (defaults to the version in `pyproject.toml`): ```bash -pre-commit install +./scripts/build-mcp-images.sh --version 0.1.0 ``` -To run all hooks across the repository: +This produces images tagged as both `latest` and the version string: + +| Image name | Tags | +|---|---| +| `nw-google-drive` | `nw-google-drive:latest`, `nw-google-drive:0.1.0` | +| `nw-smartonfhir-epic` | `nw-smartonfhir-epic:latest`, `nw-smartonfhir-epic:0.1.0` | +| `nw-smartonfhir-cerner` | `nw-smartonfhir-cerner:latest`, `nw-smartonfhir-cerner:0.1.0` | +| `nw-smtp` | `nw-smtp:latest`, `nw-smtp:0.1.0` | +| `nw-stripe` | `nw-stripe:latest`, `nw-stripe:0.1.0` | +| `nw-salesforce` | `nw-salesforce:latest`, `nw-salesforce:0.1.0` | +| `nw-slack` | `nw-slack:latest`, `nw-slack:0.1.0` | + +### Build one image manually + +To build a single image manually from the repo root: ```bash -pre-commit run --all-files +# Google Drive only +docker build -f docker/google-drive/Dockerfile -t nw-google-drive:latest . + +# Epic FHIR only +docker build -f docker/fhir-epic/Dockerfile -t nw-smartonfhir-epic:latest . + +# Cerner FHIR only +docker build -f docker/fhir-cerner/Dockerfile -t nw-smartonfhir-cerner:latest . + +# SMTP only +docker build -f docker/smtp/Dockerfile -t nw-smtp:latest . + +# Stripe only +docker build -f docker/stripe/Dockerfile -t nw-stripe:latest . + +# Salesforce only +docker build -f docker/salesforce/Dockerfile -t nw-salesforce:latest . + +# Slack only +docker build -f docker/slack/Dockerfile -t nw-slack:latest . ``` +> **Note:** The build context must be the repository root (`.`) so the `COPY src/` and `COPY config/` instructions resolve correctly. + --- -## Copyright Headers & Compliance +## Run MCP Servers with Docker Compose + +### Compose prerequisites -This repository enforces open-source licensing compliance using [REUSE](https://reuse.software/). All first-party source files must contain the appropriate SPDX copyright and license headers. +Before starting the MCP containers, make sure: -### Testing Compliance +- The MCP server images have already been built locally. +- Your `.env` file is populated with the credentials needed by the connectors you want to run. -To verify that all files have the correct headers, run the `reuse` lint tool: +`docker-compose.mcp.yml` starts all MCP servers as stdio containers in one command. This is useful for local validation before configuring ToolHive. ```bash -uv pip install reuse -uv run reuse lint +# Ensure your .env is populated, then: +docker compose -f docker-compose.mcp.yml up ``` -### Adding Missing Headers - -If `reuse lint` reports missing headers on new files, you can automatically add them by running: +To start only a specific server: ```bash -bash scripts/add-license-headers.sh +docker compose -f docker-compose.mcp.yml up nw-smartonfhir-epic ``` -## Dependency Inventory & Compliance +--- + +## Documentation -To maintain an open-source compliant ecosystem, we track all third-party dependencies and their licenses. +For more detailed information, please refer to the following guides: + +- **[Architecture](docs/architecture.md)** — Layered design and data flow. +- **[Installation](docs/installation.md)** — Detailed setup and prerequisites. +- **[Configuration](docs/configuration.md)** — Environment variables and `connectors.yaml`. +- **[Connectors Guide](docs/connectors.md)** — How to use and build connectors. +- **[MCP Integration](docs/mcp.md)** — Using Node Wire with AI agents. +- **[Troubleshooting](docs/troubleshooting.md)** — Common errors and fixes. +- **[MCP Servers & Docker](docs/mcp-servers.md)** — Deploying individual connectors as MCP servers. +- **[Packaging & Publishing](docs/packaging.md)** — Wheel builds and CI flow. +- **[Code Quality & Compliance](docs/code-quality-compliance.md)** — Ruff, Mypy, pre-commit, REUSE, and dependency compliance. -### License Classification Criteria -Dependencies are strictly evaluated against the following criteria: -* **✅ Safe (Permissive):** MIT, Apache-2.0, BSD, PSF. Universally safe for our Apache 2.0 release. -* **⚠️ Needs Review:** Custom or obscure licenses require manual review to ensure no conflicting obligations. -* **⛔ Risky (Copyleft):** GPLv2, GPLv3, AGPL. Strictly prohibited in the runtime application. Permitted *only* as isolated, non-distributed Development/Linting tools. +## Developer docs -### Updating the Dependency Inventory & Security Checks -When a new package is added to the project, or before creating a release, you must run the unified compliance script. +- Individual connector MCP servers (ToolHive): [docs/mcp-servers.md](docs/mcp-servers.md) +- Creating a new connector: [docs/connectors.md](docs/connectors.md) +- Code quality/compliance (Ruff, Mypy, REUSE, pip-audit): [docs/code-quality-compliance.md](docs/code-quality-compliance.md) +- Quality/security gates (Bandit, SonarQube): [docs/quality-security-gates.md](docs/quality-security-gates.md) -This script will: -1. Generate the `DEPENDENCIES.md` inventory. -2. Run **Bandit** for Static Application Security Testing (SAST). -3. Run **pip-audit** for vulnerability scanning across all dependencies. +--- -To automatically run these checks, execute: -```bash -bash scripts/run-compliance-checks.sh -``` ## License diff --git a/Setup.md b/Setup.md deleted file mode 100644 index a2fddac..0000000 --- a/Setup.md +++ /dev/null @@ -1,572 +0,0 @@ -# Node Wire — Setup Guide - -Node Wire is a Python framework that runs connector adapters (Google Drive, SMTP, FHIR, Stripe, and more) and exposes them over REST, gRPC, or MCP. It includes a built-in AI agent layer so LLMs can discover and orchestrate these connectors automatically. - ---- - -## Table of Contents - -- [Prerequisites](#prerequisites) -- [Installation](#installation) -- [Configuration](#configuration) -- [Running the Platform](#running-the-platform) -- [Connectors Overview](#connectors-overview) -- [Connector Setup](#connector-setup) -- [MCP Server & ToolHive](#mcp-server--toolhive) -- [Running Tests](#running-tests) -- [Playground UI](#playground-ui) - ---- - -## Prerequisites - - -| Requirement | Version | Notes | -| ----------- | ------- | --------------------------------------- | -| Python | 3.11+ | `python --version` to check | -| pip or uv | Latest | `pip install --upgrade pip` | -| Git | Any | To clone the repo | -| Docker | Latest | Only needed for ToolHive MCP deployment | -| Node.js | Any LTS | Only needed for `npx @modelcontextprotocol/inspector` | - - ---- - -## Installation - -```bash -# 1. Clone the repository -git clone -cd # the folder git creates (rename if you like) - -# 2. Install dependencies (recommended: uv) -uv sync --extra agents - -# 3. Verify the install -python -m uv run node-wire --help -``` - -> **Install uv:** See the official installer docs at `https://docs.astral.sh/uv/`. -> -> **REST/gRPC only** (no AI agent features): `uv sync` without the extra is sufficient. -> -> **Alternative (pip):** If you’re not using `uv`, install editable deps with pip: -> -> - `pip install -e ".[agents]"` (includes MCP/LLM agent dependencies) -> - `pip install -e .` (REST/gRPC only, no agent dependencies) - -> **Installing from PyPI wheels instead of source?** See [docs/packaging.md](docs/packaging.md) for the wheel build lifecycle, client install model, and pre-publish validation checklist. - ---- - -## Configuration - -All secrets and settings are loaded from environment variables. A template is provided at `sample.env`. - -```bash -# Copy the template -cp sample.env .env - -# Open and fill in the values you need -``` - -You only need to fill in the sections for the connectors you plan to use. The platform starts successfully even if some credentials are missing — those connectors will simply return an error when called. - -> [!IMPORTANT] -> **Connector Allowlist:** For security, Node Wire uses a fail-closed allowlist for connector entry points. You **must** set `NW_ALLOWED_CONNECTORS` in your `.env` file to a comma-separated list of the connectors you want to enable (e.g., `fhir_epic,http_generic`). If this variable is missing or empty, no connectors will be loaded even if they are enabled in `connectors.yaml`. - - -> **Doc convention:** Environment variable names in the docs follow `sample.env`. Some legacy keys (like `stripe_api_key`) are intentionally lower-case because that is what the connector reads. - -### Environment Variable Sections - - -| Section | Key Variables | When Needed | -| ---------------- | ------------------------------------------------------------------------------------------------------------------- | ---------------------- | -| **FHIR Epic** | `EPIC_FHIR_BASE_URL`, `EPIC_TOKEN_URL`, `EPIC_CLIENT_ID`, `EPIC_KID`, `EPIC_PRIVATE_KEY` | Epic EHR integration | -| **FHIR Cerner** | `CERNER_FHIR_BASE_URL`, `CERNER_TOKEN_URL`, `CERNER_CLIENT_ID`, `CERNER_KID`, `CERNER_PRIVATE_KEY`, `CERNER_SCOPES` | Cerner EHR integration | -| **Google Drive** | `GOOGLE_DRIVE_SA_JSON`, `GOOGLE_DRIVE_FOLDER_ID` | Google Drive connector | -| **SMTP** | `SMTP_HOST`, `SMTP_PORT`, `SMTP_USERNAME`, `SMTP_PASSWORD` | Sending emails | -| **Slack** | `SLACK_BOT_TOKEN` | Sending Slack messages | -| **LLM / Agent** | `LLM_PROVIDER`, `GROQ_API_KEY` (or other provider key) | AI agent / ToolHive | -| **ToolHive / MCP**| `TOOLHIVE_MCP_URLS` (multi-server), `NW_MCP_TRANSPORT`, `NW_MCP_PORT`, `NW_STREAM_BUFFER_MS` | AI agent / ToolHive | -| **ToolHive** | `TOOLHIVE_MCP_URL` (single) or `TOOLHIVE_MCP_URLS` (comma-separated, multi-server) | ToolHive MCP proxy | -| **Plugin Allowlist** | `NW_ALLOWED_CONNECTORS` (comma-separated list of allowed connector names) | **Required** for any connector to load | - - - -See `sample.env` for the full list with example values. - ---- - -## Running the Platform - -The platform supports three modes. Set the `MODE` environment variable to switch between them. - - -| Mode | Command | Default Port | Use Case | -| ---------------------- | --------------------------------- | ------------ | ----------------------------------- | -| **REST API** (default) | `python -m uv run node-wire` | `8000` | HTTP clients, Swagger UI, curl | -| **gRPC** | `MODE=GRPC python -m uv run node-wire` | `50051` | gRPC clients | -| **MCP** | `python -m agents.mcp_entrypoint` | stdio / 8080 | AI agents, ToolHive, Claude Desktop | - -> **Important:** `MODE=MCP` for `node-wire` / `python -m bindings_entrypoint` starts a minimal MCP-style placeholder server, not the full stdio MCP server used with ToolHive and the agent layer. For ToolHive/Inspector/agents, use `python -m agents.mcp_entrypoint` (or the per-connector MCP servers in `docs/mcp-servers.md`). - -### Configuration file (`config/connectors.yaml`) - -Connectors are loaded from `config/connectors.yaml`. Each connector has: - -- `enabled`: whether the connector is instantiated at startup -- `exposed_via`: which protocols can access it (`rest`, `grpc`, `mcp`) - -If a connector is disabled (or not exposed for a protocol), requests to it will fail with “not configured / not available” even if your `.env` is correct. - -For details on adding a new connector to the runtime, see [docs/connectors.md](docs/connectors.md). - - -### REST API Quick Start - -```bash -# Local development: disable REST auth (do not use in production) -export NW_REST_AUTH_DISABLED=true - -# Default port 8000 -python -m uv run node-wire - -# If port 8000 is in use, override with PORT -PORT=8001 python -m uv run node-wire -``` - -**Production / secured REST:** set `NW_REST_API_KEY` and send `Authorization: Bearer ` or `X-API-Key: ` on every route except `GET /health`. Set `NW_REST_LOAD_DOTENV=false` so secrets are not loaded from a `.env` file. See [docs/connectors.md](docs/connectors.md) (Security section). - -Once running: - -- **Health check (no auth):** `GET http://localhost:8000/health` -- **Interactive docs (Swagger UI):** `http://localhost:8000/docs` (requires API key when auth is enabled) -- **Call a connector:** `POST http://localhost:8000/connectors/{connector_id}/{action}` - -Example — send an HTTP request via the generic connector (with auth enabled): - -```bash -curl -X POST http://localhost:8000/connectors/http_generic/request \ - -H "Authorization: Bearer $NW_REST_API_KEY" \ - -H "Content-Type: application/json" \ - -d '{"url": "https://httpbin.org/get", "method": "GET"}' -``` - -All responses use the same standard shape: - -```json -{ - "success": true, - "data": { "raw": { ... }, "description": "..." }, - "error_code": null, - "error_category": null, - "message": null, - "trace_id": "..." -} -``` - ---- - -## Connectors Overview - -**Developer guide (`BaseConnector`, config, factory):** [docs/connectors.md](docs/connectors.md). - - - -| Connector | What It Does | Credentials Needed | Setup Guide | -| ---------------- | ------------------------------------------ | -------------------------------------- | --------------------------------------------------------------------------------------------- | -| **http_generic** | Make HTTP requests to any URL | None | No setup needed | -| **smtp** | Send emails via SMTP | SMTP host/port/username/password | [SMTP Setup](#smtp) | -| **stripe** | Process Stripe payments | Stripe API key | [Stripe Setup](#stripe) | -| **google_drive** | List, upload, download, manage Drive files | GCP service account JSON | [Google Drive setup & API](docs/google_drive_connector.md#google-drive-service-account-setup) | -| **fhir_epic** | Read/write patient data from Epic EHR | Epic SMART credentials + private key | [FHIR Epic Setup](#fhir-epic) | -| **fhir_cerner** | Read/write patient data from Cerner EHR | Cerner SMART credentials + private key | [FHIR Cerner Setup](#fhir-cerner) | -| **slack** | Send messages and files to Slack channels | Slack Bot OAuth Token | [Slack Setup](#slack) | - - ---- - -## Connector Setup - -### HTTP Generic - -No credentials required. Works out of the box. - -Security defaults: -- Allowed methods: `GET`, `POST`, `PUT`, `PATCH`, `DELETE` (input is normalized to uppercase). -- Internal targets are blocked: `localhost`, loopback, private/link-local IPs, and metadata endpoints. -- Connector logs omit URL query strings and fragments (scheme/host/path only). - -```bash -curl -X POST http://localhost:8000/connectors/http_generic/request \ - -H "Content-Type: application/json" \ - -d '{ - "url": "https://api.example.com/data", - "method": "POST", - "headers": {"Authorization": "Bearer your-token"}, - "body": {"key": "value"} - }' -``` - ---- - -### SMTP - -Add these to your `.env`: - -```env -SMTP_HOST=smtp.gmail.com -SMTP_PORT=587 -SMTP_USERNAME=you@gmail.com -SMTP_PASSWORD=your-app-password -``` - -> **Gmail users:** You must use an [App Password](https://support.google.com/accounts/answer/185833), not your regular Gmail password. Enable 2-Factor Authentication on your Google account first, then generate an App Password under Security settings. - -Supported configurations: - -- Port `587` with STARTTLS (recommended for Gmail, most SMTP providers) -- Port `465` with implicit TLS - ---- - -### Stripe - -Add to your `.env`: - -```env -STRIPE_API_KEY=sk_test_your_key_here -``` - -Use a **test key** (`sk_test_...`) during development. Switch to a live key (`sk_live_...`) for production. - ---- - -### Google Drive - -The Google Drive connector uses a **service account** — a non-human Google account your application uses to authenticate with Google Drive APIs. - -**Full documentation:** [docs/google_drive_connector.md](docs/google_drive_connector.md) — service account setup, verification, and REST `execute` API (all seven operations). - -Quick summary of what you'll need: - -1. A Google Cloud project with the Drive API enabled -2. A service account with a downloaded JSON key file -3. A shared Drive folder (share it with the service account's email) - -Add to your `.env`: - -```env -GOOGLE_DRIVE_SA_JSON=/absolute/path/to/service-account.json -GOOGLE_DRIVE_FOLDER_ID=your-folder-id-from-drive-url -``` - ---- - -### FHIR Epic - -Epic EHR integration uses the SMART Backend Services OAuth2 flow with RS384 JWT authentication. - -Add to your `.env`: - -```env -EPIC_FHIR_BASE_URL=https://fhir.epic.com/interconnect-fhir-oauth/api/FHIR/R4 -EPIC_TOKEN_URL=https://fhir.epic.com/interconnect-fhir-oauth/oauth2/token -EPIC_CLIENT_ID=your-epic-client-id -EPIC_KID=your-key-id -EPIC_PRIVATE_KEY="-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----" -``` - -You obtain these credentials by registering a backend application in the [Epic App Orchard](https://appmarket.epic.com/) (or your organization's Epic sandbox). - -**Available actions:** `read_patient`, `search_patients`, `search_encounter`, `create_document_reference`, `search_document_reference` - ---- - -### FHIR Cerner - -Cerner EHR integration also uses SMART Backend Services with `private_key_jwt` client authentication. - -Add to your `.env`: - -```env -CERNER_FHIR_BASE_URL=https://fhir-ehr-code.cerner.com/r4/your-tenant-id -CERNER_TOKEN_URL=https://authorization.cerner.com/tenants/your-tenant-id/protocols/oauth2/profiles/smart-v1/token -CERNER_CLIENT_ID=your-cerner-client-id -CERNER_KID=your-key-id -CERNER_PRIVATE_KEY="-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----" -CERNER_SCOPES="system/Patient.read system/Encounter.read system/DocumentReference.read system/DocumentReference.write" -``` - -Register your application in the [Cerner Developer Portal](https://code.cerner.com/) to obtain these credentials. - -**Available actions:** `read_patient`, `search_patients`, `search_encounter`, `create_document_reference`, `search_document_reference` - ---- - -## MCP Transport Modes - -Node Wire supports two transport modes for AI agents. Switch between them using the `NW_MCP_TRANSPORT` environment variable: - -- **`stdio`** (Default): Communicates via standard I/O. Best for local development, subprocess-based clients, and ToolHive-style wrapping. The playground uses buffered agent responses in this mode. -- **`streamable-http`**: Native HTTP MCP server. Exposes a direct endpoint on `NW_MCP_HOST`, `NW_MCP_PORT`, and `NW_MCP_PATH`. The playground streams tool progress and final answer chunks in this mode. - -### Streaming Features -- **Configurable Buffering (`NW_STREAM_BUFFER_MS`)**: When streaming, output can be buffered to reduce event spam. Set to the duration in milliseconds (e.g. `2000` for a 2-second batching window). Default is `0` (no buffering). -- **Completion Signals**: The core runtime emits structured "done" signals (`stream_completion_log`) via Python logging when streaming ends, allowing package consumers to easily detect when a stream finishes. - -**Example: stdio mode** - -```powershell -$env:NW_MCP_TRANSPORT="stdio" -python -m agents.mcp_entrypoint -``` - -**Example: Shift to HTTP mode on Port 8081** -```powershell -# Windows -$env:NW_MCP_TRANSPORT="streamable-http" -$env:NW_MCP_HOST="127.0.0.1" -$env:NW_MCP_PORT="8081" -$env:NW_MCP_PATH="/mcp" -python -m agents.mcp_entrypoint -``` - -The HTTP MCP endpoint is then: - -```text -http://127.0.0.1:8081/mcp -``` - -When the REST playground is running, the Agentic Workflow panel displays the active transport by reading `/scenarios/agent-transport`: - -- `Transport: stdio`: the UI waits for the complete backend agent result. -- `Transport: Streamable HTTP`: tool cards appear as tools finish, and the final answer renders progressively as chunks arrive. - -### Testing with MCP Inspector - -MCP Inspector can be launched with `npx`: - -```powershell -npx @modelcontextprotocol/inspector -``` - -For stdio testing, let Inspector launch the server: - -```powershell -$env:NW_MCP_TRANSPORT="stdio" -npx @modelcontextprotocol/inspector python -m agents.mcp_entrypoint -``` - -For streamable HTTP testing, start the server first: - -```powershell -$env:NW_MCP_TRANSPORT="streamable-http" -$env:NW_MCP_HOST="127.0.0.1" -$env:NW_MCP_PORT="8081" -$env:NW_MCP_PATH="/mcp" -python -m agents.mcp_entrypoint -``` - -Then run Inspector: - -```powershell -npx @modelcontextprotocol/inspector -``` - -In Inspector, choose `Streamable HTTP`, enter `http://127.0.0.1:8081/mcp`, connect, then use `Tools -> List Tools` and run a safe tool call with valid JSON arguments. - ---- -### Slack - -Add the bot token to your `.env`: - -```env -SLACK_BOT_TOKEN=xoxb-your-slack-bot-token -``` - -1. Create a Slack App at [api.slack.com/apps](https://api.slack.com/apps). -2. Go to **OAuth & Permissions** and add **Bot Token Scopes**: - - `chat:write` (to post messages) - - `files:write` (to upload files) - - `im:write` (to send direct messages) - - `conversations:open` (to resolve User IDs) -3. Install the app to your workspace. -4. Copy the **Bot User OAuth Token** (`xoxb-...`). -5. Invite the bot to any private channels you want it to access. - ---- - -## MCP Server & ToolHive - -The platform exposes connector tools for AI agents via the MCP (Model Context Protocol). There are two deployment modes: - -### Individual MCP servers (recommended) - -Each connector runs as its own independent MCP server. This is the preferred approach for modular, scalable deployments. - - -| Image | MCP tools (manifest) | Docker image | -| ----------------------- | -------------------- | -------------------------------- | -| `nw-google-drive` | All `google_drive.` (e.g. `google_drive.files.upload`) | `docker/google-drive/Dockerfile` | -| `nw-smartonfhir-epic` | All `fhir_epic.` (e.g. `fhir_epic.read_patient`) | `docker/fhir-epic/Dockerfile` | -| `nw-smartonfhir-cerner` | All `fhir_cerner.` (e.g. `fhir_cerner.read_patient`) | `docker/fhir-cerner/Dockerfile` | -| `nw-smtp` | `smtp.send_email` | `docker/smtp/Dockerfile` | -| `nw-slack` | All `slack.` (e.g. `slack.post_message`) | `docker/slack/Dockerfile` | - - -**Full guide (build, env config, ToolHive registration, multi-server agent usage):** [docs/mcp-servers.md](docs/mcp-servers.md) - -**FHIR tool arguments (Cerner / Epic)** — tool names are `fhir_cerner.` and `fhir_epic.`. Use field names from `tools/list` / the connector manifest. Typical payloads: - -| Action | When to use | Example arguments | -| ------ | ----------- | ------------------- | -| `read_patient` | You have a Patient id | `{"resource_id": "12724066"}` (Epic ids often start with `e`) | -| `search_patients` | No id, or name-based search | `{"resource_ids": ["id1"]}` or `{"given_name": "...", "family_name": "..."}` or `{"search_params": {"identifier": "...", "family": "..."}}` (FHIR search param names) | - -The MCP server normalizes common LLM/legacy aliases (`patientId` / `patient_id` → `resource_id`; `patientId` inside `search_params` → `identifier`) before validation. Prefer canonical fields above when authoring prompts or clients. - -Quick start: - -```bash -# Build all four images -./scripts/build-mcp-images.sh - -# Start all four locally -docker compose -f docker-compose.mcp.yml up -``` - -### Combined MCP server (all connectors in one) - -For simpler setups all connectors can be exposed from a single MCP server: - -```bash -python -m agents.mcp_entrypoint -``` - -**ToolHive** runs the MCP server inside a secure Docker container, manages secrets injection, and provides an HTTP proxy that any MCP-compatible client (Claude Desktop, Cursor, custom agents) can connect to. - -**See the full ToolHive workflow guide:** [docs/toolhive_agent_scenario.md](docs/toolhive_agent_scenario.md) - -### Quick Local Test (No ToolHive) - -```bash -# Inspect any individual server with MCP Inspector -npx @modelcontextprotocol/inspector python -m agents.fhir_epic_mcp -npx @modelcontextprotocol/inspector python -m agents.google_drive_mcp - -# Or test the combined server -npx @modelcontextprotocol/inspector python -m agents.mcp_entrypoint -``` - -For native streamable HTTP testing, use the `streamable-http` setup above and connect Inspector to `http://127.0.0.1:8081/mcp`. - -### Troubleshooting quick hits - -- **Port 8000 in use**: set `PORT=8001` (or any free port) when starting the REST API. -- **Connector “not configured”**: confirm it is `enabled: true` (and exposed for your protocol) in `config/connectors.yaml`. -- **ToolHive + Google Drive auth failure**: inside ToolHive, `GOOGLE_DRIVE_SA_JSON` must be the JSON **contents** (not a file path). Locally, it can be an absolute file path (see `docs/mcp-servers.md`). - ---- - -## Running Tests - -```bash -# Install dev dependencies (if not already installed) -pip install -e ".[dev,agents]" - -# Run all tests -pytest tests/ -v - -# Run a specific connector's tests -pytest tests/test_google_drive.py -v -pytest tests/test_fhir_epic.py -v -pytest tests/test_toolhive_agent.py -v -``` - -Most tests are unit tests that run without real credentials. Integration tests that call live APIs are skipped unless the relevant environment variables are set. - -For deterministic pytest runs (especially when a repo-root `.env` exists locally), `tests/conftest.py` sets **`NW_REST_LOAD_DOTENV=false`** so REST startup does not merge `.env` over test variables, **`NW_CONFIG_PATH`** to [`tests/fixtures/connectors_for_tests.yaml`](tests/fixtures/connectors_for_tests.yaml) so optional connectors not on the test allowlist stay **`enabled: false`** (e.g. slack, salesforce), and a fixed **`NW_ALLOWED_CONNECTORS`** list. Do not rely on `.env` values during collection. - ---- - -## Code quality and security gates - -Node Wire enforces security and coverage-backed analysis in CI for pull requests and pushes to `main`/`master`: - -- Bandit: JSON report + log summary (artifact), then `bandit -c pyproject.toml -r src --severity-level high` for the failing gate. The JSON step uses `--exit-zero` because Bandit otherwise exits 1 on *any* finding while the gate only blocks **high** severity. -- SonarQube Community Edition scan with `sonar.qualitygate.wait=true` so PRs fail when the quality gate fails. - -### Run checks locally - -```bash -# Install dev tools -pip install -e ".[dev,agents]" - -# Security gate (matches CI failure threshold) -bandit -c pyproject.toml -r src --severity-level high - -# Optional: JSON report + same summary as CI logs -bandit -c pyproject.toml -r src -f json -o bandit-report.json --exit-zero -python scripts/bandit_report_summary.py bandit-report.json - -# Tests + coverage.xml (required by SonarQube) -pytest tests/ -v -``` - -### Pre-commit - -```bash -pre-commit install -pre-commit run --all-files -``` - -### Run SonarQube scan locally (Docker) - -```bash -# from repository root, after coverage.xml is generated -docker run --rm \ - -e SONAR_TOKEN=YOUR_TOKEN \ - -v "G:\SPACE\node-wire:/usr/src" \ - -w /usr/src \ - sonarsource/sonar-scanner-cli \ - -Dsonar.host.url=http://host.docker.internal:9000 \ - -Dsonar.token=YOUR_TOKEN -``` - -### SonarQube configuration - -The repository includes [`sonar-project.properties`](sonar-project.properties) and CI expects these GitHub secrets: - -- `SONAR_HOST_URL` (example: `https://sonarqube.company.internal`) -- `SONAR_TOKEN` (project analysis token) - -For server setup and quality gate policy details, see [docs/quality-security-gates.md](docs/quality-security-gates.md). - ---- - -## Playground UI - -The repository includes an interactive web playground that showcases 5 orchestration scenarios: - -> **Note:** The UI is served under the `/playground/` path (not at the server root). - -```bash -# Start the REST API (if not already running) -python -m uv run node-wire - -# Open in your browser -open http://localhost:8000/playground/ -``` - -Scenarios include: - -1. Epic FHIR patient lookup and clinical note upload -2. IT Ops automation via HTTP Generic -3. Cerner FHIR orchestration -4. Google Drive document archival -5. AI agent orchestration via MCP - -See `playground/README.md` for details on each scenario and how to configure them. diff --git a/docs/code-quality-compliance.md b/docs/code-quality-compliance.md new file mode 100644 index 0000000..b9e5c4f --- /dev/null +++ b/docs/code-quality-compliance.md @@ -0,0 +1,95 @@ + + +# Code Quality and Compliance + +This project uses **Ruff** for linting and formatting, **Mypy** for static type checking, **Bandit** for SAST, **pip-audit** for dependency vulnerability checks, and **REUSE** for open-source licensing compliance. + +Linting and type checks run automatically in CI on pull requests against the `main` branch via `.github/workflows/lint.yml`. Security and package compliance checks are additionally enforced through `.github/workflows/quality-gates.yml` and `.github/workflows/security-pr.yml`. + +## Manual usage for developers + +Install development dependencies first: + +```bash +pip install -e ".[dev]" +``` + +Then run the local quality checks: + +- **Check formatting and linting errors:** `ruff check .` +- **Auto-fix and format code:** `ruff check --fix . && ruff format .` +- **Run static type validation:** `mypy` + +`mypy` uses the default `files` target from `[tool.mypy]` in `pyproject.toml`, which is currently `src`. Avoid `mypy .`, because it can pull in packaging `setup.py` scripts under `packages/` and produce duplicate-module noise. To include tests explicitly, run: + +```bash +mypy src tests +``` + +## Pre-commit hooks + +You can attach `.pre-commit-config.yaml` so checks run before each commit: + +```bash +pre-commit install +``` + +To run all configured hooks across the repository: + +```bash +pre-commit run --all-files +``` + +The current pre-commit setup includes Ruff, Ruff formatting, Mypy, and Bandit. + +## Copyright headers and REUSE compliance + +This repository enforces open-source licensing compliance using [REUSE](https://reuse.software/). First-party files should contain the appropriate SPDX copyright and license headers. + +### Verify compliance + +```bash +uv pip install reuse +uv run reuse lint +``` + +### Add missing headers + +If `reuse lint` reports missing headers, you can apply the repository header template with: + +```bash +bash scripts/add-license-headers.sh +``` + +## Dependency inventory and compliance + +To maintain an open-source compliant dependency set, the repository tracks third-party packages and their licenses in `DEPENDENCIES.md`. + +### License classification criteria + +- **Safe (permissive):** MIT, Apache-2.0, BSD, PSF. Safe for the Apache-2.0 release. +- **Needs review:** Custom or uncommon licenses that require manual review. +- **Risky (copyleft):** GPLv2, GPLv3, AGPL. Not allowed in the runtime application. They may be acceptable only for isolated, non-distributed development tooling. + +### Update the inventory and run compliance checks + +When adding dependencies or preparing a release, run the unified compliance script: + +```bash +bash scripts/run-compliance-checks.sh +``` + +That script: + +1. Regenerates `DEPENDENCIES.md`. +2. Runs **Bandit** for static application security testing. +3. Runs **pip-audit** for dependency vulnerability scanning. + +## Related docs + +- [Quality and security gates](quality-security-gates.md) +- [Installation guide](installation.md) diff --git a/docs/connectors.md b/docs/connectors.md index 1fdf318..9259a14 100644 --- a/docs/connectors.md +++ b/docs/connectors.md @@ -532,7 +532,7 @@ connectors: **REST API (`bindings.rest_api`)** — `GET /health` is unauthenticated. All other routes (`/connectors/...`, `/playground/...`, `/scenarios/...`, OpenAPI) require **`NW_REST_API_KEY`** via `Authorization: Bearer ` or `X-API-Key: `, optional **`NW_REST_JWT_SECRET`** for HS256 JWTs. API key scopes use **`NW_REST_API_KEY_SCOPES`** (same format as MCP). Set **`NW_REST_AUTH_DISABLED=true`** only for local development. Production: set **`NW_REST_LOAD_DOTENV=false`** so secrets are not read from a `.env` file on disk. -**HTTP Generic outbound policy** — `http_generic.request` allows only `GET`, `POST`, `PUT`, `PATCH`, `DELETE`. URLs targeting internal destinations are rejected (`localhost`, loopback, private/link-local IP ranges, metadata endpoints). Connector logs sanitize URL fields by dropping query strings and fragments. +**HTTP Generic outbound policy** — `http_generic.request` allows only `GET`, `POST`, `PUT`, `PATCH`, `DELETE`, and input methods are normalized to uppercase before validation. URLs targeting internal destinations are rejected (`localhost`, loopback, private/link-local IP ranges, metadata endpoints). Connector logs sanitize URL fields by dropping query strings and fragments so only scheme/host/path are retained. **Connector entry points** — Any installed distribution may register `node_wire.connectors`. For production, set **`NW_ALLOWED_CONNECTORS`** to a comma-separated list of entry point names (e.g. `fhir_epic,http_generic`). **`NW_CONNECTOR_MODULE_PREFIX`** defaults to `node_wire_`; modules not under that prefix are skipped. diff --git a/docs/installation.md b/docs/installation.md index 327d9fe..f2b8bff 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -31,6 +31,8 @@ copy sample.env .env ``` *(Edit `.env` and set `NW_ALLOWED_CONNECTORS=http_generic` or others)* +Node Wire uses a fail-closed connector allowlist. If `NW_ALLOWED_CONNECTORS` is missing or empty, no connectors are loaded even when they are enabled in `config/connectors.yaml`. + ### 3. Install dependencies **Using `uv` (recommended):** @@ -50,6 +52,41 @@ uv run node-wire --help --- +## Running the Platform + +Node Wire supports REST, gRPC, and MCP entry modes: + +| Mode | Command | Default port / transport | Use case | +|------|---------|--------------------------|----------| +| REST API | `uv run node-wire` | `8000` | HTTP clients, Swagger UI, playground | +| gRPC | `MODE=GRPC uv run node-wire` | `50051` | gRPC clients | +| MCP | `python -m agents.mcp_entrypoint` | `stdio` or HTTP | AI agents, ToolHive, Inspector | + +### REST quick start + +```bash +# Local development only +export NW_REST_AUTH_DISABLED=true + +# Start the API +uv run node-wire +``` + +Once it is running: + +- Health check: `GET http://localhost:8000/health` +- Swagger UI: `http://localhost:8000/docs` +- Playground: `http://localhost:8000/playground/` + +### MCP notes + +For MCP transport modes, Inspector usage, and multi-server deployment: + +- See [mcp.md](mcp.md) for transport setup and local MCP usage. +- See [mcp-servers.md](mcp-servers.md) for per-connector images, ToolHive, and Docker-based MCP deployment. + +--- + ## Development Setup ### Code Quality (Linting & Formatting) @@ -57,7 +94,9 @@ We use **Ruff** for linting/formatting and **Mypy** for type checking. - **Check:** `ruff check .` - **Fix:** `ruff check --fix . && ruff format .` -- **Types:** `mypy .` +- **Types:** `mypy` + +`mypy` defaults to the `[tool.mypy].files` targets from `pyproject.toml`. To include tests explicitly, run `mypy src tests`. ### Pre-commit Hooks ```bash diff --git a/docs/mcp-servers.md b/docs/mcp-servers.md index 26e8665..8aabc7d 100644 --- a/docs/mcp-servers.md +++ b/docs/mcp-servers.md @@ -38,12 +38,16 @@ flowchart TD Epic[nw-smartonfhir-epic] Cerner[nw-smartonfhir-cerner] SMTP[nw-smtp] + Stripe[nw-stripe] + Salesforce[nw-salesforce] Slack[nw-slack] end Agent -->|"TOOLHIVE_MCP_URLS"| GDrive Agent -->|"TOOLHIVE_MCP_URLS"| Epic Agent -->|"TOOLHIVE_MCP_URLS"| Cerner Agent -->|"TOOLHIVE_MCP_URLS"| SMTP + Agent -->|"TOOLHIVE_MCP_URLS"| Stripe + Agent -->|"TOOLHIVE_MCP_URLS"| Salesforce Agent -->|"TOOLHIVE_MCP_URLS"| Slack ``` @@ -413,7 +417,7 @@ GROQ_API_KEY=your-groq-api-key Before building images, build local wheels first. See [docs/local-packages-to-images.md](local-packages-to-images.md) for the full package -> image workflow and required wheel artifacts per image. -All four images are built from the repository root using the automation script: +All MCP server images are built from the repository root using the automation script: ```bash ./scripts/build-mcp-images.sh @@ -433,6 +437,8 @@ This produces images tagged as both `latest` and the version string: | `nw-smartonfhir-epic` | `nw-smartonfhir-epic:latest`, `nw-smartonfhir-epic:0.1.0` | | `nw-smartonfhir-cerner` | `nw-smartonfhir-cerner:latest`, `nw-smartonfhir-cerner:0.1.0` | | `nw-smtp` | `nw-smtp:latest`, `nw-smtp:0.1.0` | +| `nw-stripe` | `nw-stripe:latest`, `nw-stripe:0.1.0` | +| `nw-salesforce` | `nw-salesforce:latest`, `nw-salesforce:0.1.0` | | `nw-slack` | `nw-slack:latest`, `nw-slack:0.1.0` | To build a single image manually from the repo root: @@ -450,6 +456,12 @@ docker build -f docker/fhir-cerner/Dockerfile -t nw-smartonfhir-cerner:latest . # SMTP only docker build -f docker/smtp/Dockerfile -t nw-smtp:latest . +# Stripe only +docker build -f docker/stripe/Dockerfile -t nw-stripe:latest . + +# Salesforce only +docker build -f docker/salesforce/Dockerfile -t nw-salesforce:latest . + # Slack only docker build -f docker/slack/Dockerfile -t nw-slack:latest . ``` @@ -460,7 +472,7 @@ docker build -f docker/slack/Dockerfile -t nw-slack:latest . ## Run with docker-compose -`docker-compose.mcp.yml` starts all four MCP servers as stdio containers in one command. This is useful for local validation before configuring ToolHive. +`docker-compose.mcp.yml` starts all MCP servers as stdio containers in one command. This is useful for local validation before configuring ToolHive. ```bash # Ensure your .env is populated, then: @@ -515,6 +527,20 @@ thv run --name nw-smtp --transport stdio \ --secret FROM_EMAIL,target=FROM_EMAIL \ nw-smtp:latest +# Stripe +thv run --name nw-stripe --transport stdio \ + --secret STRIPE_API_KEY,target=STRIPE_API_KEY \ + nw-stripe:latest + +# Salesforce +thv run --name nw-salesforce --transport stdio \ + --secret SALESFORCE_INSTANCE_URL,target=SALESFORCE_INSTANCE_URL \ + --secret SALESFORCE_CLIENT_ID,target=SALESFORCE_CLIENT_ID \ + --secret SALESFORCE_CLIENT_SECRET,target=SALESFORCE_CLIENT_SECRET \ + --secret SALESFORCE_USERNAME,target=SALESFORCE_USERNAME \ + --secret SALESFORCE_PASSWORD,target=SALESFORCE_PASSWORD \ + nw-salesforce:latest + # Slack thv run --name nw-slack --transport stdio \ --secret SLACK_BOT_TOKEN,target=SLACK_BOT_TOKEN \ diff --git a/docs/quality-security-gates.md b/docs/quality-security-gates.md index 5e8ad97..3858c9a 100644 --- a/docs/quality-security-gates.md +++ b/docs/quality-security-gates.md @@ -39,23 +39,43 @@ Configure branch protection so pull requests cannot merge unless all required ch **Monorepo install note:** Connector packages under `packages/connectors/*` declare `node-wire-runtime>=0.1.0` as a normal PyPI dependency name. The security workflow installs `packages/runtime` from the checkout **together with** each matrix package (`pip install packages/runtime ""`) so `pip` can resolve `node-wire-runtime` without requiring a published wheel on PyPI. Locally, mirror that when auditing a single connector: `pip install packages/runtime packages/connectors/`. -## Local commands - -```bash -pip install -e ".[dev,agents]" -# Enforce the same threshold as CI (non-zero exit if any HIGH finding) -bandit -c pyproject.toml -r src --severity-level high -# Full JSON report without failing the shell (Bandit otherwise exits 1 on any finding) -bandit -c pyproject.toml -r src -f json -o bandit-report.json --exit-zero -python scripts/bandit_report_summary.py bandit-report.json -pytest tests/ -v -pre-commit install -pre-commit run --all-files -``` - -## Local Sonar scan with Docker - -After generating `coverage.xml`, run scanner from the repository root: +## Run checks locally + +```bash +# Install dev tools +pip install -e ".[dev,agents]" + +# Security gate (matches CI failure threshold) +bandit -c pyproject.toml -r src --severity-level high + +# Optional: JSON report + same summary as CI logs +bandit -c pyproject.toml -r src -f json -o bandit-report.json --exit-zero +python scripts/bandit_report_summary.py bandit-report.json + +# Tests + coverage.xml (required by SonarQube) +pytest tests/ -v +``` + +## Deterministic pytest environment + +To keep pytest collection and REST app startup deterministic, `tests/conftest.py` sets a fixed environment before imports: + +- `NW_REST_LOAD_DOTENV=false` so REST startup does not merge a repo-root `.env` over test variables. +- `NW_CONFIG_PATH=tests/fixtures/connectors_for_tests.yaml` so optional connectors outside the pytest allowlist remain `enabled: false` (for example `slack` and `salesforce`). +- `NW_ALLOWED_CONNECTORS=http_generic,smtp,stripe,google_drive,fhir_epic,fhir_cerner` so only the supported test connector set is loaded during collection. + +Do not rely on `.env` values during pytest collection. The test harness intentionally overrides them so local developer state does not affect CI or test outcomes. + +### Pre-commit + +```bash +pre-commit install +pre-commit run --all-files +``` + +## Local Sonar scan with Docker + +After generating `coverage.xml`, run scanner from the repository root: ```bash docker run --rm \ @@ -63,11 +83,20 @@ docker run --rm \ -v "G:\SPACE\node-wire:/usr/src" \ -w /usr/src \ sonarsource/sonar-scanner-cli \ - -Dsonar.host.url=http://host.docker.internal:9000 \ - -Dsonar.token=YOUR_TOKEN -``` - -## Bandit policy + -Dsonar.host.url=http://host.docker.internal:9000 \ + -Dsonar.token=YOUR_TOKEN +``` + +## SonarQube configuration + +The repository includes `sonar-project.properties` and CI expects these GitHub secrets: + +- `SONAR_HOST_URL` (example: `https://sonarqube.company.internal`) +- `SONAR_TOKEN` (project analysis token) + +For server setup and quality gate policy details, see this document's [SonarQube Community Edition setup](#sonarqube-community-edition-setup) section. + +## Bandit policy Bandit is configured in `pyproject.toml` under `[tool.bandit]`. @@ -80,7 +109,7 @@ CI splits responsibilities: 1. **JSON artifact + log summary** — `bandit ... -f json -o bandit-report.json --exit-zero` so the workflow always produces the report and runs `scripts/bandit_report_summary.py` for readable logs. Low/medium issues are visible here and in Sonar/import without failing the job. 2. **Enforcement** — `bandit ... --severity-level high` fails the job only on high-severity findings (matches branch-protection intent). -Locally, mirror CI with the commands in [Local commands](#local-commands). +Locally, mirror CI with the commands in [Run checks locally](#run-checks-locally). ### Scope From c1691a9e3c1ada50e8b1c326484fbf20c383e620 Mon Sep 17 00:00:00 2001 From: Rahul Ap Date: Thu, 21 May 2026 10:10:39 +0530 Subject: [PATCH 40/60] expose snake_case tool inputs while preserving Salesforce API aliases (#48) --- src/node_wire_runtime/manifest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/node_wire_runtime/manifest.py b/src/node_wire_runtime/manifest.py index d490c27..6e5c268 100644 --- a/src/node_wire_runtime/manifest.py +++ b/src/node_wire_runtime/manifest.py @@ -17,7 +17,7 @@ def _schema_for(model: Type[BaseModel], *, strict: bool = True) -> Dict[str, Any]: - schema = copy.deepcopy(model.model_json_schema()) + schema = copy.deepcopy(model.model_json_schema(by_alias=False)) # Remove `action` from `required`: it is always auto-injected from the tool # name by invoke_tool (run_args.setdefault("action", action)), so LLMs must # not be required to pass it. Keeping it as an optional property is fine. From 2fab69886c92aa88eb6b0fb1f73306764adcb5a4 Mon Sep 17 00:00:00 2001 From: kesav-aot Date: Thu, 21 May 2026 12:03:57 +0530 Subject: [PATCH 41/60] Add Connector Apps view and streaming timers (#47) * Add Connector Apps view and streaming timers Introduce a Connector Apps marketplace and app cards, plus improved navigation and streaming UI. - index.html: add a new Connector Apps selection card and a connector-apps-selection-view with an app-card for an External Patient Viewer and an apps back button. - app.js: include .app-card in selection handling and implement new views (connector-apps-menu, ext-patient-viewer), refined back-button logic (apps-back-btn, backSelectionBtn behavior, backToConnectorsBtn behavior), and setMode adjustments for mode-specific back labels. Insert step cards before active streaming bubbles when present. - Streaming: add an inline running timer to streaming bubbles, start/clear a timer interval during stream lifecycle, pass final elapsed time to appendStreamEndMessage, and ensure proper cleanup on errors/interrupts. - style.css: add styles for .app-card, responsive selection layout, visual tweaks for selection cards, and styling for the .stream-running-timer and apps selection view. These changes improve UX for browsing connector apps and provide visible timing/cleanup for streaming agent responses. * Update Slack scopes and bump Gemini model Remove the `conversations:open` scope from the Slack connector docs (it's no longer listed as required). Also update sample.env to use `GEMINI_MODEL=gemini-2.5-flash` instead of `gemini-2.0-flash` to reflect the newer model. * Update style.css --- docs/slack_connector.md | 1 - playground/app.js | 187 +++++++++++++++++++++++++++++++++------- playground/index.html | 57 +++++++----- playground/style.css | 113 ++++++++++++++++++------ sample.env | 2 +- 5 files changed, 281 insertions(+), 79 deletions(-) diff --git a/docs/slack_connector.md b/docs/slack_connector.md index d5ac828..e69c4a5 100644 --- a/docs/slack_connector.md +++ b/docs/slack_connector.md @@ -33,7 +33,6 @@ The Slack connector uses a **Bot User OAuth Token** to interact with your worksp - `chat:write` — Send messages to channels and DMs. - `files:write` — Upload and share files. - `im:write` — Start direct messages with users. - - `conversations:open` — Resolve User IDs to DM channel IDs. - `groups:read` (optional) — If you need to post to private channels the bot is invited to. - `channels:read` (optional) — If you need to resolve channel names. diff --git a/playground/app.js b/playground/app.js index 0dd1602..b2c4952 100644 --- a/playground/app.js +++ b/playground/app.js @@ -291,7 +291,7 @@ document.addEventListener('DOMContentLoaded', () => { } const rootSelectionView = document.getElementById('root-selection-view'); - const selectionCards = document.querySelectorAll('.selection-card'); + const selectionCards = document.querySelectorAll('.selection-card, .app-card'); const rootTabContainer = document.querySelector('.root-tab-container'); const backToHomeBtn = document.getElementById('back-to-home'); @@ -311,28 +311,79 @@ document.addEventListener('DOMContentLoaded', () => { selectionCards.forEach(card => { card.addEventListener('click', () => { const view = card.dataset.target; - rootSelectionView.classList.add('hidden'); - layoutMain.classList.remove('hidden'); - headerActions.classList.remove('hidden'); if (view === 'agent') { + rootSelectionView.classList.add('hidden'); + layoutMain.classList.remove('hidden'); + headerActions.classList.remove('hidden'); agentPanel.classList.remove('hidden'); connectorsView.classList.add('hidden'); layoutMain.classList.add('agent-mode'); connectorStatus.textContent = 'AI Agent Online'; tagline.textContent = 'Autonomous Healthcare Assistant'; document.documentElement.style.setProperty('--brand-accent', '#8b5cf6'); + if (backSelectionBtn) { + backSelectionBtn.classList.remove('hidden'); + backSelectionBtn.innerHTML = ` + + Back to Workspace + `; + } log('Switched to AI Agent mode (MCP + LLM)', 'system'); + } else if (view === 'connector-apps-menu') { + rootSelectionView.classList.add('hidden'); + document.getElementById('connector-apps-selection-view').classList.remove('hidden'); + layoutMain.classList.add('hidden'); + headerActions.classList.remove('hidden'); + connectorStatus.textContent = 'Apps Marketplace'; + tagline.textContent = 'Ready-to-use experiences built on top of connectors'; + document.documentElement.style.setProperty('--brand-accent', '#0d9488'); + if (backSelectionBtn) { + backSelectionBtn.classList.remove('hidden'); + backSelectionBtn.innerHTML = ` + + Back to Workspace + `; + } + log('Opened Connector Apps menu', 'system'); + } else if (view === 'ext-patient-viewer') { + document.getElementById('connector-apps-selection-view').classList.add('hidden'); + rootSelectionView.classList.add('hidden'); + layoutMain.classList.remove('hidden'); + headerActions.classList.remove('hidden'); + agentPanel.classList.add('hidden'); + connectorsView.classList.remove('hidden'); + layoutMain.classList.remove('agent-mode'); + connectorsListPanel.classList.add('hidden'); + playgroundView.classList.remove('hidden'); + if (backToConnectorsBtn) backToConnectorsBtn.classList.add('hidden'); + if (backSelectionBtn) { + backSelectionBtn.classList.remove('hidden'); + backSelectionBtn.innerHTML = ` + + Back to Apps + `; + } + setMode('ext-patient-viewer'); } else { + rootSelectionView.classList.add('hidden'); + layoutMain.classList.remove('hidden'); + headerActions.classList.remove('hidden'); agentPanel.classList.add('hidden'); connectorsView.classList.remove('hidden'); layoutMain.classList.remove('agent-mode'); connectorsListPanel.classList.remove('hidden'); playgroundView.classList.add('hidden'); - connectorStatus.textContent = 'Connectors Ready'; tagline.textContent = 'Enterprise Integration Suite'; document.documentElement.style.setProperty('--brand-accent', '#2563eb'); + if (backSelectionBtn) { + backSelectionBtn.classList.remove('hidden'); + backSelectionBtn.innerHTML = ` + + Back to Workspace + `; + } log('Switched to Connectors view', 'system'); } }); @@ -341,13 +392,38 @@ document.addEventListener('DOMContentLoaded', () => { const returnToHome = (e) => { if (e) e.preventDefault(); rootSelectionView.classList.remove('hidden'); + document.getElementById('connector-apps-selection-view').classList.add('hidden'); layoutMain.classList.add('hidden'); headerActions.classList.add('hidden'); + tagline.textContent = 'Autonomous Connector Orchestration Platform'; log('Returned to main selection screen', 'system'); }; backToHomeBtn.addEventListener('click', returnToHome); - if (backSelectionBtn) backSelectionBtn.addEventListener('click', returnToHome); + + if (backSelectionBtn) { + backSelectionBtn.addEventListener('click', (e) => { + if (e) e.preventDefault(); + const btnText = backSelectionBtn.textContent.trim(); + if (btnText.includes('Back to Apps')) { + // Return from patient viewer to apps marketplace directory + layoutMain.classList.add('hidden'); + document.getElementById('connector-apps-selection-view').classList.remove('hidden'); + connectorStatus.textContent = 'Apps Marketplace'; + tagline.textContent = 'Ready-to-use experiences built on top of connectors'; + document.documentElement.style.setProperty('--brand-accent', '#0d9488'); + log('Returned to Apps Marketplace directory', 'system'); + } else { + // Return from other direct pages to workspace selection home + returnToHome(e); + } + }); + } + + const appsBackBtn = document.getElementById('apps-back-btn'); + if (appsBackBtn) { + appsBackBtn.addEventListener('click', returnToHome); + } const newChatBtn = document.getElementById('new-chat-btn'); newChatBtn.addEventListener('click', () => { @@ -518,6 +594,20 @@ document.addEventListener('DOMContentLoaded', () => { function setMode(mode) { currentMode = mode; + if (backToConnectorsBtn) { + if (mode === 'ext-patient-viewer') { + backToConnectorsBtn.innerHTML = ` + + Back to Workspace + `; + } else { + backToConnectorsBtn.innerHTML = ` + + Back to All Connectors + `; + } + } + // Hide all panels first ehrPanel.classList.add('hidden'); itopsPanel.classList.add('hidden'); @@ -632,19 +722,24 @@ document.addEventListener('DOMContentLoaded', () => { connectorsListPanel.classList.add('hidden'); playgroundView.classList.remove('hidden'); if (backSelectionBtn) backSelectionBtn.classList.add('hidden'); + if (backToConnectorsBtn) backToConnectorsBtn.classList.remove('hidden'); setMode(mode); }); }); // Back to Connectors List backToConnectorsBtn.addEventListener('click', () => { - playgroundView.classList.add('hidden'); - connectorsListPanel.classList.remove('hidden'); - if (backSelectionBtn) backSelectionBtn.classList.remove('hidden'); - connectorStatus.textContent = 'Connectors Ready'; - tagline.textContent = 'Enterprise Integration Suite'; - document.documentElement.style.setProperty('--brand-accent', '#2563eb'); - log('Returned to Connectors list', 'system'); + if (currentMode === 'ext-patient-viewer') { + returnToHome(); + } else { + playgroundView.classList.add('hidden'); + connectorsListPanel.classList.remove('hidden'); + if (backSelectionBtn) backSelectionBtn.classList.remove('hidden'); + connectorStatus.textContent = 'Connectors Ready'; + tagline.textContent = 'Enterprise Integration Suite'; + document.documentElement.style.setProperty('--brand-accent', '#2563eb'); + log('Returned to Connectors list', 'system'); + } }); // Google Drive Sub-mode Switching @@ -1281,7 +1376,7 @@ document.addEventListener('DOMContentLoaded', () => { - Streaming response... + Streaming response... 0.0s `; @@ -1291,6 +1386,7 @@ document.addEventListener('DOMContentLoaded', () => { bubble, text: bubble.querySelector('.streaming-text'), loader: bubble.querySelector('.stream-tail-loader'), + timer: bubble.querySelector('.stream-running-timer') }; } @@ -1304,10 +1400,14 @@ document.addEventListener('DOMContentLoaded', () => { agentChatHistory.scrollTop = agentChatHistory.scrollHeight; } - function appendStreamEndMessage(message, success = true) { + function appendStreamEndMessage(message, success = true, finalTime = null) { const end = document.createElement('div'); end.className = `stream-end-message ${success ? 'success' : 'error'}`; - end.textContent = message || (success ? 'Streaming completed.' : 'Streaming ended with an error.'); + let displayMessage = message || (success ? 'Streaming completed.' : 'Streaming ended with an error.'); + if (finalTime) { + displayMessage += ` (Total Time: ${finalTime}s)`; + } + end.textContent = displayMessage; agentChatHistory.appendChild(end); agentChatHistory.scrollTop = agentChatHistory.scrollHeight; } @@ -1380,7 +1480,13 @@ document.addEventListener('DOMContentLoaded', () => {
${escapeHTML(argsStr)}
${resultPreview ? `
${resultIcon} ${escapeHTML(resultPreview)}
` : ''} `; - agentChatHistory.appendChild(card); + + const streamingBubble = agentChatHistory.querySelector('.streaming-bubble'); + if (streamingBubble) { + agentChatHistory.insertBefore(card, streamingBubble); + } else { + agentChatHistory.appendChild(card); + } agentChatHistory.scrollTop = agentChatHistory.scrollHeight; } @@ -1430,8 +1536,27 @@ document.addEventListener('DOMContentLoaded', () => { log(`Agent Chat: Sending message...`, 'system'); + let timerInterval = null; + let streamView = null; + const startTime = Date.now(); + try { if (agentTransportMode === 'streamable-http') { + // Instantly display the streaming bubble and start the active timer + streamView = appendStreamingBubble(); + agentTyping.classList.add('hidden'); // Hide generic typing dot loader + + function startRunningTimer() { + if (timerInterval) return; + timerInterval = setInterval(() => { + if (streamView && streamView.timer) { + const elapsed = ((Date.now() - startTime) / 1000).toFixed(1); + streamView.timer.textContent = `${elapsed}s`; + } + }, 100); + } + startRunningTimer(); + const response = await fetch('/scenarios/agent-chat-stream', { method: 'POST', headers: { 'Content-Type': 'application/json' }, @@ -1447,7 +1572,6 @@ document.addEventListener('DOMContentLoaded', () => { let traceId = ''; let success = true; let doneMessage = ''; - let streamView = null; await readNdjsonStream(response, { meta: (event) => { @@ -1463,16 +1587,9 @@ document.addEventListener('DOMContentLoaded', () => { args: event.args || {}, result: event.result || '' }); - if (!streamView) { - streamView = appendStreamingBubble(); - } else { - agentChatHistory.appendChild(streamView.bubble); - agentChatHistory.scrollTop = agentChatHistory.scrollHeight; - } }, final_chunk: (event) => { agentTyping.classList.add('hidden'); - if (!streamView) streamView = appendStreamingBubble(); finalText += event.content || ''; streamView.text.textContent = finalText; agentChatHistory.scrollTop = agentChatHistory.scrollHeight; @@ -1480,7 +1597,6 @@ document.addEventListener('DOMContentLoaded', () => { error: (event) => { success = false; agentTyping.classList.add('hidden'); - if (!streamView) streamView = appendStreamingBubble(); finalText += event.message || ''; streamView.text.textContent = finalText; }, @@ -1488,18 +1604,27 @@ document.addEventListener('DOMContentLoaded', () => { traceId = event.trace_id || traceId; success = Boolean(event.success); doneMessage = event.message || `Streaming ${success ? 'completed' : 'failed'}. trace_id=${traceId}`; - if (!streamView) streamView = appendStreamingBubble(); + + if (timerInterval) { + clearInterval(timerInterval); + timerInterval = null; + } + const finalElapsed = ((Date.now() - startTime) / 1000).toFixed(2); streamView.loader.classList.add('hidden'); - appendStreamEndMessage(doneMessage, success); + appendStreamEndMessage(doneMessage, success, finalElapsed); } }); agentTyping.classList.add('hidden'); if (!doneMessage) { - if (!streamView) streamView = appendStreamingBubble(); + if (timerInterval) { + clearInterval(timerInterval); + timerInterval = null; + } + const finalElapsed = ((Date.now() - startTime) / 1000).toFixed(2); streamView.loader.classList.add('hidden'); doneMessage = `Streaming connection closed before done event. trace_id=${traceId || 'unknown'}`; - appendStreamEndMessage(doneMessage, false); + appendStreamEndMessage(doneMessage, false, finalElapsed); success = false; } if (!finalText) { @@ -1545,6 +1670,8 @@ document.addEventListener('DOMContentLoaded', () => { log(`Agent Chat: ${data.success ? 'Success' : 'Responded'} | steps=${data.steps ? data.steps.length : 0}`, data.success ? 'success' : 'system'); } catch (error) { + if (timerInterval) clearInterval(timerInterval); + if (streamView && streamView.loader) streamView.loader.classList.add('hidden'); agentTyping.classList.add('hidden'); appendChatBubble('assistant', `Sorry, I couldn't reach the server: ${error.message}. Please check that the backend is running.`); log(`Agent Chat Error: ${error.message}`, 'error'); diff --git a/playground/index.html b/playground/index.html index 84bcec4..0882d12 100644 --- a/playground/index.html +++ b/playground/index.html @@ -88,6 +88,43 @@

Connectors

+ +
+
+
+ +
+
+

Connector Apps

+

Ready-to-use experiences built on top of connectors

+
+
+ +
+
+
+ + + + @@ -249,26 +286,6 @@

Salesforce

Lead and contact management for CRM-driven enterprise workflows.

- -
-
- - - - - - - - - - - -
-
-

External Patient Viewer

-

Read-only on-demand retrieval of demographics, encounters, and documents from source EHR.

-
-
diff --git a/playground/style.css b/playground/style.css index c8441e7..934f763 100644 --- a/playground/style.css +++ b/playground/style.css @@ -923,7 +923,7 @@ textarea:focus { margin-top: 1.5rem; } -.connector-card { +.connector-card, .app-card { background: var(--card-bg); backdrop-filter: blur(12px); border: 1px solid var(--border); @@ -939,7 +939,7 @@ textarea:focus { overflow: hidden; } -.connector-card:hover { +.connector-card:hover, .app-card:hover { transform: translateY(-5px); border-color: var(--brand-accent); box-shadow: 0 15px 30px rgba(0, 0, 0, 0.06); @@ -1427,6 +1427,17 @@ input[type="range"]::-webkit-slider-thumb:hover { font-weight: 600; } +.stream-running-timer { + font-family: monospace; + font-size: 0.8rem; + font-weight: 700; + color: var(--brand-accent); + background: rgba(37, 99, 235, 0.08); + padding: 0.1rem 0.4rem; + border-radius: 4px; + margin-left: 0.25rem; +} + .stream-end-message { align-self: flex-start; max-width: 85%; @@ -1636,13 +1647,22 @@ input[type="range"]::-webkit-slider-thumb:hover { flex-direction: column; align-items: center; justify-content: center; - min-height: 80vh; - padding: 2rem; - gap: 4rem; + min-height: calc(100vh - 120px); + padding: 1.5rem 0 3rem; + gap: 2.5rem; /* background: radial-gradient(circle at 10% 20%, rgba(139, 92, 246, 0.05) 0%, transparent 40%), */ /* radial-gradient(circle at 90% 80%, rgba(37, 99, 235, 0.05) 0%, transparent 40%); */ } +.apps-selection-view { + justify-content: flex-start; + padding-top: 4rem; +} + +.apps-selection-view .selection-grid { + justify-content: flex-start; +} + .selection-welcome h1 { font-size: 3.5rem; font-weight: 700; @@ -1653,22 +1673,25 @@ input[type="range"]::-webkit-slider-thumb:hover { .selection-grid { display: grid; - grid-template-columns: repeat(2, 1fr); - gap: 3rem; + grid-template-columns: repeat(3, minmax(280px, 1fr)); + align-items: stretch; + gap: 2rem; width: 100%; - max-width: 900px; + max-width: 1240px; } .selection-card { + width: 100%; background: white; border: 1px solid #e2e8f0; - border-radius: 20px; + border-radius: 1.8rem; cursor: pointer; transition: all 0.5s cubic-bezier(0.4, 0, 0.2, 1); - box-shadow: 0 20px 50px rgba(0, 0, 0, 0.04); + box-shadow: 0 18px 48px rgba(148, 163, 184, 0.18); position: relative; overflow: hidden; display: flex; + min-height: 356px; } .card-inner { @@ -1676,53 +1699,72 @@ input[type="range"]::-webkit-slider-thumb:hover { display: flex; flex-direction: column; align-items: center; - padding: 3rem 2rem 0; + justify-content: center; + padding: 2.2rem 2rem 2rem; } .selection-card:hover { transform: translateY(-8px); - box-shadow: 0 40px 80px rgba(0, 0, 0, 0.08); + box-shadow: 0 30px 70px rgba(148, 163, 184, 0.28); } .selection-icon { - width: 80px; - height: 80px; + width: 64px; + height: 64px; border-radius: 50%; display: flex; align-items: center; justify-content: center; - margin-bottom: 2rem; + margin-bottom: 1.25rem; transition: all 0.3s; } -.card-mcp .selection-icon { background: #f5f3ff; color: #8b5cf6; } -.card-connectors .selection-icon { background: #eff6ff; color: #2563eb; } +.card-mcp .selection-icon { background: #f3e8ff; color: #7c3aed; } +.card-connectors .selection-icon { background: #e0f2fe; color: #0284c7; } +.card-ext-viewer .selection-icon, +.card-apps-directory .selection-icon { background: #ccfbf1; color: #0d9488; } + +.selection-details { + text-align: center; + width: 100%; + display: flex; + flex-direction: column; + align-items: center; +} .selection-details h3 { - font-size: 1.75rem; + font-size: 1.5rem; font-family: "Outfit", sans-serif; color: #1e293b; - margin-bottom: 0.75rem; + margin-bottom: 0.7rem; + line-height: 1.2; } .selection-details p { - font-size: 1rem; + font-size: 0.92rem; color: #64748b; - margin-bottom: 2.5rem; /* Space before action bar */ + margin-bottom: 1.4rem; + line-height: 1.4; + max-width: 310px; + min-height: 2.6em; + padding: 0 0.5rem; } .action-bar { - width: calc(100% + 4rem); - margin: 0 -2rem; - height: 80px; + width: 100%; + max-width: 240px; + height: 60px; display: flex; align-items: center; justify-content: center; transition: all 0.3s; + margin-top: auto; } -.card-mcp .action-bar { background: #f5f3ff; color: #8b5cf6; } -.card-connectors .action-bar { background: #eff6ff; color: #2563eb; } +.card-mcp .action-bar { background: #f3e8ff; color: #7c3aed; } +.card-connectors .action-bar { background: #e0f2fe; color: #0284c7; } +.card-ext-viewer .action-bar, +.card-apps-directory .action-bar { background: #ccfbf1; color: #0d9488; } .selection-card:hover .action-bar { filter: brightness(0.95); @@ -1742,9 +1784,26 @@ input[type="range"]::-webkit-slider-thumb:hover { .card-mcp:hover { border-color: #8b5cf6; } .card-connectors { border: 2px solid transparent; } .card-connectors:hover { border-color: #2563eb; } +.card-ext-viewer, +.card-apps-directory { border: 2px solid transparent; } +.card-ext-viewer:hover, +.card-apps-directory:hover { border-color: #0d9488; } @media (max-width: 900px) { - .selection-grid { grid-template-columns: 1fr; } + .root-selection-view { + min-height: auto; + padding-top: 1rem; + } + + .selection-grid { + grid-template-columns: 1fr; + max-width: 420px; + } + + .selection-card { + min-height: 332px; + } + .selection-welcome h1 { font-size: 2.5rem; } } diff --git a/sample.env b/sample.env index 6a82c10..00383f6 100644 --- a/sample.env +++ b/sample.env @@ -83,7 +83,7 @@ OPENAI_MODEL=gpt-4o-mini # Google Gemini (optional) GEMINI_API_KEY=your-gemini-api-key -GEMINI_MODEL=gemini-2.0-flash +GEMINI_MODEL=gemini-2.5-flash # Anthropic / Claude (optional) ANTHROPIC_API_KEY=your-anthropic-api-key From babce8f0dd155a4d97b0d4a0b866f8e08a05c6b7 Mon Sep 17 00:00:00 2001 From: rahul-aot Date: Fri, 22 May 2026 14:35:30 +0530 Subject: [PATCH 42/60] Add integration tests for Google Drive connector and enhance playground navigation --- .github/workflows/lint.yml | 6 + .github/workflows/pytest.yml | 56 +++- .github/workflows/quality-gates.yml | 6 + .github/workflows/security-pr.yml | 7 +- pyproject.toml | 2 + sample.env | 4 + tests/playground/conftest.py | 79 +++++ tests/playground/gdrive/README.md | 86 ++++++ tests/playground/gdrive/conftest.py | 70 +++++ tests/playground/gdrive/gdrive_page.py | 140 +++++++++ .../gdrive/test_gdrive_integration.py | 275 ++++++++++++++++++ tests/playground/home_page.py | 52 ++++ .../playground/test_playground_integration.py | 72 +++++ 13 files changed, 853 insertions(+), 2 deletions(-) create mode 100644 tests/playground/conftest.py create mode 100644 tests/playground/gdrive/README.md create mode 100644 tests/playground/gdrive/conftest.py create mode 100644 tests/playground/gdrive/gdrive_page.py create mode 100644 tests/playground/gdrive/test_gdrive_integration.py create mode 100644 tests/playground/home_page.py create mode 100644 tests/playground/test_playground_integration.py diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 90c2680..2c3ae51 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,3 +1,9 @@ +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## + + name: Lint and Type Check on: diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index b2a258b..ecf25f8 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -1,3 +1,9 @@ + +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## + name: CI – Pytest on: @@ -45,4 +51,52 @@ jobs: with: name: coverage-html-py${{ matrix.python-version }} path: htmlcov/ - if-no-files-found: ignore \ No newline at end of file + if-no-files-found: ignore + + # ── Playground integration tests ────────────────────────────────────────── + # Runs only on manual workflow_dispatch. Requires real Google Drive + # credentials stored as repository secrets (see tests/playground/README.md). + playground-integration: + name: Playground integration tests + runs-on: ubuntu-latest + if: github.event_name == 'workflow_dispatch' + + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + cache-dependency-glob: | + pyproject.toml + + - name: Install dependencies + run: uv sync --all-extras --dev + + - name: Install Playwright browsers + run: uv run python -m playwright install chromium --with-deps + + - name: Run playground integration tests + env: + GOOGLE_DRIVE_SA_JSON: ${{ secrets.GOOGLE_DRIVE_SA_JSON }} + GOOGLE_DRIVE_FOLDER_ID: ${{ secrets.GOOGLE_DRIVE_FOLDER_ID }} + GDRIVE_TEST_RECIPIENT_EMAIL: ${{ secrets.GDRIVE_TEST_RECIPIENT_EMAIL }} + NW_REST_AUTH_DISABLED: "true" + NW_REST_LOAD_DOTENV: "false" + NW_ALLOWED_CONNECTORS: "google_drive" + run: uv run pytest tests/playground/ --no-cov -v + + - name: Upload Playwright traces on failure + if: failure() + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + with: + name: playwright-traces + path: test-results/ + if-no-files-found: ignore diff --git a/.github/workflows/quality-gates.yml b/.github/workflows/quality-gates.yml index d3c079b..524fc09 100644 --- a/.github/workflows/quality-gates.yml +++ b/.github/workflows/quality-gates.yml @@ -1,3 +1,9 @@ + +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## + name: Quality gates on: diff --git a/.github/workflows/security-pr.yml b/.github/workflows/security-pr.yml index 22cf5cb..264ef7f 100644 --- a/.github/workflows/security-pr.yml +++ b/.github/workflows/security-pr.yml @@ -1,4 +1,9 @@ -# Continuous security checks for publishable Python packages on pull requests. +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +### Continuous security checks for publishable Python packages on pull requests. + + name: Python package security PR checks on: diff --git a/pyproject.toml b/pyproject.toml index 273416d..2784685 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ dev = [ "mypy>=1.9.0", "bandit[toml]>=1.7.9", "pre-commit>=4.0.0", + "pytest-playwright>=0.4.0", ] [tool.uv] @@ -107,6 +108,7 @@ addopts = [ "--cov-report=term-missing", "--cov-report=html:htmlcov", "--cov-report=xml:coverage.xml", + "--ignore=tests/playground", ] [tool.coverage.run] diff --git a/sample.env b/sample.env index 6a82c10..aa8bd40 100644 --- a/sample.env +++ b/sample.env @@ -144,3 +144,7 @@ SALESFORCE_TOKEN_URL=https://login.salesforce.com/services/oauth2/token SALESFORCE_CLIENT_ID=your-client-id SALESFORCE_CLIENT_SECRET=your-client-secret SALESFORCE_REFRESH_TOKEN=your-refresh-token + + +# Playwright playground headed execution - set to "true" to view the browser and its activities +PLAYGROUND_HEADED=false diff --git a/tests/playground/conftest.py b/tests/playground/conftest.py new file mode 100644 index 0000000..caabf0a --- /dev/null +++ b/tests/playground/conftest.py @@ -0,0 +1,79 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import os +import socket +import threading +import time + +from dotenv import load_dotenv +import httpx +import pytest + +# Load .env before any app imports so connectors initialise with real credentials. +load_dotenv(override=False) + + +@pytest.fixture(scope="session") +def browser_type_launch_args(browser_type_launch_args): + """Override Playwright launch arguments dynamically via environment variables.""" + env_val = ( + os.getenv("PLAYGROUND_HEADED") or os.getenv("HEADED") or os.getenv("PLAYWRIGHT_HEADLESS") + ) + is_headed = False + if env_val: + env_val_lower = env_val.lower().strip() + if env_val_lower in ("true", "1", "yes"): + is_headed = True + elif env_val_lower in ("false", "0", "no") and os.getenv("PLAYWRIGHT_HEADLESS"): + is_headed = True + return {**browser_type_launch_args, "headless": not is_headed} + + +@pytest.fixture(scope="session") +def api_server_url(): + """Start the real FastAPI server on a free port and yield its base URL. + + The playground UI is served at /playground/ and the scenarios API at + /scenarios/*, so browser fetch() calls with relative paths resolve + correctly without any Playwright route interception. + """ + import uvicorn # noqa: PLC0415 + from bindings.rest_api.app import app as rest_app # noqa: PLC0415 + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + port = s.getsockname()[1] + + config = uvicorn.Config(rest_app, host="127.0.0.1", port=port, log_level="error") + server = uvicorn.Server(config) + + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + base = f"http://127.0.0.1:{port}" + with httpx.Client(timeout=2) as probe: + for _ in range(60): + try: + probe.get(f"{base}/health") + break + except Exception: + time.sleep(0.3) + else: + pytest.fail("FastAPI server did not start within 18 seconds") + + yield base + + server.should_exit = True + thread.join(timeout=5) + + +@pytest.fixture +def playground_page(page, api_server_url: str): + """Navigate to the playground served by the real FastAPI server.""" + page.goto(f"{api_server_url}/playground/") + page.wait_for_load_state("domcontentloaded") + return page diff --git a/tests/playground/gdrive/README.md b/tests/playground/gdrive/README.md new file mode 100644 index 0000000..fe68471 --- /dev/null +++ b/tests/playground/gdrive/README.md @@ -0,0 +1,86 @@ + + +# Google Drive Playground Integration Tests + +End-to-end Playwright tests that open the Playground UI in a real browser, +navigate to the Google Drive connector panel, and assert on the rendered +pipeline state. No mocking — every test hits the real Google Drive API. + +## What is tested + +| Test | Action | +|------|--------| +| `test_gdrive_list_files_default_page_size` | `files.list` — default page size | +| `test_gdrive_list_files_explicit_page_size` | `files.list` — explicit page_size=5 | +| `test_gdrive_list_files_with_query` | `files.list` — mimeType filter | +| `test_gdrive_get_file` | `files.get` — valid file ID with field mask | +| `test_gdrive_get_file_without_fields` | `files.get` — no fields mask | +| `test_gdrive_get_file_invalid_id` | `files.get` — nonexistent ID, expects error state | +| `test_gdrive_update_file_name` | `files.update` — rename file | +| `test_gdrive_update_file_name_and_mime` | `files.update` — rename + mime_type | +| `test_gdrive_upload_file` | `files.upload` — attach file, fill recipient, assert 4-step pipeline | +| `test_gdrive_upload_remove_and_reattach` | `files.upload` — remove attachment UI, re-attach | +| `test_gdrive_switch_list_then_get` | Cross-action switch on same page | + +## How it works + +The test session starts a real FastAPI server on a random local port. Playwright +navigates to `/playground/`. The browser's `fetch("/scenarios/gdrive-archival")` +calls route to the real backend, which calls the real Google Drive API. +No `page.route()` interception. + +## Running locally + +```bash +# Install Playwright browsers (once) +uv run python -m playwright install chromium + +# Run all GDrive tests +uv run pytest tests/playground/gdrive/ --no-cov -v + +# Run headed (watch the browser) +PLAYGROUND_HEADED=true uv run pytest tests/playground/gdrive/ --no-cov -v -s +``` + +> **Note:** GDrive tests are excluded from the default `uv run pytest` run and +> from regular CI (push/PR). They must be triggered explicitly. + +## Required environment variables + +Set these before running (`.env` is loaded automatically if present): + +| Variable | Description | +|----------|-------------| +| `GOOGLE_DRIVE_SA_JSON` | Service-account JSON (path to file or full JSON string inline) | +| `GOOGLE_DRIVE_FOLDER_ID` | Google Drive folder ID where test files are uploaded | +| `GDRIVE_TEST_RECIPIENT_EMAIL` | Sharing recipient email for upload tests (default: `test@mailinator.com`) | +| `NW_REST_AUTH_DISABLED` | Set to `true` to skip REST auth middleware | + +## CI / GitHub Actions + +GDrive tests run **only on manual `workflow_dispatch`** trigger +(`Actions → CI – Pytest → Run workflow`). + +Credentials are read from repository secrets: + +| Secret | Maps to env var | +|--------|----------------| +| `GOOGLE_DRIVE_SA_JSON` | `GOOGLE_DRIVE_SA_JSON` | +| `GOOGLE_DRIVE_FOLDER_ID` | `GOOGLE_DRIVE_FOLDER_ID` | +| `GDRIVE_TEST_RECIPIENT_EMAIL` | `GDRIVE_TEST_RECIPIENT_EMAIL` | + +Set these under **Settings → Secrets and variables → Actions** before triggering +the workflow. + +## Test data and cleanup + +The `uploaded_test_file_id` session fixture uploads a small file +(`nw-integration-test.txt`) to Google Drive once per test session. This file +is **not automatically deleted** after the tests finish — clean it up manually +via the Google Drive UI if needed. + +The `files.update` tests rename this file but do not delete it. diff --git a/tests/playground/gdrive/conftest.py b/tests/playground/gdrive/conftest.py new file mode 100644 index 0000000..6214692 --- /dev/null +++ b/tests/playground/gdrive/conftest.py @@ -0,0 +1,70 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import base64 +import os + +import httpx +import pytest + +_TEST_RECIPIENT_EMAIL = os.environ.get( + "GDRIVE_TEST_RECIPIENT_EMAIL", "test@mailinator.com" +) + + +@pytest.fixture(scope="session") +def real_gdrive_file_id(api_server_url: str) -> str: + """Return a real Google Drive file ID by listing the Drive via the API. + + Skips the test if no files exist in the configured Drive folder. + """ + with httpx.Client(timeout=30) as client: + resp = client.post( + f"{api_server_url}/scenarios/gdrive-archival", + json={"action": "files.list", "list_page_size": 5}, + ) + resp.raise_for_status() + data = resp.json() + files = data.get("steps", [{}])[0].get("data", {}).get("raw", {}).get("files", []) + if not files: + pytest.skip("No files found in Google Drive — skipping tests that need a real file ID") + return files[0]["id"] + + +@pytest.fixture(scope="session") +def uploaded_test_file_id(api_server_url: str) -> str: + """Upload a small test file to Google Drive once per session and return its ID. + + Used by files.update tests so they operate on a disposable file. + Note: the file is left in Google Drive after the session (manual cleanup needed). + """ + content = b"node-wire integration test file - safe to delete" + with httpx.Client(timeout=60) as client: + resp = client.post( + f"{api_server_url}/scenarios/gdrive-archival", + json={ + "action": "files.upload", + "document_name": "nw-integration-test.txt", + "recipient_email": _TEST_RECIPIENT_EMAIL, + "file_base64": base64.b64encode(content).decode(), + "file_mime_type": "text/plain", + }, + ) + resp.raise_for_status() + data = resp.json() + file_id = data.get("final_resource_id") + if not file_id: + pytest.skip( + f"Setup upload failed — cannot run update tests. " + f"Error: {data.get('error_message') or 'no file_id returned'}" + ) + return file_id + + +@pytest.fixture(scope="session") +def test_recipient_email() -> str: + """Email address used as the sharing recipient in upload tests.""" + return _TEST_RECIPIENT_EMAIL diff --git a/tests/playground/gdrive/gdrive_page.py b/tests/playground/gdrive/gdrive_page.py new file mode 100644 index 0000000..46ae6de --- /dev/null +++ b/tests/playground/gdrive/gdrive_page.py @@ -0,0 +1,140 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from playwright.sync_api import Page, Locator + + +class GoogleDrivePage: + """Page Object Model for the Google Drive connector panel in the Playground.""" + + def __init__(self, page: Page) -> None: + self.page = page + + # Selector for the Google Drive card inside system connectors view + self.connector_card: Locator = page.locator(".connector-card[data-mode='gdrive']") + + # Panel root and main headers + self.panel: Locator = page.locator("#gdrive-panel") + self.title: Locator = page.locator("#gdrive-panel .card-title h2") + self.action_select: Locator = page.locator("#gdrive-action-select") + self.run_btn: Locator = page.locator("#gdrive-run-btn") + self.back_to_connectors: Locator = page.locator("#back-to-connectors") + + # --- files.upload action elements --- + self.upload_section: Locator = page.locator("#gdrive-upload-only") + self.recipient_email: Locator = page.locator( + "#gdrive-upload-only input[name='recipient_email']" + ) + self.doc_name_group: Locator = page.locator("#gdrive-doc-name-group") + self.document_name: Locator = page.locator( + "#gdrive-doc-name-group input[name='document_name']" + ) + self.file_section: Locator = page.locator("#gdrive-file-section") + self.file_input: Locator = page.locator("#gdrive-file") + self.file_drop_zone: Locator = page.locator("#file-drop-zone") + self.file_chosen_preview: Locator = page.locator("#file-chosen-preview") + self.preview_name: Locator = page.locator("#file-chosen-preview .preview-name") + self.remove_file_btn: Locator = page.locator("#file-chosen-preview .remove-file-btn") + + # --- files.get action elements --- + self.get_section: Locator = page.locator("#gdrive-get-only") + self.get_file_id: Locator = page.locator("#gdrive-get-only input[name='get_file_id']") + self.get_fields: Locator = page.locator("#gdrive-get-only input[name='get_fields']") + + # --- files.update action elements --- + self.update_section: Locator = page.locator("#gdrive-update-only") + self.update_file_id: Locator = page.locator( + "#gdrive-update-only input[name='update_file_id']" + ) + self.update_name: Locator = page.locator("#gdrive-update-only input[name='update_name']") + self.update_mime_type: Locator = page.locator( + "#gdrive-update-only input[name='update_mime_type']" + ) + self.update_add_parents: Locator = page.locator( + "#gdrive-update-only input[name='update_add_parents']" + ) + self.update_remove_parents: Locator = page.locator( + "#gdrive-update-only input[name='update_remove_parents']" + ) + + # --- files.list action elements --- + self.list_section: Locator = page.locator("#gdrive-list-only") + self.list_page_size: Locator = page.locator( + "#gdrive-list-only input[name='list_page_size']" + ) + self.list_query: Locator = page.locator("#gdrive-list-only input[name='list_query']") + self.list_fields: Locator = page.locator("#gdrive-list-only input[name='list_fields']") + + # --- Output and Logs elements --- + self.pipeline_steps: Locator = page.locator(".flow-node") + self.step_nodes: list[Locator] = [page.locator(f"#step-{i}") for i in range(4)] + self.final_result: Locator = page.locator("#final-result") + self.summary_text: Locator = page.locator("#human-summary") + self.result_tag: Locator = page.locator("#result-id") + self.log_terminal: Locator = page.locator("#log-terminal") + + def navigate_to_panel(self) -> None: + """Click the Google Drive card in system connectors to open the panel.""" + self.connector_card.click() + + def select_action(self, action: str) -> None: + """Change the action via the select element.""" + self.action_select.select_option(action) + + def fill_upload_fields(self, recipient_email: str, doc_name: str | None = None) -> None: + """Fill upload parameters.""" + self.recipient_email.fill(recipient_email) + if doc_name is not None: + # First ensure doc_name field is shown by switching to Write Note/sub-mode if needed, + # or fill directly if exposed. + self.document_name.fill(doc_name) + + def fill_get_fields(self, file_id: str, fields: str | None = None) -> None: + """Fill get parameters.""" + self.get_file_id.fill(file_id) + if fields is not None: + self.get_fields.fill(fields) + + def fill_update_fields( + self, + file_id: str, + new_name: str | None = None, + mime_type: str | None = None, + add_parents: str | None = None, + remove_parents: str | None = None, + ) -> None: + """Fill update parameters.""" + self.update_file_id.fill(file_id) + if new_name is not None: + self.update_name.fill(new_name) + if mime_type is not None: + self.update_mime_type.fill(mime_type) + if add_parents is not None: + self.update_add_parents.fill(add_parents) + if remove_parents is not None: + self.update_remove_parents.fill(remove_parents) + + def fill_list_fields( + self, + page_size: int | None = None, + query: str | None = None, + fields: str | None = None, + ) -> None: + """Fill list parameters.""" + if page_size is not None: + self.list_page_size.fill(str(page_size)) + if query is not None: + self.list_query.fill(query) + if fields is not None: + self.list_fields.fill(fields) + + def submit(self) -> None: + """Submit the form to execute the archival/orchestration workflow.""" + self.run_btn.click() + + def go_back(self) -> None: + """Click 'Back to All Connectors' to return to connectors selection view.""" + self.back_to_connectors.click() diff --git a/tests/playground/gdrive/test_gdrive_integration.py b/tests/playground/gdrive/test_gdrive_integration.py new file mode 100644 index 0000000..eee9eee --- /dev/null +++ b/tests/playground/gdrive/test_gdrive_integration.py @@ -0,0 +1,275 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Google Drive connector Playground real integration tests. + +Each test opens the Playground UI, navigates to the Google Drive panel, +selects an action, fills the form, clicks the run button, and asserts the +resulting pipeline state — no API mocking, real Google Drive calls. + +Required env vars (loaded from .env): + GOOGLE_DRIVE_SA_JSON — service-account JSON (path or inline JSON) + GOOGLE_DRIVE_FOLDER_ID — target folder for uploads + GDRIVE_TEST_RECIPIENT_EMAIL — email used as sharing recipient (default: rahul.ap@aot-technologies.com) +""" + +from __future__ import annotations + +import os +import tempfile +import time + +from playwright.sync_api import Page, expect + +from tests.playground.gdrive.gdrive_page import GoogleDrivePage +from tests.playground.home_page import PlaygroundHomePage + +_TIMEOUT_STEP = 20_000 # ms — single-step operations (list, get) +_TIMEOUT_MULTI = 45_000 # ms — multi-step operations (upload, update) + + +def _navigate_to_gdrive(page: Page) -> GoogleDrivePage: + PlaygroundHomePage(page).click_connectors() + gdrive = GoogleDrivePage(page) + gdrive.navigate_to_panel() + return gdrive + + +def _maybe_sleep() -> None: + env = os.getenv("PLAYGROUND_HEADED") or os.getenv("HEADED") + if env and env.lower().strip() in ("true", "1", "yes"): + time.sleep(3) + + +# ── files.list ──────────────────────────────────────────────────────────────── + + +def test_gdrive_list_files_default_page_size(playground_page: Page) -> None: + """List files with the default page size; assert the pipeline step succeeds.""" + gdrive = _navigate_to_gdrive(playground_page) + + gdrive.select_action("files.list") + gdrive.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.summary_text).to_contain_text("file(s)") + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(gdrive.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +def test_gdrive_list_files_explicit_page_size(playground_page: Page) -> None: + """List files with page_size=5; summary must mention the requested page size.""" + gdrive = _navigate_to_gdrive(playground_page) + + gdrive.select_action("files.list") + gdrive.fill_list_fields(page_size=5) + gdrive.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.summary_text).to_contain_text("page size 5") + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(gdrive.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +def test_gdrive_list_files_with_query(playground_page: Page) -> None: + """List files filtered by mimeType query; step label and success state must appear.""" + gdrive = _navigate_to_gdrive(playground_page) + + gdrive.select_action("files.list") + gdrive.fill_list_fields(page_size=10, query="mimeType='text/plain'") + gdrive.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_STEP) + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + + _maybe_sleep() + + +# ── files.get ───────────────────────────────────────────────────────────────── + + +def test_gdrive_get_file(playground_page: Page, real_gdrive_file_id: str) -> None: + """Retrieve metadata for a real file; assert single-step success and result card.""" + gdrive = _navigate_to_gdrive(playground_page) + + gdrive.select_action("files.get") + gdrive.fill_get_fields(real_gdrive_file_id, "id,name,mimeType") + gdrive.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.summary_text).to_contain_text("Google Drive file metadata") + expect(gdrive.result_tag).to_contain_text(real_gdrive_file_id) + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(gdrive.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +def test_gdrive_get_file_without_fields(playground_page: Page, real_gdrive_file_id: str) -> None: + """files.get without a fields mask; Drive returns default metadata fields.""" + gdrive = _navigate_to_gdrive(playground_page) + + gdrive.select_action("files.get") + gdrive.fill_get_fields(real_gdrive_file_id) # no fields argument + gdrive.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_STEP) + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + + _maybe_sleep() + + +def test_gdrive_get_file_invalid_id(playground_page: Page) -> None: + """files.get with a nonexistent ID; the pipeline step must show the error state.""" + gdrive = _navigate_to_gdrive(playground_page) + + gdrive.select_action("files.get") + gdrive.fill_get_fields("this-id-does-not-exist-9999999999") + gdrive.submit() + + expect(playground_page.locator("#step-0.error")).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.final_result).to_be_hidden() + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Failed") + expect(gdrive.log_terminal).to_contain_text("FAILED") + + _maybe_sleep() + + +# ── files.update ────────────────────────────────────────────────────────────── + + +def test_gdrive_update_file_name(playground_page: Page, uploaded_test_file_id: str) -> None: + """Rename the integration-test file; assert all 4 update pipeline steps succeed.""" + gdrive = _navigate_to_gdrive(playground_page) + + gdrive.select_action("files.update") + gdrive.fill_update_fields( + file_id=uploaded_test_file_id, + new_name="nw-integration-test-renamed.txt", + ) + gdrive.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT_MULTI) + + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_MULTI) + expect(gdrive.summary_text).to_contain_text("Updated") + expect(gdrive.result_tag).to_contain_text(uploaded_test_file_id) + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(gdrive.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +def test_gdrive_update_file_name_and_mime( + playground_page: Page, uploaded_test_file_id: str +) -> None: + """Update both the file name and mime_type; all 4 steps must succeed.""" + gdrive = _navigate_to_gdrive(playground_page) + + gdrive.select_action("files.update") + gdrive.fill_update_fields( + file_id=uploaded_test_file_id, + new_name="nw-integration-test-v2.txt", + mime_type="text/plain", + ) + gdrive.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT_MULTI) + + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_MULTI) + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + + _maybe_sleep() + + +# ── files.upload ────────────────────────────────────────────────────────────── + + +def test_gdrive_upload_file(playground_page: Page, test_recipient_email: str) -> None: + """Attach a temp file, fill recipient email, submit, assert all 4 steps succeed.""" + gdrive = _navigate_to_gdrive(playground_page) + + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False, prefix="nw_ui_test_") as tmp: + tmp.write(b"Integration test document - uploaded via Playwright UI test.") + tmp_path = tmp.name + + gdrive.file_input.set_input_files(tmp_path) + expect(gdrive.file_chosen_preview).to_be_visible(timeout=3_000) + expect(gdrive.file_drop_zone).to_be_hidden() + expect(gdrive.preview_name).to_contain_text("nw_ui_test_") + + gdrive.fill_upload_fields(test_recipient_email) + gdrive.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT_MULTI) + + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_MULTI) + expect(gdrive.summary_text).to_contain_text("archived to Google Drive") + expect(gdrive.summary_text).to_contain_text(test_recipient_email) + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(gdrive.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +def test_gdrive_upload_remove_and_reattach(playground_page: Page) -> None: + """Remove an attached file → drop zone reappears; re-attach → preview is restored.""" + gdrive = _navigate_to_gdrive(playground_page) + + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False, prefix="nw_reattach_") as tmp: + tmp.write(b"Reattach UI test content - safe to delete") + tmp_path = tmp.name + + # Attach + gdrive.file_input.set_input_files(tmp_path) + expect(gdrive.file_chosen_preview).to_be_visible(timeout=3_000) + expect(gdrive.file_drop_zone).to_be_hidden() + + # Remove + gdrive.remove_file_btn.click() + expect(gdrive.file_chosen_preview).to_be_hidden(timeout=3_000) + expect(gdrive.file_drop_zone).to_be_visible() + + # Re-attach + gdrive.file_input.set_input_files(tmp_path) + expect(gdrive.file_chosen_preview).to_be_visible(timeout=3_000) + expect(gdrive.preview_name).to_contain_text("nw_reattach_") + + +# ── cross-action switch ─────────────────────────────────────────────────────── + + +def test_gdrive_switch_list_then_get(playground_page: Page, real_gdrive_file_id: str) -> None: + """Run files.list, switch to files.get on the same page — both must complete successfully.""" + gdrive = _navigate_to_gdrive(playground_page) + + # First run: files.list + gdrive.select_action("files.list") + gdrive.submit() + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.summary_text).to_contain_text("file(s)") + + # Switch action and run again + gdrive.select_action("files.get") + gdrive.fill_get_fields(real_gdrive_file_id) + gdrive.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.summary_text).to_contain_text("Google Drive file metadata") + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + + _maybe_sleep() diff --git a/tests/playground/home_page.py b/tests/playground/home_page.py new file mode 100644 index 0000000..08cf7e6 --- /dev/null +++ b/tests/playground/home_page.py @@ -0,0 +1,52 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from playwright.sync_api import Page, Locator + + +class PlaygroundHomePage: + """Page Object Model for the node-wire Playground Home (landing/selection) page.""" + + def __init__(self, page: Page) -> None: + self.page = page + + # Section views + self.root_selection_view: Locator = page.locator("#root-selection-view") + self.main_layout: Locator = page.locator(".layout-main") + + # Header components + self.brand_header: Locator = page.locator(".dashboard-header") + self.brand_title: Locator = page.locator(".brand-text h1") + self.tagline: Locator = page.locator(".tagline") + self.header_actions: Locator = page.locator("#header-actions") + + # Selection Cards + self.selection_cards: Locator = page.locator(".selection-card") + + # Agentic Workflow Card + self.agentic_card: Locator = page.locator(".selection-card.card-mcp") + self.agentic_card_title: Locator = self.agentic_card.locator("h3") + self.agentic_card_desc: Locator = self.agentic_card.locator("p") + + # Connectors Card + self.connectors_card: Locator = page.locator(".selection-card.card-connectors") + self.connectors_card_title: Locator = self.connectors_card.locator("h3") + self.connectors_card_desc: Locator = self.connectors_card.locator("p") + + # Navigation + self.back_selection_btn: Locator = page.locator("#back-selection-btn") + + def click_agentic_workflow(self) -> None: + """Click the Agentic Workflow (MCP) selection card to navigate to the agent view.""" + self.agentic_card.click() + + def click_connectors(self) -> None: + """Click the Connectors selection card to navigate to the clinical workflows view.""" + self.connectors_card.click() + + def go_back_to_selection(self) -> None: + """Click the back button to return to the selection page.""" + self.back_selection_btn.click() diff --git a/tests/playground/test_playground_integration.py b/tests/playground/test_playground_integration.py new file mode 100644 index 0000000..f0244d1 --- /dev/null +++ b/tests/playground/test_playground_integration.py @@ -0,0 +1,72 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Playground Home page integration test. + +This test loads the playground page, asserts all header elements and cards, +and verifies interactive transitions between the home page selection view and +individual dashboard views (Agentic Workflow and Connectors) using the Page Object Model. +""" + +from __future__ import annotations + +import os +import time +from playwright.sync_api import Page, expect + +from tests.playground.home_page import PlaygroundHomePage + + +def test_playground_home_page_flow(playground_page: Page) -> None: + """Verify elements visibility, cards presence, and navigation transitions on the Playground Home page.""" + home = PlaygroundHomePage(playground_page) + + # 1. Assert overall page title + assert playground_page.title() == "node-wire Playground" + + # 2. Verify visibility of root components and headers + expect(home.root_selection_view).to_be_visible() + expect(home.main_layout).to_be_hidden() + expect(home.brand_header).to_be_visible() + expect(home.brand_title).to_contain_text("node-") + expect(home.tagline).to_be_visible() + expect(home.header_actions).to_be_hidden() + + # 3. Assert card counts and detailed card contents + assert home.selection_cards.count() == 2 + + # Agentic Workflow Card + expect(home.agentic_card).to_be_visible() + expect(home.agentic_card_title).to_have_text("Agentic Workflow") + expect(home.agentic_card_desc).to_contain_text("via ToolHive") + + # Connectors Card + expect(home.connectors_card).to_be_visible() + expect(home.connectors_card_title).to_have_text("Connectors") + expect(home.connectors_card_desc).to_contain_text("Pre-built Clinical Workflows") + + # 4. Test Navigation Flow: Root -> Agentic Workflow -> Root + home.click_agentic_workflow() + expect(home.root_selection_view).to_be_hidden() + expect(home.main_layout).to_be_visible() + + # Return back to home + home.go_back_to_selection() + expect(home.root_selection_view).to_be_visible() + expect(home.main_layout).to_be_hidden() + + # 5. Test Navigation Flow: Root -> Connectors -> Root + home.click_connectors() + expect(home.root_selection_view).to_be_hidden() + expect(home.main_layout).to_be_visible() + + # Return back to home + home.go_back_to_selection() + expect(home.root_selection_view).to_be_visible() + expect(home.main_layout).to_be_hidden() + + # 6. Optional visual delay for headed mode + is_headed = os.getenv("PLAYGROUND_HEADED") or os.getenv("HEADED") + if is_headed and is_headed.lower().strip() in ("true", "1", "yes"): + time.sleep(5) From 0e49ba46fb2c665e6a909ea98ff4879b35d2d0aa Mon Sep 17 00:00:00 2001 From: rahul-aot Date: Fri, 22 May 2026 15:09:11 +0530 Subject: [PATCH 43/60] streamline environment variable retrieval for test recipient email --- tests/playground/gdrive/conftest.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/playground/gdrive/conftest.py b/tests/playground/gdrive/conftest.py index 6214692..3c40aec 100644 --- a/tests/playground/gdrive/conftest.py +++ b/tests/playground/gdrive/conftest.py @@ -10,9 +10,7 @@ import httpx import pytest -_TEST_RECIPIENT_EMAIL = os.environ.get( - "GDRIVE_TEST_RECIPIENT_EMAIL", "test@mailinator.com" -) +_TEST_RECIPIENT_EMAIL = os.environ.get("GDRIVE_TEST_RECIPIENT_EMAIL", "test@mailinator.com") @pytest.fixture(scope="session") From 80ef687dd0baa5e97ad972a13e9e407a97bca9d9 Mon Sep 17 00:00:00 2001 From: rahul-aot Date: Fri, 22 May 2026 15:36:29 +0530 Subject: [PATCH 44/60] Add Connector Apps card and navigation to Playground Home integration tests --- tests/playground/home_page.py | 17 +++++++++++++++++ .../playground/test_playground_integration.py | 19 +++++++++++++++++-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/tests/playground/home_page.py b/tests/playground/home_page.py index 08cf7e6..08910a6 100644 --- a/tests/playground/home_page.py +++ b/tests/playground/home_page.py @@ -36,6 +36,15 @@ def __init__(self, page: Page) -> None: self.connectors_card_title: Locator = self.connectors_card.locator("h3") self.connectors_card_desc: Locator = self.connectors_card.locator("p") + # Connector Apps Card + self.connector_apps_card: Locator = page.locator(".selection-card.card-apps-directory") + self.connector_apps_card_title: Locator = self.connector_apps_card.locator("h3") + self.connector_apps_card_desc: Locator = self.connector_apps_card.locator("p") + + # Connector Apps sub-menu view + self.connector_apps_view: Locator = page.locator("#connector-apps-selection-view") + self.apps_back_btn: Locator = page.locator("#apps-back-btn") + # Navigation self.back_selection_btn: Locator = page.locator("#back-selection-btn") @@ -47,6 +56,14 @@ def click_connectors(self) -> None: """Click the Connectors selection card to navigate to the clinical workflows view.""" self.connectors_card.click() + def click_connector_apps(self) -> None: + """Click the Connector Apps selection card to navigate to the apps sub-menu.""" + self.connector_apps_card.click() + + def go_back_from_apps(self) -> None: + """Click the back button inside the Connector Apps sub-menu.""" + self.apps_back_btn.click() + def go_back_to_selection(self) -> None: """Click the back button to return to the selection page.""" self.back_selection_btn.click() diff --git a/tests/playground/test_playground_integration.py b/tests/playground/test_playground_integration.py index f0244d1..e04df1e 100644 --- a/tests/playground/test_playground_integration.py +++ b/tests/playground/test_playground_integration.py @@ -34,7 +34,7 @@ def test_playground_home_page_flow(playground_page: Page) -> None: expect(home.header_actions).to_be_hidden() # 3. Assert card counts and detailed card contents - assert home.selection_cards.count() == 2 + assert home.selection_cards.count() == 3 # Agentic Workflow Card expect(home.agentic_card).to_be_visible() @@ -46,6 +46,11 @@ def test_playground_home_page_flow(playground_page: Page) -> None: expect(home.connectors_card_title).to_have_text("Connectors") expect(home.connectors_card_desc).to_contain_text("Pre-built Clinical Workflows") + # Connector Apps Card + expect(home.connector_apps_card).to_be_visible() + expect(home.connector_apps_card_title).to_have_text("Connector Apps") + expect(home.connector_apps_card_desc).to_contain_text("built on top of connectors") + # 4. Test Navigation Flow: Root -> Agentic Workflow -> Root home.click_agentic_workflow() expect(home.root_selection_view).to_be_hidden() @@ -66,7 +71,17 @@ def test_playground_home_page_flow(playground_page: Page) -> None: expect(home.root_selection_view).to_be_visible() expect(home.main_layout).to_be_hidden() - # 6. Optional visual delay for headed mode + # 6. Test Navigation Flow: Root -> Connector Apps -> Root + home.click_connector_apps() + expect(home.root_selection_view).to_be_hidden() + expect(home.connector_apps_view).to_be_visible() + + # Return back to home + home.go_back_from_apps() + expect(home.root_selection_view).to_be_visible() + expect(home.connector_apps_view).to_be_hidden() + + # 7. Optional visual delay for headed mode is_headed = os.getenv("PLAYGROUND_HEADED") or os.getenv("HEADED") if is_headed and is_headed.lower().strip() in ("true", "1", "yes"): time.sleep(5) From 2f05fb918844876f89f741a39003893b9fc3ed42 Mon Sep 17 00:00:00 2001 From: rahul-aot Date: Mon, 25 May 2026 01:20:42 -0700 Subject: [PATCH 45/60] Add integration tests for Stripe connector in Playground --- .github/workflows/pytest.yml | 9 +- tests/playground/stripe/README.md | 78 +++++ tests/playground/stripe/conftest.py | 63 ++++ tests/playground/stripe/stripe_page.py | 139 +++++++++ .../stripe/test_stripe_integration.py | 273 ++++++++++++++++++ 5 files changed, 558 insertions(+), 4 deletions(-) create mode 100644 tests/playground/stripe/README.md create mode 100644 tests/playground/stripe/conftest.py create mode 100644 tests/playground/stripe/stripe_page.py create mode 100644 tests/playground/stripe/test_stripe_integration.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index ecf25f8..c811975 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -53,9 +53,7 @@ jobs: path: htmlcov/ if-no-files-found: ignore - # ── Playground integration tests ────────────────────────────────────────── - # Runs only on manual workflow_dispatch. Requires real Google Drive - # credentials stored as repository secrets (see tests/playground/README.md). + playground-integration: name: Playground integration tests runs-on: ubuntu-latest @@ -88,9 +86,12 @@ jobs: GOOGLE_DRIVE_SA_JSON: ${{ secrets.GOOGLE_DRIVE_SA_JSON }} GOOGLE_DRIVE_FOLDER_ID: ${{ secrets.GOOGLE_DRIVE_FOLDER_ID }} GDRIVE_TEST_RECIPIENT_EMAIL: ${{ secrets.GDRIVE_TEST_RECIPIENT_EMAIL }} + STRIPE_API_KEY: ${{ secrets.STRIPE_API_KEY }} + STRIPE_TEST_CUSTOMER_ID: ${{ secrets.STRIPE_TEST_CUSTOMER_ID }} + STRIPE_TEST_PRICE_ID: ${{ secrets.STRIPE_TEST_PRICE_ID }} NW_REST_AUTH_DISABLED: "true" NW_REST_LOAD_DOTENV: "false" - NW_ALLOWED_CONNECTORS: "google_drive" + NW_ALLOWED_CONNECTORS: "google_drive,stripe" run: uv run pytest tests/playground/ --no-cov -v - name: Upload Playwright traces on failure diff --git a/tests/playground/stripe/README.md b/tests/playground/stripe/README.md new file mode 100644 index 0000000..9292c8e --- /dev/null +++ b/tests/playground/stripe/README.md @@ -0,0 +1,78 @@ + + +# Stripe Playground Integration Tests + +End-to-end Playwright tests that open the Playground UI in a real browser, +navigate to the Stripe connector panel, and assert on the rendered pipeline +state. No mocking — every test hits the real Stripe test-mode API. + +## What is tested + +| Test | Action | +|------|--------| +| `test_stripe_charge_default` | `charge` — default values (2000 usd) | +| `test_stripe_charge_custom_amount` | `charge` — custom amount + description | +| `test_stripe_charge_no_description` | `charge` — empty description | +| `test_stripe_payment_intent_default` | `payment_intent` — defaults (5000 usd, pm_card_visa) | +| `test_stripe_payment_intent_custom_amount` | `payment_intent` — custom amount, result tag contains pi_ | +| `test_stripe_payment_intent_no_payment_method` | `payment_intent` — no payment method | +| `test_stripe_cancel_subscription_invalid_id` | `cancel_subscription` — nonexistent ID, expects error state | +| `test_stripe_cancel_subscription` | `cancel_subscription` — real subscription ID (requires env vars) | +| `test_stripe_refund_by_charge_id` | `refund` — full refund against a real charge | +| `test_stripe_refund_invalid_id` | `refund` — nonexistent ID, expects error state | +| `test_stripe_switch_charge_then_payment_intent` | Cross-action switch on same page | + +## How it works + +The test session starts a real FastAPI server on a random local port. Playwright +navigates to `/playground/`. The browser's `fetch("/scenarios/stripe-*")` calls +route to the real backend, which calls the real Stripe test-mode API. +No `page.route()` interception. + +## Running locally + +```bash +# Install Playwright browsers (once) +uv run python -m playwright install chromium + +# Run all Stripe tests +uv run pytest tests/playground/stripe/ --no-cov -v + +# Run headed (watch the browser) +PLAYGROUND_HEADED=true uv run pytest tests/playground/stripe/ --no-cov -v -s +``` + +> **Note:** Stripe tests are excluded from the default `uv run pytest` run and +> from regular CI (push/PR). They must be triggered explicitly. + +## Required environment variables + +Set these before running (`.env` is loaded automatically if present): + +| Variable | Description | +|----------|-------------| +| `STRIPE_API_KEY` | Stripe secret key (`sk_test_...`) | +| `NW_REST_AUTH_DISABLED` | Set to `true` to skip REST auth middleware | + +## Optional environment variables (for subscription tests) + +| Variable | Description | +|----------|-------------| +| `STRIPE_TEST_CUSTOMER_ID` | Pre-existing Stripe test customer (`cus_...`) | +| `STRIPE_TEST_PRICE_ID` | Pre-existing Stripe test price (`price_...`) | + +`test_stripe_cancel_subscription` is automatically skipped when these are absent. + +## Test data and cleanup + +The `real_stripe_charge_id` session fixture creates a small charge (`$5.00 usd`) +against the `tok_visa` test token once per session. The `test_stripe_refund_by_charge_id` +test immediately refunds this charge in full, so no balance is left outstanding. + +The optional `real_stripe_subscription_id` fixture creates a subscription that +is cancelled by `test_stripe_cancel_subscription` — leaving no active subscription +after the session. diff --git a/tests/playground/stripe/conftest.py b/tests/playground/stripe/conftest.py new file mode 100644 index 0000000..9ad006a --- /dev/null +++ b/tests/playground/stripe/conftest.py @@ -0,0 +1,63 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import os + +import httpx +import pytest + + +@pytest.fixture(scope="session") +def real_stripe_charge_id(api_server_url: str) -> str: + """Create a real Stripe test charge via the API and return its charge ID. + + Uses the default tok_visa source hardcoded in StripeChargeInput so no extra + env vars are needed beyond STRIPE_API_KEY. + """ + with httpx.Client(timeout=30) as client: + resp = client.post( + f"{api_server_url}/scenarios/stripe-charge", + json={"amount": 500, "currency": "usd", "description": "nw-integration-test charge"}, + ) + resp.raise_for_status() + data = resp.json() + charge_id = data.get("final_resource_id") + if not charge_id: + pytest.skip( + f"Stripe charge setup failed — cannot run refund tests. " + f"Error: {data.get('error_message') or 'no charge_id returned'}" + ) + return charge_id + + +@pytest.fixture(scope="session") +def real_stripe_subscription_id(api_server_url: str) -> str: + """Create a real Stripe subscription and return its subscription ID. + + Requires STRIPE_TEST_CUSTOMER_ID and STRIPE_TEST_PRICE_ID env vars. + Tests that use this fixture are skipped when the vars are absent. + """ + customer_id = os.environ.get("STRIPE_TEST_CUSTOMER_ID") + price_id = os.environ.get("STRIPE_TEST_PRICE_ID") + if not customer_id or not price_id: + pytest.skip( + "STRIPE_TEST_CUSTOMER_ID and STRIPE_TEST_PRICE_ID are required for subscription tests" + ) + + with httpx.Client(timeout=30) as client: + resp = client.post( + f"{api_server_url}/scenarios/stripe-subscription", + json={"customer_id": customer_id, "price_id": price_id}, + ) + resp.raise_for_status() + data = resp.json() + sub_id = data.get("final_resource_id") + if not sub_id: + pytest.skip( + f"Stripe subscription setup failed — cannot run cancel tests. " + f"Error: {data.get('error_message') or 'no subscription_id returned'}" + ) + return sub_id diff --git a/tests/playground/stripe/stripe_page.py b/tests/playground/stripe/stripe_page.py new file mode 100644 index 0000000..e2b8416 --- /dev/null +++ b/tests/playground/stripe/stripe_page.py @@ -0,0 +1,139 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from playwright.sync_api import Page, Locator + + +class StripePage: + """Page Object Model for the Stripe connector panel in the Playground.""" + + def __init__(self, page: Page) -> None: + self.page = page + + # Selector for the Stripe card inside system connectors view + self.connector_card: Locator = page.locator(".connector-card[data-mode='stripe']") + + # Panel root and main headers + self.panel: Locator = page.locator("#stripe-panel") + self.title: Locator = page.locator("#stripe-panel .card-title h2") + self.action_select: Locator = page.locator("#stripe-action-select") + self.run_btn: Locator = page.locator("#stripe-run-btn") + self.back_to_connectors: Locator = page.locator("#back-to-connectors") + + # --- charge action elements --- + self.charge_section: Locator = page.locator("#stripe-section-charge") + self.charge_amount: Locator = page.locator( + "#stripe-section-charge input[name='charge_amount']" + ) + self.charge_currency: Locator = page.locator( + "#stripe-section-charge input[name='charge_currency']" + ) + self.charge_description: Locator = page.locator( + "#stripe-section-charge input[name='charge_description']" + ) + + # --- payment_intent action elements --- + self.pi_section: Locator = page.locator("#stripe-section-pi") + self.pi_amount: Locator = page.locator("#stripe-section-pi input[name='pi_amount']") + self.pi_currency: Locator = page.locator("#stripe-section-pi input[name='pi_currency']") + self.pi_customer: Locator = page.locator("#stripe-section-pi input[name='pi_customer']") + self.pi_payment_method: Locator = page.locator( + "#stripe-section-pi input[name='pi_payment_method']" + ) + + # --- subscription action elements --- + self.sub_section: Locator = page.locator("#stripe-section-sub") + self.sub_customer: Locator = page.locator( + "#stripe-section-sub input[name='sub_customer']" + ) + self.sub_price: Locator = page.locator("#stripe-section-sub input[name='sub_price']") + + # --- cancel_subscription action elements --- + self.cancel_section: Locator = page.locator("#stripe-section-cancel") + self.cancel_sub_id: Locator = page.locator( + "#stripe-section-cancel input[name='cancel_sub_id']" + ) + + # --- refund action elements --- + self.refund_section: Locator = page.locator("#stripe-section-refund") + self.refund_target_id: Locator = page.locator( + "#stripe-section-refund input[name='refund_target_id']" + ) + self.refund_amount: Locator = page.locator( + "#stripe-section-refund input[name='refund_amount']" + ) + + # --- Output and Logs elements --- + self.pipeline_steps: Locator = page.locator(".flow-node") + self.step_nodes: list[Locator] = [page.locator(f"#step-{i}") for i in range(3)] + self.final_result: Locator = page.locator("#final-result") + self.summary_text: Locator = page.locator("#human-summary") + self.result_tag: Locator = page.locator("#result-id") + self.log_terminal: Locator = page.locator("#log-terminal") + + def navigate_to_panel(self) -> None: + """Click the Stripe card in system connectors to open the panel.""" + self.connector_card.click() + + def select_action(self, action: str) -> None: + """Change the action via the select element.""" + self.action_select.select_option(action) + + def fill_charge_fields( + self, + amount: int | None = None, + currency: str | None = None, + description: str | None = None, + ) -> None: + """Fill charge parameters (all optional — HTML defaults apply if not provided).""" + if amount is not None: + self.charge_amount.fill(str(amount)) + if currency is not None: + self.charge_currency.fill(currency) + if description is not None: + self.charge_description.fill(description) + + def fill_payment_intent_fields( + self, + amount: int | None = None, + currency: str | None = None, + customer_id: str | None = None, + payment_method: str | None = None, + ) -> None: + """Fill payment intent parameters.""" + if amount is not None: + self.pi_amount.fill(str(amount)) + if currency is not None: + self.pi_currency.fill(currency) + if customer_id is not None: + self.pi_customer.fill(customer_id) + if payment_method is not None: + self.pi_payment_method.fill(payment_method) + + def fill_subscription_fields(self, customer_id: str, price_id: str) -> None: + """Fill subscription parameters.""" + self.sub_customer.fill(customer_id) + self.sub_price.fill(price_id) + + def fill_cancel_fields(self, subscription_id: str) -> None: + """Fill cancel subscription parameter.""" + self.cancel_sub_id.fill(subscription_id) + + def fill_refund_fields( + self, target_id: str, amount: int | None = None + ) -> None: + """Fill refund parameters. target_id may be a ch_... or pi_... ID.""" + self.refund_target_id.fill(target_id) + if amount is not None: + self.refund_amount.fill(str(amount)) + + def submit(self) -> None: + """Submit the form to execute the Stripe workflow.""" + self.run_btn.click() + + def go_back(self) -> None: + """Click 'Back to All Connectors' to return to connectors selection view.""" + self.back_to_connectors.click() diff --git a/tests/playground/stripe/test_stripe_integration.py b/tests/playground/stripe/test_stripe_integration.py new file mode 100644 index 0000000..37c6fd9 --- /dev/null +++ b/tests/playground/stripe/test_stripe_integration.py @@ -0,0 +1,273 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Stripe connector Playground real integration tests. + +Each test opens the Playground UI, navigates to the Stripe panel, +selects an action, fills the form, clicks the run button, and asserts the +resulting pipeline state — no API mocking, real Stripe test-mode calls. + +Required env vars (loaded from .env): + STRIPE_API_KEY — Stripe secret key (sk_test_...) + +Optional env vars (for subscription-related tests): + STRIPE_TEST_CUSTOMER_ID — pre-existing Stripe test customer (cus_...) + STRIPE_TEST_PRICE_ID — pre-existing Stripe test price (price_...) +""" + +from __future__ import annotations + +import os +import time + +from playwright.sync_api import Page, expect + +from tests.playground.stripe.stripe_page import StripePage +from tests.playground.home_page import PlaygroundHomePage + +_TIMEOUT = 20_000 # ms — all Stripe scenarios are 3-step + + +def _navigate_to_stripe(page: Page) -> StripePage: + PlaygroundHomePage(page).click_connectors() + stripe = StripePage(page) + stripe.navigate_to_panel() + return stripe + + +def _maybe_sleep() -> None: + env = os.getenv("PLAYGROUND_HEADED") or os.getenv("HEADED") + if env and env.lower().strip() in ("true", "1", "yes"): + time.sleep(3) + + +# ── charge ──────────────────────────────────────────────────────────────────── + + +def test_stripe_charge_default(playground_page: Page) -> None: + """Process a charge with the HTML default values (2000 usd); all 3 steps must succeed.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("charge") + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(stripe.summary_text).to_contain_text("20.00 USD charge") + expect(stripe.result_tag).to_be_visible() + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(stripe.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +def test_stripe_charge_custom_amount(playground_page: Page) -> None: + """Process a charge with a custom amount and description; summary must reflect the amount.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("charge") + stripe.fill_charge_fields(amount=1500, currency="usd", description="nw-test charge") + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(stripe.summary_text).to_contain_text("15.00 USD charge") + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(stripe.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +def test_stripe_charge_no_description(playground_page: Page) -> None: + """Process a charge with an empty description; pipeline must still complete successfully.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("charge") + stripe.fill_charge_fields(amount=1000, currency="usd", description="") + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + + _maybe_sleep() + + +# ── payment_intent ──────────────────────────────────────────────────────────── + + +def test_stripe_payment_intent_default(playground_page: Page) -> None: + """Create a payment intent with the HTML defaults (5000 usd, pm_card_visa); 3 steps succeed.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("payment_intent") + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(stripe.summary_text).to_contain_text("payment intent") + expect(stripe.result_tag).to_be_visible() + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(stripe.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +def test_stripe_payment_intent_custom_amount(playground_page: Page) -> None: + """Create a payment intent with a custom amount; result tag must contain a pi_ ID.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("payment_intent") + stripe.fill_payment_intent_fields(amount=3000, currency="usd") + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(stripe.result_tag).to_contain_text("pi_") + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + + _maybe_sleep() + + +def test_stripe_payment_intent_no_payment_method(playground_page: Page) -> None: + """Create a payment intent without a payment method; backend creates a requires_payment_method PI.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("payment_intent") + stripe.fill_payment_intent_fields(amount=2500, currency="usd", payment_method="") + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + + _maybe_sleep() + + +# ── cancel_subscription ─────────────────────────────────────────────────────── + + +def test_stripe_cancel_subscription_invalid_id(playground_page: Page) -> None: + """Cancel with a nonexistent subscription ID; step-1 must show error state.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("cancel_subscription") + stripe.fill_cancel_fields("sub_this_does_not_exist_9999") + stripe.submit() + + # step-0 (Locate Resource) is a validation step — it always succeeds + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + # step-1 (Cancel Sub) calls the real Stripe API — it must fail + expect(playground_page.locator("#step-1.error")).to_be_visible(timeout=_TIMEOUT) + expect(stripe.final_result).to_be_hidden() + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Failed") + expect(stripe.log_terminal).to_contain_text("FAILED") + + _maybe_sleep() + + +def test_stripe_cancel_subscription( + playground_page: Page, real_stripe_subscription_id: str +) -> None: + """Cancel a real subscription; all 3 steps must succeed.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("cancel_subscription") + stripe.fill_cancel_fields(real_stripe_subscription_id) + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(stripe.summary_text).to_contain_text("canceled subscription") + expect(stripe.result_tag).to_contain_text(real_stripe_subscription_id) + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(stripe.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +# ── refund ──────────────────────────────────────────────────────────────────── + + +def test_stripe_refund_by_charge_id( + playground_page: Page, real_stripe_charge_id: str +) -> None: + """Issue a full refund against a real charge; all 3 steps must succeed.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("refund") + stripe.fill_refund_fields(real_stripe_charge_id) + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(stripe.summary_text).to_contain_text("issued refund") + expect(stripe.result_tag).to_be_visible() + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(stripe.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +def test_stripe_refund_invalid_id(playground_page: Page) -> None: + """Refund with a nonexistent charge ID; step-1 must show error state.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("refund") + stripe.fill_refund_fields("ch_this_does_not_exist_9999") + stripe.submit() + + # step-0 (Validate Params) is a local validation step — it always succeeds + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + # step-1 (Issue Refund) calls the real Stripe API — it must fail + expect(playground_page.locator("#step-1.error")).to_be_visible(timeout=_TIMEOUT) + expect(stripe.final_result).to_be_hidden() + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Failed") + expect(stripe.log_terminal).to_contain_text("FAILED") + + _maybe_sleep() + + +# ── cross-action switch ─────────────────────────────────────────────────────── + + +def test_stripe_switch_charge_then_payment_intent(playground_page: Page) -> None: + """Run a charge, then switch to payment_intent on the same page — both must succeed.""" + stripe = _navigate_to_stripe(playground_page) + + # First run: charge + stripe.select_action("charge") + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + + # Switch action and run again + stripe.select_action("payment_intent") + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(stripe.summary_text).to_contain_text("payment intent") + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + + _maybe_sleep() From e8fb36010ef2bec302fcd1779f6327d963f5e034 Mon Sep 17 00:00:00 2001 From: rahul-aot Date: Mon, 25 May 2026 02:12:19 -0700 Subject: [PATCH 46/60] Refactor StripePage and test_stripe_integration for improved readability --- tests/playground/stripe/stripe_page.py | 8 ++------ tests/playground/stripe/test_stripe_integration.py | 4 +--- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/playground/stripe/stripe_page.py b/tests/playground/stripe/stripe_page.py index e2b8416..cc8e883 100644 --- a/tests/playground/stripe/stripe_page.py +++ b/tests/playground/stripe/stripe_page.py @@ -46,9 +46,7 @@ def __init__(self, page: Page) -> None: # --- subscription action elements --- self.sub_section: Locator = page.locator("#stripe-section-sub") - self.sub_customer: Locator = page.locator( - "#stripe-section-sub input[name='sub_customer']" - ) + self.sub_customer: Locator = page.locator("#stripe-section-sub input[name='sub_customer']") self.sub_price: Locator = page.locator("#stripe-section-sub input[name='sub_price']") # --- cancel_subscription action elements --- @@ -122,9 +120,7 @@ def fill_cancel_fields(self, subscription_id: str) -> None: """Fill cancel subscription parameter.""" self.cancel_sub_id.fill(subscription_id) - def fill_refund_fields( - self, target_id: str, amount: int | None = None - ) -> None: + def fill_refund_fields(self, target_id: str, amount: int | None = None) -> None: """Fill refund parameters. target_id may be a ch_... or pi_... ID.""" self.refund_target_id.fill(target_id) if amount is not None: diff --git a/tests/playground/stripe/test_stripe_integration.py b/tests/playground/stripe/test_stripe_integration.py index 37c6fd9..7527692 100644 --- a/tests/playground/stripe/test_stripe_integration.py +++ b/tests/playground/stripe/test_stripe_integration.py @@ -204,9 +204,7 @@ def test_stripe_cancel_subscription( # ── refund ──────────────────────────────────────────────────────────────────── -def test_stripe_refund_by_charge_id( - playground_page: Page, real_stripe_charge_id: str -) -> None: +def test_stripe_refund_by_charge_id(playground_page: Page, real_stripe_charge_id: str) -> None: """Issue a full refund against a real charge; all 3 steps must succeed.""" stripe = _navigate_to_stripe(playground_page) From 34bc971e9a9b88546b91e99edfa6d24a1da07ad2 Mon Sep 17 00:00:00 2001 From: rahul-aot Date: Tue, 26 May 2026 01:06:49 -0700 Subject: [PATCH 47/60] Add integration tests for Salesforce connector in Playground --- .github/workflows/pytest.yml | 18 +- tests/playground/conftest.py | 14 + tests/playground/salesforce/README.md | 108 ++++++ tests/playground/salesforce/__init__.py | 0 tests/playground/salesforce/conftest.py | 88 +++++ .../playground/salesforce/salesforce_page.py | 155 ++++++++ .../salesforce/test_salesforce_integration.py | 364 ++++++++++++++++++ 7 files changed, 744 insertions(+), 3 deletions(-) create mode 100644 tests/playground/salesforce/README.md create mode 100644 tests/playground/salesforce/__init__.py create mode 100644 tests/playground/salesforce/conftest.py create mode 100644 tests/playground/salesforce/salesforce_page.py create mode 100644 tests/playground/salesforce/test_salesforce_integration.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index ecf25f8..b554d1e 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -54,8 +54,9 @@ jobs: if-no-files-found: ignore # ── Playground integration tests ────────────────────────────────────────── - # Runs only on manual workflow_dispatch. Requires real Google Drive - # credentials stored as repository secrets (see tests/playground/README.md). + # Runs only on manual workflow_dispatch. Requires real connector credentials + # stored as repository secrets (see tests/playground/gdrive/README.md and + # tests/playground/salesforce/README.md). playground-integration: name: Playground integration tests runs-on: ubuntu-latest @@ -85,12 +86,23 @@ jobs: - name: Run playground integration tests env: + + #gdrive GOOGLE_DRIVE_SA_JSON: ${{ secrets.GOOGLE_DRIVE_SA_JSON }} GOOGLE_DRIVE_FOLDER_ID: ${{ secrets.GOOGLE_DRIVE_FOLDER_ID }} GDRIVE_TEST_RECIPIENT_EMAIL: ${{ secrets.GDRIVE_TEST_RECIPIENT_EMAIL }} + + #salesforce + SALESFORCE_INSTANCE_URL: ${{ secrets.SALESFORCE_INSTANCE_URL }} + SALESFORCE_TOKEN_URL: ${{ secrets.SALESFORCE_TOKEN_URL }} + SALESFORCE_CLIENT_ID: ${{ secrets.SALESFORCE_CLIENT_ID }} + SALESFORCE_CLIENT_SECRET: ${{ secrets.SALESFORCE_CLIENT_SECRET }} + SALESFORCE_REFRESH_TOKEN: ${{ secrets.SALESFORCE_REFRESH_TOKEN }} + + # Disable authentication and dotenv loading for playground tests, and restrict connectors NW_REST_AUTH_DISABLED: "true" NW_REST_LOAD_DOTENV: "false" - NW_ALLOWED_CONNECTORS: "google_drive" + NW_ALLOWED_CONNECTORS: "google_drive,salesforce," run: uv run pytest tests/playground/ --no-cov -v - name: Upload Playwright traces on failure diff --git a/tests/playground/conftest.py b/tests/playground/conftest.py index caabf0a..641d81f 100644 --- a/tests/playground/conftest.py +++ b/tests/playground/conftest.py @@ -8,11 +8,25 @@ import socket import threading import time +from pathlib import Path from dotenv import load_dotenv import httpx import pytest +_REPO_ROOT = Path(__file__).resolve().parent.parent.parent + +# tests/conftest.py restricts NW_ALLOWED_CONNECTORS to the narrow CI-safe set and +# points NW_CONFIG_PATH at a fixture yaml that disables salesforce/slack. +# Playground integration tests hit real external services and need the full allowlist +# and the real config/connectors.yaml, so override those values here before any app +# import occurs. +os.environ["NW_ALLOWED_CONNECTORS"] = ( + "http_generic,smtp,stripe,google_drive,fhir_epic,fhir_cerner,salesforce,slack" +) +os.environ["NW_CONFIG_PATH"] = str(_REPO_ROOT / "config" / "connectors.yaml") +os.environ["NW_REST_LOAD_DOTENV"] = "true" + # Load .env before any app imports so connectors initialise with real credentials. load_dotenv(override=False) diff --git a/tests/playground/salesforce/README.md b/tests/playground/salesforce/README.md new file mode 100644 index 0000000..ed67970 --- /dev/null +++ b/tests/playground/salesforce/README.md @@ -0,0 +1,108 @@ + + +# Salesforce CRM Playground Integration Tests + +End-to-end Playwright tests that open the Playground UI in a real browser, +navigate to the Salesforce CRM connector panel, and assert on the rendered +pipeline state. No mocking — every test hits the real Salesforce API. + +## What is tested + +| Test | Action | +|------|--------| +| `test_sf_create_lead_minimal` | `create_lead` — required fields only (LastName + Company) | +| `test_sf_create_lead_full` | `create_lead` — with first name and email | +| `test_sf_create_contact_minimal` | `create_contact` — required field only (LastName) | +| `test_sf_create_contact_with_email` | `create_contact` — with first name and email | +| `test_sf_read_lead` | `read_lead` — valid Lead ID, asserts success state | +| `test_sf_read_lead_invalid_id` | `read_lead` — nonexistent ID, expects error state | +| `test_sf_read_contact` | `read_contact` — valid Contact ID, asserts success state | +| `test_sf_read_contact_invalid_id` | `read_contact` — nonexistent ID, expects error state | +| `test_sf_update_lead` | `update_lead` — rename + company change | +| `test_sf_update_lead_email` | `update_lead` — email-only update | +| `test_sf_update_contact` | `update_contact` — name update | +| `test_sf_update_contact_email` | `update_contact` — email-only update | +| `test_sf_delete_lead` | `delete_lead` — delete a freshly created Lead | +| `test_sf_delete_contact` | `delete_contact` — delete a freshly created Contact | +| `test_sf_switch_create_lead_to_read` | Cross-action switch on same page | + +## How it works + +The test session starts a real FastAPI server on a random local port. Playwright +navigates to `/playground/`. The browser's `fetch("/scenarios/salesforce-*")` +calls route to the real backend, which calls the real Salesforce API via OAuth2 +refresh token. No `page.route()` interception. + +Session fixtures create Lead and Contact records once via the REST API for use +across read and update tests. Delete tests each create their own fresh record. +All generated names and emails use random suffixes (e.g. `Lead839201`, +`test748203@mailinator.com`) so repeated runs never collide on duplicate values. + +## Running locally + +```bash +# Install Playwright browsers (once) +uv run python -m playwright install chromium + +# Run all Salesforce tests +uv run pytest tests/playground/salesforce/ --no-cov -v + +# Run headed (watch the browser) +PLAYGROUND_HEADED=true uv run pytest tests/playground/salesforce/ --no-cov -v -s + +# Run a single test +uv run pytest tests/playground/salesforce/ --no-cov -v -k test_sf_create_lead_minimal +``` + +> **Note:** Salesforce tests are excluded from the default `uv run pytest` run and +> from regular CI (push/PR). They must be triggered explicitly. + +## Required environment variables + +Set these before running (`.env` is loaded automatically if present): + +| Variable | Description | +|----------|-------------| +| `SALESFORCE_INSTANCE_URL` | Your Salesforce org URL, e.g. `https://orgname.my.salesforce.com` | +| `SALESFORCE_TOKEN_URL` | OAuth2 token endpoint, e.g. `https://login.salesforce.com/services/oauth2/token` | +| `SALESFORCE_CLIENT_ID` | Connected App client ID | +| `SALESFORCE_CLIENT_SECRET` | Connected App client secret | +| `SALESFORCE_REFRESH_TOKEN` | Long-lived OAuth2 refresh token | +| `NW_REST_AUTH_DISABLED` | Set to `true` to skip REST auth middleware | + +## CI / GitHub Actions + +Salesforce tests run **only on manual `workflow_dispatch`** trigger +(`Actions → CI – Pytest → Run workflow`). + +Credentials are read from repository secrets: + +| Secret | Maps to env var | +|--------|----------------| +| `SALESFORCE_INSTANCE_URL` | `SALESFORCE_INSTANCE_URL` | +| `SALESFORCE_TOKEN_URL` | `SALESFORCE_TOKEN_URL` | +| `SALESFORCE_CLIENT_ID` | `SALESFORCE_CLIENT_ID` | +| `SALESFORCE_CLIENT_SECRET` | `SALESFORCE_CLIENT_SECRET` | +| `SALESFORCE_REFRESH_TOKEN` | `SALESFORCE_REFRESH_TOKEN` | + +Set these under **Settings → Secrets and variables → Actions** before triggering +the workflow. + +## Test data and cleanup + +The `real_sf_lead_id` and `real_sf_contact_id` session fixtures create one Lead +and one Contact in Salesforce at the start of the test session. These records are +**not automatically deleted** after the tests finish — clean them up manually via +the Salesforce UI or Developer Console if needed. Look for records with names +matching the pattern `IntegLead` and `IntegContact`. + +The `deletable_lead_id` and `deletable_contact_id` fixtures create a fresh record +per delete test and those records are consumed (deleted) by the test itself. + +Update tests mutate the session-scoped records in place (name, email). Because +Salesforce does not enforce unique constraints on Lead/Contact names, this is safe +to run multiple times without conflicts. diff --git a/tests/playground/salesforce/__init__.py b/tests/playground/salesforce/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/playground/salesforce/conftest.py b/tests/playground/salesforce/conftest.py new file mode 100644 index 0000000..331c057 --- /dev/null +++ b/tests/playground/salesforce/conftest.py @@ -0,0 +1,88 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import random + +import httpx +import pytest + + +def _rnd() -> str: + return str(random.randint(100_000, 999_999)) + + +def _email() -> str: + return f"test{_rnd()}@mailinator.com" + + +def _create_lead(api_server_url: str, last_name: str, company: str) -> str: + """Create a Salesforce Lead via the REST API and return its record ID.""" + with httpx.Client(timeout=30) as client: + resp = client.post( + f"{api_server_url}/scenarios/salesforce-create-lead", + json={"last_name": last_name, "company": company}, + ) + resp.raise_for_status() + data = resp.json() + record_id = data.get("final_resource_id") + if not record_id: + pytest.skip( + f"Salesforce Lead creation failed — cannot run dependent tests. " + f"Error: {data.get('error_message') or 'no record ID returned'}" + ) + return record_id + + +def _create_contact(api_server_url: str, last_name: str) -> str: + """Create a Salesforce Contact via the REST API and return its record ID.""" + with httpx.Client(timeout=30) as client: + resp = client.post( + f"{api_server_url}/scenarios/salesforce-create-contact", + json={"last_name": last_name}, + ) + resp.raise_for_status() + data = resp.json() + record_id = data.get("final_resource_id") + if not record_id: + pytest.skip( + f"Salesforce Contact creation failed — cannot run dependent tests. " + f"Error: {data.get('error_message') or 'no record ID returned'}" + ) + return record_id + + +@pytest.fixture(scope="session") +def real_sf_lead_id(api_server_url: str) -> str: + """Create a Salesforce Lead once per session for read and update tests. + + The Lead is left in Salesforce after the session (manual cleanup needed). + """ + return _create_lead(api_server_url, f"IntegLead{_rnd()}", f"Corp{_rnd()}") + + +@pytest.fixture(scope="session") +def real_sf_contact_id(api_server_url: str) -> str: + """Create a Salesforce Contact once per session for read and update tests. + + The Contact is left in Salesforce after the session (manual cleanup needed). + """ + return _create_contact(api_server_url, f"IntegContact{_rnd()}") + + +@pytest.fixture +def deletable_lead_id(api_server_url: str) -> str: + """Create a fresh Salesforce Lead per test for delete tests. + + Each invocation creates a new record so the delete test always operates + on an existing record. + """ + return _create_lead(api_server_url, f"DelLead{_rnd()}", f"Corp{_rnd()}") + + +@pytest.fixture +def deletable_contact_id(api_server_url: str) -> str: + """Create a fresh Salesforce Contact per test for delete tests.""" + return _create_contact(api_server_url, f"DelContact{_rnd()}") diff --git a/tests/playground/salesforce/salesforce_page.py b/tests/playground/salesforce/salesforce_page.py new file mode 100644 index 0000000..ddfaa3c --- /dev/null +++ b/tests/playground/salesforce/salesforce_page.py @@ -0,0 +1,155 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from playwright.sync_api import Page, Locator + + +class SalesforcePage: + """Page Object Model for the Salesforce CRM connector panel in the Playground.""" + + def __init__(self, page: Page) -> None: + self.page = page + + # Connector card inside system connectors view + self.connector_card: Locator = page.locator(".connector-card[data-mode='salesforce']") + + # Panel root and top-level controls + self.panel: Locator = page.locator("#salesforce-panel") + self.action_select: Locator = page.locator("#salesforce-action-select") + self.run_btn: Locator = page.locator("#salesforce-run-btn") + self.back_to_connectors: Locator = page.locator("#back-to-connectors") + + # --- Lead section (create_lead / update_lead) --- + self.lead_section: Locator = page.locator("#salesforce-section-lead") + self.lead_id: Locator = page.locator("#salesforce-section-lead input[name='lead_id']") + self.lead_first_name: Locator = page.locator( + "#salesforce-section-lead input[name='lead_first_name']" + ) + self.lead_last_name: Locator = page.locator( + "#salesforce-section-lead input[name='lead_last_name']" + ) + self.lead_company: Locator = page.locator( + "#salesforce-section-lead input[name='lead_company']" + ) + self.lead_email: Locator = page.locator( + "#salesforce-section-lead input[name='lead_email']" + ) + + # --- Contact section (create_contact / update_contact) --- + self.contact_section: Locator = page.locator("#salesforce-section-contact") + self.contact_id: Locator = page.locator( + "#salesforce-section-contact input[name='contact_id']" + ) + self.contact_first_name: Locator = page.locator( + "#salesforce-section-contact input[name='contact_first_name']" + ) + self.contact_last_name: Locator = page.locator( + "#salesforce-section-contact input[name='contact_last_name']" + ) + self.contact_email: Locator = page.locator( + "#salesforce-section-contact input[name='contact_email']" + ) + self.contact_account_id: Locator = page.locator( + "#salesforce-section-contact input[name='contact_account_id']" + ) + + # --- Generic ID section (read_lead / read_contact / delete_lead / delete_contact) --- + self.id_only_section: Locator = page.locator("#salesforce-section-id-only") + self.generic_record_id: Locator = page.locator( + "#salesforce-section-id-only input[name='generic_record_id']" + ) + + # --- Output / log elements (shared across connectors) --- + self.final_result: Locator = page.locator("#final-result") + self.summary_text: Locator = page.locator("#human-summary") + self.result_tag: Locator = page.locator("#result-id") + self.log_terminal: Locator = page.locator("#log-terminal") + + def navigate_to_panel(self) -> None: + """Click the Salesforce card in system connectors to open the panel.""" + self.connector_card.click() + + def select_action(self, action: str) -> None: + """Change the CRM action via the select element.""" + self.action_select.select_option(action) + + def fill_lead_fields( + self, + last_name: str, + company: str, + first_name: str | None = None, + email: str | None = None, + ) -> None: + """Fill Lead create form fields.""" + self.lead_last_name.fill(last_name) + self.lead_company.fill(company) + if first_name is not None: + self.lead_first_name.fill(first_name) + if email is not None: + self.lead_email.fill(email) + + def fill_lead_update_fields( + self, + record_id: str, + last_name: str | None = None, + company: str | None = None, + first_name: str | None = None, + email: str | None = None, + ) -> None: + """Fill Lead update form fields (record ID + any changed fields).""" + self.lead_id.fill(record_id) + if last_name is not None: + self.lead_last_name.fill(last_name) + if company is not None: + self.lead_company.fill(company) + if first_name is not None: + self.lead_first_name.fill(first_name) + if email is not None: + self.lead_email.fill(email) + + def fill_contact_fields( + self, + last_name: str, + first_name: str | None = None, + email: str | None = None, + account_id: str | None = None, + ) -> None: + """Fill Contact create form fields.""" + self.contact_last_name.fill(last_name) + if first_name is not None: + self.contact_first_name.fill(first_name) + if email is not None: + self.contact_email.fill(email) + if account_id is not None: + self.contact_account_id.fill(account_id) + + def fill_contact_update_fields( + self, + record_id: str, + last_name: str | None = None, + first_name: str | None = None, + email: str | None = None, + ) -> None: + """Fill Contact update form fields (record ID + any changed fields).""" + self.contact_id.fill(record_id) + if last_name is not None: + self.contact_last_name.fill(last_name) + if first_name is not None: + self.contact_first_name.fill(first_name) + if email is not None: + self.contact_email.fill(email) + + def fill_id_only(self, record_id: str) -> None: + """Fill the generic record ID field used by read/delete actions.""" + self.generic_record_id.fill(record_id) + + def submit(self) -> None: + """Click the run button to execute the selected CRM action.""" + self.run_btn.click() + + def go_back(self) -> None: + """Click 'Back to All Connectors' to return to the connectors selection view.""" + self.back_to_connectors.click() diff --git a/tests/playground/salesforce/test_salesforce_integration.py b/tests/playground/salesforce/test_salesforce_integration.py new file mode 100644 index 0000000..5e9c672 --- /dev/null +++ b/tests/playground/salesforce/test_salesforce_integration.py @@ -0,0 +1,364 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Salesforce CRM connector Playground real integration tests. + +Each test opens the Playground UI, navigates to the Salesforce panel, +selects an action, fills the form, clicks the run button, and asserts the +resulting pipeline state — no API mocking, real Salesforce calls. + +Required env vars (loaded from .env): + SALESFORCE_INSTANCE_URL — https://.my.salesforce.com + SALESFORCE_TOKEN_URL — OAuth2 token endpoint + SALESFORCE_CLIENT_ID — Connected App client ID + SALESFORCE_CLIENT_SECRET — Connected App client secret + SALESFORCE_REFRESH_TOKEN — Long-lived refresh token +""" + +from __future__ import annotations + +import os +import random +import time + +from playwright.sync_api import Page, expect + +from tests.playground.home_page import PlaygroundHomePage +from tests.playground.salesforce.salesforce_page import SalesforcePage + +_TIMEOUT = 20_000 # ms — all Salesforce operations are single-step + + +def _rnd() -> str: + """Return a 6-digit random suffix, unique enough to avoid duplicate-email rejections.""" + return str(random.randint(100_000, 999_999)) + + +def _email() -> str: + return f"test{_rnd()}@mailinator.com" + + +def _navigate_to_salesforce(page: Page) -> SalesforcePage: + PlaygroundHomePage(page).click_connectors() + sf = SalesforcePage(page) + sf.navigate_to_panel() + return sf + + +def _maybe_sleep() -> None: + env = os.getenv("PLAYGROUND_HEADED") or os.getenv("HEADED") + if env and env.lower().strip() in ("true", "1", "yes"): + time.sleep(3) + + +# ── create_lead ─────────────────────────────────────────────────────────────── + + +def test_sf_create_lead_minimal(playground_page: Page) -> None: + """Create a Lead with only the required fields (LastName + Company).""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("create_lead") + sf.fill_lead_fields(last_name=f"Lead{_rnd()}", company=f"Corp{_rnd()}") + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text("Lead created successfully") + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(sf.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +def test_sf_create_lead_full(playground_page: Page) -> None: + """Create a Lead with first name and email in addition to required fields.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("create_lead") + sf.fill_lead_fields( + last_name=f"Lead{_rnd()}", + company=f"Corp{_rnd()}", + first_name="John", + email=_email(), + ) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text("Lead created successfully") + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + + _maybe_sleep() + + +# ── create_contact ──────────────────────────────────────────────────────────── + + +def test_sf_create_contact_minimal(playground_page: Page) -> None: + """Create a Contact with only the required LastName field.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("create_contact") + sf.fill_contact_fields(last_name=f"Contact{_rnd()}") + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text("Contact created successfully") + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(sf.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +def test_sf_create_contact_with_email(playground_page: Page) -> None: + """Create a Contact with first name and email.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("create_contact") + sf.fill_contact_fields( + last_name=f"Contact{_rnd()}", + first_name="Jane", + email=_email(), + ) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text("Contact created successfully") + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + + _maybe_sleep() + + +# ── read_lead ───────────────────────────────────────────────────────────────── + + +def test_sf_read_lead(playground_page: Page, real_sf_lead_id: str) -> None: + """Retrieve metadata for a real Lead; assert single-step success and result card.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("read_lead") + sf.fill_id_only(real_sf_lead_id) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text(real_sf_lead_id) + expect(sf.result_tag).to_contain_text(real_sf_lead_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(sf.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +def test_sf_read_lead_invalid_id(playground_page: Page) -> None: + """read_lead with a nonexistent ID; pipeline step must show the error state.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("read_lead") + sf.fill_id_only("00Q000000000001AAA") + sf.submit() + + expect(playground_page.locator("#step-0.error")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_hidden() + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Failed") + expect(sf.log_terminal).to_contain_text("FAILED") + + _maybe_sleep() + + +# ── read_contact ────────────────────────────────────────────────────────────── + + +def test_sf_read_contact(playground_page: Page, real_sf_contact_id: str) -> None: + """Retrieve metadata for a real Contact; assert single-step success and result card.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("read_contact") + sf.fill_id_only(real_sf_contact_id) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text(real_sf_contact_id) + expect(sf.result_tag).to_contain_text(real_sf_contact_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(sf.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +def test_sf_read_contact_invalid_id(playground_page: Page) -> None: + """read_contact with a nonexistent ID; pipeline step must show the error state.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("read_contact") + sf.fill_id_only("003000000000001AAA") + sf.submit() + + expect(playground_page.locator("#step-0.error")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_hidden() + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Failed") + expect(sf.log_terminal).to_contain_text("FAILED") + + _maybe_sleep() + + +# ── update_lead ─────────────────────────────────────────────────────────────── + + +def test_sf_update_lead(playground_page: Page, real_sf_lead_id: str) -> None: + """Update a Lead's last name; assert single-step success and summary contains the record ID.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("update_lead") + sf.fill_lead_update_fields( + record_id=real_sf_lead_id, + last_name=f"Lead{_rnd()}", + company=f"Corp{_rnd()}", + ) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text("updated successfully") + expect(sf.result_tag).to_contain_text(real_sf_lead_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(sf.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +def test_sf_update_lead_email(playground_page: Page, real_sf_lead_id: str) -> None: + """Update only a Lead's email; assert success with result ID.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("update_lead") + sf.fill_lead_update_fields( + record_id=real_sf_lead_id, + email=_email(), + ) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.result_tag).to_contain_text(real_sf_lead_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + + _maybe_sleep() + + +# ── update_contact ──────────────────────────────────────────────────────────── + + +def test_sf_update_contact(playground_page: Page, real_sf_contact_id: str) -> None: + """Update a Contact's name; assert single-step success and summary contains the record ID.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("update_contact") + sf.fill_contact_update_fields( + record_id=real_sf_contact_id, + last_name=f"Contact{_rnd()}", + first_name="Updated", + ) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text("updated successfully") + expect(sf.result_tag).to_contain_text(real_sf_contact_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(sf.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +def test_sf_update_contact_email(playground_page: Page, real_sf_contact_id: str) -> None: + """Update only a Contact's email; assert success.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("update_contact") + sf.fill_contact_update_fields( + record_id=real_sf_contact_id, + email=_email(), + ) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.result_tag).to_contain_text(real_sf_contact_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + + _maybe_sleep() + + +# ── delete_lead ─────────────────────────────────────────────────────────────── + + +def test_sf_delete_lead(playground_page: Page, deletable_lead_id: str) -> None: + """Delete a Lead; assert single-step success and the record ID appears in the result.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("delete_lead") + sf.fill_id_only(deletable_lead_id) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text(deletable_lead_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(sf.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +# ── delete_contact ──────────────────────────────────────────────────────────── + + +def test_sf_delete_contact(playground_page: Page, deletable_contact_id: str) -> None: + """Delete a Contact; assert single-step success and the record ID appears in the result.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("delete_contact") + sf.fill_id_only(deletable_contact_id) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text(deletable_contact_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(sf.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +# ── cross-action switch ─────────────────────────────────────────────────────── + + +def test_sf_switch_create_lead_to_read(playground_page: Page, real_sf_lead_id: str) -> None: + """Create a Lead, then switch to read_lead on the same page — both must succeed.""" + sf = _navigate_to_salesforce(playground_page) + + # First run: create_lead + sf.select_action("create_lead") + sf.fill_lead_fields(last_name=f"Lead{_rnd()}", company=f"Corp{_rnd()}") + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text("Lead created successfully") + + # Switch action and run read_lead + sf.select_action("read_lead") + sf.fill_id_only(real_sf_lead_id) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text(real_sf_lead_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + + _maybe_sleep() From cd08cb83c9998a9dbde6c2cba2032611d4e9de87 Mon Sep 17 00:00:00 2001 From: Rahul Ap Date: Tue, 26 May 2026 01:10:06 -0700 Subject: [PATCH 48/60] Modify NW_ALLOWED_CONNECTORS in pytest workflow Updated NW_ALLOWED_CONNECTORS to include 'stripe'. --- .github/workflows/pytest.yml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 71ec6c9..6c9195f 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -93,9 +93,6 @@ jobs: STRIPE_API_KEY: ${{ secrets.STRIPE_API_KEY }} STRIPE_TEST_CUSTOMER_ID: ${{ secrets.STRIPE_TEST_CUSTOMER_ID }} STRIPE_TEST_PRICE_ID: ${{ secrets.STRIPE_TEST_PRICE_ID }} - NW_REST_AUTH_DISABLED: "true" - NW_REST_LOAD_DOTENV: "false" - NW_ALLOWED_CONNECTORS: "google_drive,stripe" #salesforce SALESFORCE_INSTANCE_URL: ${{ secrets.SALESFORCE_INSTANCE_URL }} @@ -107,7 +104,7 @@ jobs: # Disable authentication and dotenv loading for playground tests, and restrict connectors NW_REST_AUTH_DISABLED: "true" NW_REST_LOAD_DOTENV: "false" - NW_ALLOWED_CONNECTORS: "google_drive,salesforce," + NW_ALLOWED_CONNECTORS: "google_drive,salesforce,stripe" run: uv run pytest tests/playground/ --no-cov -v - name: Upload Playwright traces on failure From 5d76914b9f43423eeff0f005d7563c1c723b1282 Mon Sep 17 00:00:00 2001 From: rahul-aot Date: Tue, 26 May 2026 01:12:05 -0700 Subject: [PATCH 49/60] Refactor lead_email locator for improved readability --- tests/playground/salesforce/salesforce_page.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/playground/salesforce/salesforce_page.py b/tests/playground/salesforce/salesforce_page.py index ddfaa3c..344a531 100644 --- a/tests/playground/salesforce/salesforce_page.py +++ b/tests/playground/salesforce/salesforce_page.py @@ -34,9 +34,7 @@ def __init__(self, page: Page) -> None: self.lead_company: Locator = page.locator( "#salesforce-section-lead input[name='lead_company']" ) - self.lead_email: Locator = page.locator( - "#salesforce-section-lead input[name='lead_email']" - ) + self.lead_email: Locator = page.locator("#salesforce-section-lead input[name='lead_email']") # --- Contact section (create_contact / update_contact) --- self.contact_section: Locator = page.locator("#salesforce-section-contact") From e5a1e10b6eba77765cd0aa73d9f9e54e8f41bd6d Mon Sep 17 00:00:00 2001 From: rahul-aot Date: Mon, 25 May 2026 22:25:16 -0700 Subject: [PATCH 50/60] Add integration tests for Slack connector in Playground --- .github/workflows/pytest.yml | 7 +- tests/playground/slack/README.md | 78 ++++++ tests/playground/slack/conftest.py | 56 ++++ tests/playground/slack/slack_page.py | 90 +++++++ .../slack/test_slack_integration.py | 245 ++++++++++++++++++ 5 files changed, 475 insertions(+), 1 deletion(-) create mode 100644 tests/playground/slack/README.md create mode 100644 tests/playground/slack/conftest.py create mode 100644 tests/playground/slack/slack_page.py create mode 100644 tests/playground/slack/test_slack_integration.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 6c9195f..5710c26 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -101,10 +101,15 @@ jobs: SALESFORCE_CLIENT_SECRET: ${{ secrets.SALESFORCE_CLIENT_SECRET }} SALESFORCE_REFRESH_TOKEN: ${{ secrets.SALESFORCE_REFRESH_TOKEN }} + #slack + SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} + SLACK_TEST_CHANNEL: ${{ secrets.SLACK_TEST_CHANNEL }} + SLACK_TEST_USER_ID: ${{ secrets.SLACK_TEST_USER_ID }} + # Disable authentication and dotenv loading for playground tests, and restrict connectors NW_REST_AUTH_DISABLED: "true" NW_REST_LOAD_DOTENV: "false" - NW_ALLOWED_CONNECTORS: "google_drive,salesforce,stripe" + NW_ALLOWED_CONNECTORS: "google_drive,salesforce,stripe,slack" run: uv run pytest tests/playground/ --no-cov -v - name: Upload Playwright traces on failure diff --git a/tests/playground/slack/README.md b/tests/playground/slack/README.md new file mode 100644 index 0000000..016a5f3 --- /dev/null +++ b/tests/playground/slack/README.md @@ -0,0 +1,78 @@ + + +# Slack Playground Integration Tests + +End-to-end Playwright tests that open the Playground UI in a real browser, +navigate to the Slack connector panel, and assert on the rendered pipeline +state. No mocking — every test hits the real Slack API. + +## What is tested + +| Test | Action | +|------|--------| +| `test_slack_post_message_default` | `post_message` — default message to test channel | +| `test_slack_post_message_custom_message` | `post_message` — custom message content | +| `test_slack_post_message_invalid_channel` | `post_message` — nonexistent channel, expects error at step-1 | +| `test_slack_send_direct_message` | `send_direct_message` — DM to real user (requires `SLACK_TEST_USER_ID`) | +| `test_slack_upload_file` | `upload_file` — attach and upload a temp file | +| `test_slack_upload_remove_and_reattach` | `upload_file` — remove attachment UI, re-attach | +| `test_slack_switch_post_message_then_upload` | Cross-action switch on same page | + +## How it works + +The test session starts a real FastAPI server on a random local port. Playwright +navigates to `/playground/`. The browser's `fetch("/scenarios/slack-messaging")` +calls route to the real backend, which calls the real Slack API. +No `page.route()` interception. + +## Running locally + +```bash +# Install Playwright browsers (once) +uv run python -m playwright install chromium + +# Run all Slack tests +uv run pytest tests/playground/slack/ --no-cov -v + +# Run headed (watch the browser) +PLAYGROUND_HEADED=true uv run pytest tests/playground/slack/ --no-cov -v -s +``` + +> **Note:** Slack tests are excluded from the default `uv run pytest` run and +> from regular CI (push/PR). They must be triggered explicitly. + +## Required environment variables + +Set these before running (`.env` is loaded automatically if present): + +| Variable | Description | +|----------|-------------| +| `SLACK_BOT_TOKEN` | Slack bot token (`xoxb-...`) | +| `NW_REST_AUTH_DISABLED` | Set to `true` to skip REST auth middleware | + +## Optional environment variables + +| Variable | Description | Default | +|----------|-------------|---------| +| `SLACK_TEST_CHANNEL` | Target channel for post and upload tests | `#general` | +| `SLACK_TEST_USER_ID` | Slack user ID (`U...`) for DM tests | *(skipped if absent)* | + +The bot must be a member of `SLACK_TEST_CHANNEL`. +`test_slack_send_direct_message` is automatically skipped when `SLACK_TEST_USER_ID` is absent. + +## CI / GitHub Actions + +Slack tests run **only on manual `workflow_dispatch`** trigger alongside the other +playground integration tests. + +Credentials are read from repository secrets: + +| Secret | Maps to env var | +|--------|----------------| +| `SLACK_BOT_TOKEN` | `SLACK_BOT_TOKEN` | +| `SLACK_TEST_CHANNEL` | `SLACK_TEST_CHANNEL` | +| `SLACK_TEST_USER_ID` | `SLACK_TEST_USER_ID` | diff --git a/tests/playground/slack/conftest.py b/tests/playground/slack/conftest.py new file mode 100644 index 0000000..d15d903 --- /dev/null +++ b/tests/playground/slack/conftest.py @@ -0,0 +1,56 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import os + +import httpx +import pytest + +_DEFAULT_CHANNEL = os.environ.get("SLACK_TEST_CHANNEL", "#general") + + +@pytest.fixture(scope="session", autouse=True) +def slack_connector_available(api_server_url: str) -> None: + """Skip the entire Slack test session if the connector returns HTTP 500. + + This happens when SLACK_BOT_TOKEN is missing or when NW_ALLOWED_CONNECTORS + is set but does not include 'slack'. Converts a 25-second timeout per test + into a single fast skip with a clear reason. + """ + with httpx.Client(timeout=10) as client: + resp = client.post( + f"{api_server_url}/scenarios/slack-messaging", + json={"action": "post_message", "channel": "#general", "message": "health-check"}, + ) + if resp.status_code == 500: + detail = resp.json().get("detail", "unknown") + pytest.skip( + f"Slack connector not available ({detail}). " + "Ensure SLACK_BOT_TOKEN is set and 'slack' is in NW_ALLOWED_CONNECTORS (or leave it unset)." + ) + + +@pytest.fixture(scope="session") +def slack_test_channel() -> str: + """Slack channel used as the target for post_message and upload_file tests. + + Defaults to #general. Override via SLACK_TEST_CHANNEL env var. + The bot must be a member of this channel. + """ + return _DEFAULT_CHANNEL + + +@pytest.fixture(scope="session") +def slack_test_user_id() -> str: + """Slack user ID (U...) used as the target for send_direct_message tests. + + Requires SLACK_TEST_USER_ID env var. Tests that depend on this fixture + are skipped when the var is absent. + """ + user_id = os.environ.get("SLACK_TEST_USER_ID") + if not user_id: + pytest.skip("SLACK_TEST_USER_ID is required for direct message tests") + return user_id diff --git a/tests/playground/slack/slack_page.py b/tests/playground/slack/slack_page.py new file mode 100644 index 0000000..782abcd --- /dev/null +++ b/tests/playground/slack/slack_page.py @@ -0,0 +1,90 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from playwright.sync_api import Page, Locator + + +class SlackPage: + """Page Object Model for the Slack connector panel in the Playground.""" + + def __init__(self, page: Page) -> None: + self.page = page + + # Selector for the Slack card inside system connectors view + self.connector_card: Locator = page.locator(".connector-card[data-mode='slack']") + + # Panel root and header + self.panel: Locator = page.locator("#slack-panel") + self.title: Locator = page.locator("#slack-panel .card-title h2") + self.action_select: Locator = page.locator("#slack-action-select") + self.run_btn: Locator = page.locator("#slack-run-btn") + self.back_to_connectors: Locator = page.locator("#back-to-connectors") + + # Shared channel input (always visible) + self.channel: Locator = page.locator("#slack-panel input[name='channel']") + + # --- post_message / send_direct_message section --- + self.message_section: Locator = page.locator("#slack-message-section") + self.message: Locator = page.locator("#slack-message-section textarea[name='message']") + + # --- upload_file section --- + self.file_section: Locator = page.locator("#slack-file-section") + self.filename: Locator = page.locator("#slack-file-section input[name='filename']") + self.initial_comment: Locator = page.locator( + "#slack-file-section input[name='initial_comment']" + ) + self.file_input: Locator = page.locator("#slack-file") + self.file_drop_zone: Locator = page.locator("#slack-file-drop-zone") + self.file_chosen_preview: Locator = page.locator("#slack-file-chosen-preview") + self.preview_name: Locator = page.locator("#slack-file-chosen-preview .preview-name") + self.remove_file_btn: Locator = page.locator("#slack-file-chosen-preview .remove-file-btn") + + # --- Output and Logs elements --- + self.pipeline_steps: Locator = page.locator(".flow-node") + self.step_nodes: list[Locator] = [page.locator(f"#step-{i}") for i in range(4)] + self.final_result: Locator = page.locator("#final-result") + self.summary_text: Locator = page.locator("#human-summary") + self.result_tag: Locator = page.locator("#result-id") + self.log_terminal: Locator = page.locator("#log-terminal") + + def navigate_to_panel(self) -> None: + """Click the Slack card in system connectors to open the panel.""" + self.connector_card.click() + + def select_action(self, action: str) -> None: + """Change the action via the select element.""" + self.action_select.select_option(action) + + def fill_message_fields( + self, channel: str | None = None, message: str | None = None + ) -> None: + """Fill post_message / send_direct_message parameters.""" + if channel is not None: + self.channel.fill(channel) + if message is not None: + self.message.fill(message) + + def fill_upload_fields( + self, + channel: str | None = None, + filename: str | None = None, + initial_comment: str | None = None, + ) -> None: + """Fill upload_file parameters (excluding the file attachment itself).""" + if channel is not None: + self.channel.fill(channel) + if filename is not None: + self.filename.fill(filename) + if initial_comment is not None: + self.initial_comment.fill(initial_comment) + + def submit(self) -> None: + """Submit the form to execute the Slack workflow.""" + self.run_btn.click() + + def go_back(self) -> None: + """Click 'Back to All Connectors' to return to connectors selection view.""" + self.back_to_connectors.click() diff --git a/tests/playground/slack/test_slack_integration.py b/tests/playground/slack/test_slack_integration.py new file mode 100644 index 0000000..98b5afb --- /dev/null +++ b/tests/playground/slack/test_slack_integration.py @@ -0,0 +1,245 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Slack connector Playground real integration tests. + +Each test opens the Playground UI, navigates to the Slack panel, +selects an action, fills the form, clicks the run button, and asserts the +resulting pipeline state — no API mocking, real Slack API calls. + +Required env vars (loaded from .env): + SLACK_BOT_TOKEN — Slack bot token (xoxb-...) + +Optional env vars: + SLACK_TEST_CHANNEL — target channel (default: #general; bot must be a member) + SLACK_TEST_USER_ID — Slack user ID for DM tests (U...); skipped when absent +""" + +from __future__ import annotations + +import os +import tempfile +import time + +from playwright.sync_api import Page, expect + +from tests.playground.slack.slack_page import SlackPage +from tests.playground.home_page import PlaygroundHomePage + +_TIMEOUT = 25_000 # ms — 4-step pipeline with async Slack API calls + + +def _navigate_to_slack(page: Page) -> SlackPage: + PlaygroundHomePage(page).click_connectors() + slack = SlackPage(page) + slack.navigate_to_panel() + return slack + + +def _maybe_sleep() -> None: + env = os.getenv("PLAYGROUND_HEADED") or os.getenv("HEADED") + if env and env.lower().strip() in ("true", "1", "yes"): + time.sleep(3) + + +# ── post_message ────────────────────────────────────────────────────────────── + + +def test_slack_post_message_default(playground_page: Page, slack_test_channel: str) -> None: + """Post a message with default values; all 4 steps must succeed.""" + slack = _navigate_to_slack(playground_page) + + slack.select_action("post_message") + slack.fill_message_fields(channel=slack_test_channel) + slack.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(slack.final_result).to_be_visible(timeout=_TIMEOUT) + expect(slack.summary_text).to_contain_text("post_message") + expect(slack.summary_text).to_contain_text(slack_test_channel) + expect(slack.result_tag).to_be_visible() + expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(slack.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +def test_slack_post_message_custom_message( + playground_page: Page, slack_test_channel: str +) -> None: + """Post a message with custom content; summary must reflect the channel.""" + slack = _navigate_to_slack(playground_page) + + slack.select_action("post_message") + slack.fill_message_fields( + channel=slack_test_channel, + message="node-wire integration test — safe to ignore.", + ) + slack.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(slack.final_result).to_be_visible(timeout=_TIMEOUT) + expect(slack.summary_text).to_contain_text(slack_test_channel) + expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Active") + + _maybe_sleep() + + +def test_slack_post_message_invalid_channel(playground_page: Page) -> None: + """Post to a nonexistent channel; step-1 (Dispatch) must show error state.""" + slack = _navigate_to_slack(playground_page) + + slack.select_action("post_message") + slack.fill_message_fields(channel="this-channel-does-not-exist-99999") + slack.submit() + + # step-0 (Format Slack Payload) is local — always succeeds + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + # step-1 (Dispatch to Slack API) must fail for an invalid channel + expect(playground_page.locator("#step-1.error")).to_be_visible(timeout=_TIMEOUT) + expect(slack.final_result).to_be_hidden() + expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Failed") + expect(slack.log_terminal).to_contain_text("FAILED") + + _maybe_sleep() + + +# ── send_direct_message ─────────────────────────────────────────────────────── + + +def test_slack_send_direct_message( + playground_page: Page, slack_test_user_id: str +) -> None: + """Send a DM to a real user; all 4 steps must succeed.""" + slack = _navigate_to_slack(playground_page) + + slack.select_action("send_direct_message") + slack.fill_message_fields( + channel=slack_test_user_id, + message="node-wire DM integration test — safe to ignore.", + ) + slack.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(slack.final_result).to_be_visible(timeout=_TIMEOUT) + expect(slack.summary_text).to_contain_text("send_direct_message") + expect(slack.summary_text).to_contain_text(slack_test_user_id) + expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(slack.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +# ── upload_file ─────────────────────────────────────────────────────────────── + + +def test_slack_upload_file(playground_page: Page, slack_test_channel: str) -> None: + """Attach a temp file and upload it; all 4 steps must succeed.""" + slack = _navigate_to_slack(playground_page) + + slack.select_action("upload_file") + + with tempfile.NamedTemporaryFile( + suffix=".txt", delete=False, prefix="nw_slack_test_" + ) as tmp: + tmp.write(b"node-wire Slack upload integration test - safe to delete.") + tmp_path = tmp.name + + slack.file_input.set_input_files(tmp_path) + expect(slack.file_chosen_preview).to_be_visible(timeout=3_000) + expect(slack.file_drop_zone).to_be_hidden() + expect(slack.preview_name).to_contain_text("nw_slack_test_") + + slack.fill_upload_fields( + channel=slack_test_channel, + initial_comment="node-wire integration test upload — safe to delete.", + ) + slack.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(slack.final_result).to_be_visible(timeout=_TIMEOUT) + expect(slack.summary_text).to_contain_text("upload_file") + expect(slack.summary_text).to_contain_text(slack_test_channel) + expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(slack.log_terminal).to_contain_text("SUCCESS") + + _maybe_sleep() + + +def test_slack_upload_remove_and_reattach(playground_page: Page) -> None: + """Remove an attached file → drop zone reappears; re-attach → preview is restored.""" + slack = _navigate_to_slack(playground_page) + + slack.select_action("upload_file") + + with tempfile.NamedTemporaryFile( + suffix=".txt", delete=False, prefix="nw_slack_reattach_" + ) as tmp: + tmp.write(b"Reattach UI test content - safe to delete.") + tmp_path = tmp.name + + # Attach + slack.file_input.set_input_files(tmp_path) + expect(slack.file_chosen_preview).to_be_visible(timeout=3_000) + expect(slack.file_drop_zone).to_be_hidden() + + # Remove + slack.remove_file_btn.click() + expect(slack.file_chosen_preview).to_be_hidden(timeout=3_000) + expect(slack.file_drop_zone).to_be_visible() + + # Re-attach + slack.file_input.set_input_files(tmp_path) + expect(slack.file_chosen_preview).to_be_visible(timeout=3_000) + expect(slack.preview_name).to_contain_text("nw_slack_reattach_") + + +# ── cross-action switch ─────────────────────────────────────────────────────── + + +def test_slack_switch_post_message_then_upload( + playground_page: Page, slack_test_channel: str +) -> None: + """Run post_message then switch to upload_file on the same page — both must succeed.""" + slack = _navigate_to_slack(playground_page) + + # First run: post_message + slack.select_action("post_message") + slack.fill_message_fields(channel=slack_test_channel) + slack.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + expect(slack.final_result).to_be_visible(timeout=_TIMEOUT) + + # Switch to upload_file and run + slack.select_action("upload_file") + + with tempfile.NamedTemporaryFile( + suffix=".txt", delete=False, prefix="nw_slack_switch_" + ) as tmp: + tmp.write(b"Cross-action switch test - safe to delete.") + tmp_path = tmp.name + + slack.file_input.set_input_files(tmp_path) + expect(slack.file_chosen_preview).to_be_visible(timeout=3_000) + + slack.fill_upload_fields(channel=slack_test_channel) + slack.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + expect(slack.final_result).to_be_visible(timeout=_TIMEOUT) + expect(slack.summary_text).to_contain_text("upload_file") + expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Active") + + _maybe_sleep() From d760199988aaa5c2a281f1e8c3579e7dfeb285d4 Mon Sep 17 00:00:00 2001 From: rahul-aot Date: Tue, 26 May 2026 02:45:15 -0700 Subject: [PATCH 51/60] Add integration tests for Slack and Salesforce connectors, update environment variables, and refactor utility functions --- .github/workflows/pytest.yml | 1 + sample.env | 19 +++++-- .../gdrive/test_gdrive_integration.py | 31 +++++------ tests/playground/salesforce/conftest.py | 10 +--- tests/playground/salesforce/helpers.py | 15 ++++++ .../salesforce/test_salesforce_integration.py | 51 +++++++------------ tests/playground/slack/README.md | 7 ++- tests/playground/slack/conftest.py | 19 +++++++ .../slack/test_slack_integration.py | 31 +++++------ .../stripe/test_stripe_integration.py | 32 +++++------- tests/playground/utils.py | 15 ++++++ 11 files changed, 125 insertions(+), 106 deletions(-) create mode 100644 tests/playground/salesforce/helpers.py create mode 100644 tests/playground/utils.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 5710c26..1abd319 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -105,6 +105,7 @@ jobs: SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} SLACK_TEST_CHANNEL: ${{ secrets.SLACK_TEST_CHANNEL }} SLACK_TEST_USER_ID: ${{ secrets.SLACK_TEST_USER_ID }} + SLACK_TEST_CHANNEL_ID: ${{ secrets.SLACK_TEST_CHANNEL_ID }} # Disable authentication and dotenv loading for playground tests, and restrict connectors NW_REST_AUTH_DISABLED: "true" diff --git a/sample.env b/sample.env index 746048e..4db2b65 100644 --- a/sample.env +++ b/sample.env @@ -125,9 +125,6 @@ NW_REST_LOAD_DOTENV=true # NODE_WIRE_LEGACY_GDRIVE_ACTION_UPLOAD=warn # NODE_WIRE_LEGACY_GDRIVE_ACTION_UPLOAD=reject -NW_REST_AUTH_DISABLED=true - - NW_REST_RATE_LIMIT_ENABLED=true NW_REST_RATE_LIMIT_MAX_REQUESTS=120 NW_REST_RATE_LIMIT_WINDOW_SECONDS=60 @@ -148,3 +145,19 @@ SALESFORCE_REFRESH_TOKEN=your-refresh-token # Playwright playground headed execution - set to "true" to view the browser and its activities PLAYGROUND_HEADED=false + +# ----------------------------------------------------------------------------- +# Playground Integration Tests (pytest tests/playground/) +# ----------------------------------------------------------------------------- + +# Google Drive Playground Tests +GDRIVE_TEST_RECIPIENT_EMAIL=your-gdrive-test-recipient-email + +# Stripe Playground Tests +STRIPE_TEST_CUSTOMER_ID=your-stripe-test-customer-id +STRIPE_TEST_PRICE_ID=your-stripe-test-price-id + +# Slack Playground Tests +SLACK_TEST_CHANNEL=#your-slack-test-channel +SLACK_TEST_USER_ID=your-slack-test-user-id +SLACK_TEST_CHANNEL_ID=your-slack-test-channel-id diff --git a/tests/playground/gdrive/test_gdrive_integration.py b/tests/playground/gdrive/test_gdrive_integration.py index eee9eee..4d2e0a9 100644 --- a/tests/playground/gdrive/test_gdrive_integration.py +++ b/tests/playground/gdrive/test_gdrive_integration.py @@ -11,19 +11,18 @@ Required env vars (loaded from .env): GOOGLE_DRIVE_SA_JSON — service-account JSON (path or inline JSON) GOOGLE_DRIVE_FOLDER_ID — target folder for uploads - GDRIVE_TEST_RECIPIENT_EMAIL — email used as sharing recipient (default: rahul.ap@aot-technologies.com) + GDRIVE_TEST_RECIPIENT_EMAIL — email used as sharing recipient (default: test@mailinator.com) """ from __future__ import annotations -import os import tempfile -import time from playwright.sync_api import Page, expect from tests.playground.gdrive.gdrive_page import GoogleDrivePage from tests.playground.home_page import PlaygroundHomePage +from tests.playground.utils import maybe_sleep _TIMEOUT_STEP = 20_000 # ms — single-step operations (list, get) _TIMEOUT_MULTI = 45_000 # ms — multi-step operations (upload, update) @@ -36,12 +35,6 @@ def _navigate_to_gdrive(page: Page) -> GoogleDrivePage: return gdrive -def _maybe_sleep() -> None: - env = os.getenv("PLAYGROUND_HEADED") or os.getenv("HEADED") - if env and env.lower().strip() in ("true", "1", "yes"): - time.sleep(3) - - # ── files.list ──────────────────────────────────────────────────────────────── @@ -58,7 +51,7 @@ def test_gdrive_list_files_default_page_size(playground_page: Page) -> None: expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(gdrive.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() def test_gdrive_list_files_explicit_page_size(playground_page: Page) -> None: @@ -75,7 +68,7 @@ def test_gdrive_list_files_explicit_page_size(playground_page: Page) -> None: expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(gdrive.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() def test_gdrive_list_files_with_query(playground_page: Page) -> None: @@ -90,7 +83,7 @@ def test_gdrive_list_files_with_query(playground_page: Page) -> None: expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_STEP) expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") - _maybe_sleep() + maybe_sleep() # ── files.get ───────────────────────────────────────────────────────────────── @@ -111,7 +104,7 @@ def test_gdrive_get_file(playground_page: Page, real_gdrive_file_id: str) -> Non expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(gdrive.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() def test_gdrive_get_file_without_fields(playground_page: Page, real_gdrive_file_id: str) -> None: @@ -126,7 +119,7 @@ def test_gdrive_get_file_without_fields(playground_page: Page, real_gdrive_file_ expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_STEP) expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") - _maybe_sleep() + maybe_sleep() def test_gdrive_get_file_invalid_id(playground_page: Page) -> None: @@ -142,7 +135,7 @@ def test_gdrive_get_file_invalid_id(playground_page: Page) -> None: expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Failed") expect(gdrive.log_terminal).to_contain_text("FAILED") - _maybe_sleep() + maybe_sleep() # ── files.update ────────────────────────────────────────────────────────────── @@ -168,7 +161,7 @@ def test_gdrive_update_file_name(playground_page: Page, uploaded_test_file_id: s expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(gdrive.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() def test_gdrive_update_file_name_and_mime( @@ -191,7 +184,7 @@ def test_gdrive_update_file_name_and_mime( expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_MULTI) expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") - _maybe_sleep() + maybe_sleep() # ── files.upload ────────────────────────────────────────────────────────────── @@ -222,7 +215,7 @@ def test_gdrive_upload_file(playground_page: Page, test_recipient_email: str) -> expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(gdrive.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() def test_gdrive_upload_remove_and_reattach(playground_page: Page) -> None: @@ -272,4 +265,4 @@ def test_gdrive_switch_list_then_get(playground_page: Page, real_gdrive_file_id: expect(gdrive.summary_text).to_contain_text("Google Drive file metadata") expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") - _maybe_sleep() + maybe_sleep() diff --git a/tests/playground/salesforce/conftest.py b/tests/playground/salesforce/conftest.py index 331c057..67f141f 100644 --- a/tests/playground/salesforce/conftest.py +++ b/tests/playground/salesforce/conftest.py @@ -4,18 +4,10 @@ # from __future__ import annotations -import random - import httpx import pytest - -def _rnd() -> str: - return str(random.randint(100_000, 999_999)) - - -def _email() -> str: - return f"test{_rnd()}@mailinator.com" +from tests.playground.salesforce.helpers import rnd as _rnd, random_email as _email def _create_lead(api_server_url: str, last_name: str, company: str) -> str: diff --git a/tests/playground/salesforce/helpers.py b/tests/playground/salesforce/helpers.py new file mode 100644 index 0000000..ac20088 --- /dev/null +++ b/tests/playground/salesforce/helpers.py @@ -0,0 +1,15 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import random + + +def rnd() -> str: + return str(random.randint(100_000, 999_999)) + + +def random_email() -> str: + return f"test{rnd()}@mailinator.com" diff --git a/tests/playground/salesforce/test_salesforce_integration.py b/tests/playground/salesforce/test_salesforce_integration.py index 5e9c672..4caad9e 100644 --- a/tests/playground/salesforce/test_salesforce_integration.py +++ b/tests/playground/salesforce/test_salesforce_integration.py @@ -18,27 +18,16 @@ from __future__ import annotations -import os -import random -import time - from playwright.sync_api import Page, expect from tests.playground.home_page import PlaygroundHomePage +from tests.playground.salesforce.helpers import rnd as _rnd, random_email as _email from tests.playground.salesforce.salesforce_page import SalesforcePage +from tests.playground.utils import maybe_sleep _TIMEOUT = 20_000 # ms — all Salesforce operations are single-step -def _rnd() -> str: - """Return a 6-digit random suffix, unique enough to avoid duplicate-email rejections.""" - return str(random.randint(100_000, 999_999)) - - -def _email() -> str: - return f"test{_rnd()}@mailinator.com" - - def _navigate_to_salesforce(page: Page) -> SalesforcePage: PlaygroundHomePage(page).click_connectors() sf = SalesforcePage(page) @@ -46,12 +35,6 @@ def _navigate_to_salesforce(page: Page) -> SalesforcePage: return sf -def _maybe_sleep() -> None: - env = os.getenv("PLAYGROUND_HEADED") or os.getenv("HEADED") - if env and env.lower().strip() in ("true", "1", "yes"): - time.sleep(3) - - # ── create_lead ─────────────────────────────────────────────────────────────── @@ -69,7 +52,7 @@ def test_sf_create_lead_minimal(playground_page: Page) -> None: expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(sf.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() def test_sf_create_lead_full(playground_page: Page) -> None: @@ -90,7 +73,7 @@ def test_sf_create_lead_full(playground_page: Page) -> None: expect(sf.summary_text).to_contain_text("Lead created successfully") expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") - _maybe_sleep() + maybe_sleep() # ── create_contact ──────────────────────────────────────────────────────────── @@ -110,7 +93,7 @@ def test_sf_create_contact_minimal(playground_page: Page) -> None: expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(sf.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() def test_sf_create_contact_with_email(playground_page: Page) -> None: @@ -130,7 +113,7 @@ def test_sf_create_contact_with_email(playground_page: Page) -> None: expect(sf.summary_text).to_contain_text("Contact created successfully") expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") - _maybe_sleep() + maybe_sleep() # ── read_lead ───────────────────────────────────────────────────────────────── @@ -151,7 +134,7 @@ def test_sf_read_lead(playground_page: Page, real_sf_lead_id: str) -> None: expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(sf.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() def test_sf_read_lead_invalid_id(playground_page: Page) -> None: @@ -167,7 +150,7 @@ def test_sf_read_lead_invalid_id(playground_page: Page) -> None: expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Failed") expect(sf.log_terminal).to_contain_text("FAILED") - _maybe_sleep() + maybe_sleep() # ── read_contact ────────────────────────────────────────────────────────────── @@ -188,7 +171,7 @@ def test_sf_read_contact(playground_page: Page, real_sf_contact_id: str) -> None expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(sf.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() def test_sf_read_contact_invalid_id(playground_page: Page) -> None: @@ -204,7 +187,7 @@ def test_sf_read_contact_invalid_id(playground_page: Page) -> None: expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Failed") expect(sf.log_terminal).to_contain_text("FAILED") - _maybe_sleep() + maybe_sleep() # ── update_lead ─────────────────────────────────────────────────────────────── @@ -229,7 +212,7 @@ def test_sf_update_lead(playground_page: Page, real_sf_lead_id: str) -> None: expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(sf.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() def test_sf_update_lead_email(playground_page: Page, real_sf_lead_id: str) -> None: @@ -248,7 +231,7 @@ def test_sf_update_lead_email(playground_page: Page, real_sf_lead_id: str) -> No expect(sf.result_tag).to_contain_text(real_sf_lead_id) expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") - _maybe_sleep() + maybe_sleep() # ── update_contact ──────────────────────────────────────────────────────────── @@ -273,7 +256,7 @@ def test_sf_update_contact(playground_page: Page, real_sf_contact_id: str) -> No expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(sf.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() def test_sf_update_contact_email(playground_page: Page, real_sf_contact_id: str) -> None: @@ -292,7 +275,7 @@ def test_sf_update_contact_email(playground_page: Page, real_sf_contact_id: str) expect(sf.result_tag).to_contain_text(real_sf_contact_id) expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") - _maybe_sleep() + maybe_sleep() # ── delete_lead ─────────────────────────────────────────────────────────────── @@ -312,7 +295,7 @@ def test_sf_delete_lead(playground_page: Page, deletable_lead_id: str) -> None: expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(sf.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() # ── delete_contact ──────────────────────────────────────────────────────────── @@ -332,7 +315,7 @@ def test_sf_delete_contact(playground_page: Page, deletable_contact_id: str) -> expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(sf.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() # ── cross-action switch ─────────────────────────────────────────────────────── @@ -361,4 +344,4 @@ def test_sf_switch_create_lead_to_read(playground_page: Page, real_sf_lead_id: s expect(sf.summary_text).to_contain_text(real_sf_lead_id) expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") - _maybe_sleep() + maybe_sleep() diff --git a/tests/playground/slack/README.md b/tests/playground/slack/README.md index 016a5f3..d1e0b6b 100644 --- a/tests/playground/slack/README.md +++ b/tests/playground/slack/README.md @@ -58,11 +58,13 @@ Set these before running (`.env` is loaded automatically if present): | Variable | Description | Default | |----------|-------------|---------| -| `SLACK_TEST_CHANNEL` | Target channel for post and upload tests | `#general` | +| `SLACK_TEST_CHANNEL` | Target channel for post_message and send_direct_message tests | `#general` | +| `SLACK_TEST_CHANNEL_ID` | Channel **ID** (`C...`) for upload_file tests — required because the Slack external-upload API does not accept channel names | *(skipped if absent)* | | `SLACK_TEST_USER_ID` | Slack user ID (`U...`) for DM tests | *(skipped if absent)* | -The bot must be a member of `SLACK_TEST_CHANNEL`. +The bot must be a member of `SLACK_TEST_CHANNEL` and `SLACK_TEST_CHANNEL_ID`. `test_slack_send_direct_message` is automatically skipped when `SLACK_TEST_USER_ID` is absent. +`test_slack_upload_file` and `test_slack_switch_post_message_then_upload` are automatically skipped when `SLACK_TEST_CHANNEL_ID` is absent (and `SLACK_TEST_CHANNEL` is not already a bare ID). ## CI / GitHub Actions @@ -75,4 +77,5 @@ Credentials are read from repository secrets: |--------|----------------| | `SLACK_BOT_TOKEN` | `SLACK_BOT_TOKEN` | | `SLACK_TEST_CHANNEL` | `SLACK_TEST_CHANNEL` | +| `SLACK_TEST_CHANNEL_ID` | `SLACK_TEST_CHANNEL_ID` | | `SLACK_TEST_USER_ID` | `SLACK_TEST_USER_ID` | diff --git a/tests/playground/slack/conftest.py b/tests/playground/slack/conftest.py index d15d903..9c3a180 100644 --- a/tests/playground/slack/conftest.py +++ b/tests/playground/slack/conftest.py @@ -10,6 +10,7 @@ import pytest _DEFAULT_CHANNEL = os.environ.get("SLACK_TEST_CHANNEL", "#general") +_DEFAULT_CHANNEL_ID = os.environ.get("SLACK_TEST_CHANNEL_ID", "") @pytest.fixture(scope="session", autouse=True) @@ -43,6 +44,24 @@ def slack_test_channel() -> str: return _DEFAULT_CHANNEL +@pytest.fixture(scope="session") +def slack_upload_channel() -> str: + """Channel ID used for upload_file tests. + + Prefers SLACK_TEST_CHANNEL_ID (must be a bare channel ID like C0ANP6RADHU). + Falls back to SLACK_TEST_CHANNEL, but skips if that is still a name — the + Slack external-upload API requires an ID, not a name. + """ + if _DEFAULT_CHANNEL_ID: + return _DEFAULT_CHANNEL_ID + if _DEFAULT_CHANNEL and _DEFAULT_CHANNEL[0].upper() in ("C", "G", "D"): + return _DEFAULT_CHANNEL + pytest.skip( + "upload_file tests require a channel ID. " + "Set SLACK_TEST_CHANNEL_ID (e.g. C0ANP6RADHU) in .env." + ) + + @pytest.fixture(scope="session") def slack_test_user_id() -> str: """Slack user ID (U...) used as the target for send_direct_message tests. diff --git a/tests/playground/slack/test_slack_integration.py b/tests/playground/slack/test_slack_integration.py index 98b5afb..9cd75bd 100644 --- a/tests/playground/slack/test_slack_integration.py +++ b/tests/playground/slack/test_slack_integration.py @@ -18,14 +18,13 @@ from __future__ import annotations -import os import tempfile -import time from playwright.sync_api import Page, expect from tests.playground.slack.slack_page import SlackPage from tests.playground.home_page import PlaygroundHomePage +from tests.playground.utils import maybe_sleep _TIMEOUT = 25_000 # ms — 4-step pipeline with async Slack API calls @@ -37,12 +36,6 @@ def _navigate_to_slack(page: Page) -> SlackPage: return slack -def _maybe_sleep() -> None: - env = os.getenv("PLAYGROUND_HEADED") or os.getenv("HEADED") - if env and env.lower().strip() in ("true", "1", "yes"): - time.sleep(3) - - # ── post_message ────────────────────────────────────────────────────────────── @@ -64,7 +57,7 @@ def test_slack_post_message_default(playground_page: Page, slack_test_channel: s expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(slack.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() def test_slack_post_message_custom_message( @@ -87,7 +80,7 @@ def test_slack_post_message_custom_message( expect(slack.summary_text).to_contain_text(slack_test_channel) expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Active") - _maybe_sleep() + maybe_sleep() def test_slack_post_message_invalid_channel(playground_page: Page) -> None: @@ -106,7 +99,7 @@ def test_slack_post_message_invalid_channel(playground_page: Page) -> None: expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Failed") expect(slack.log_terminal).to_contain_text("FAILED") - _maybe_sleep() + maybe_sleep() # ── send_direct_message ─────────────────────────────────────────────────────── @@ -134,13 +127,13 @@ def test_slack_send_direct_message( expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(slack.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() # ── upload_file ─────────────────────────────────────────────────────────────── -def test_slack_upload_file(playground_page: Page, slack_test_channel: str) -> None: +def test_slack_upload_file(playground_page: Page, slack_upload_channel: str) -> None: """Attach a temp file and upload it; all 4 steps must succeed.""" slack = _navigate_to_slack(playground_page) @@ -158,7 +151,7 @@ def test_slack_upload_file(playground_page: Page, slack_test_channel: str) -> No expect(slack.preview_name).to_contain_text("nw_slack_test_") slack.fill_upload_fields( - channel=slack_test_channel, + channel=slack_upload_channel, initial_comment="node-wire integration test upload — safe to delete.", ) slack.submit() @@ -168,11 +161,11 @@ def test_slack_upload_file(playground_page: Page, slack_test_channel: str) -> No expect(slack.final_result).to_be_visible(timeout=_TIMEOUT) expect(slack.summary_text).to_contain_text("upload_file") - expect(slack.summary_text).to_contain_text(slack_test_channel) + expect(slack.summary_text).to_contain_text(slack_upload_channel) expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(slack.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() def test_slack_upload_remove_and_reattach(playground_page: Page) -> None: @@ -207,7 +200,7 @@ def test_slack_upload_remove_and_reattach(playground_page: Page) -> None: def test_slack_switch_post_message_then_upload( - playground_page: Page, slack_test_channel: str + playground_page: Page, slack_test_channel: str, slack_upload_channel: str ) -> None: """Run post_message then switch to upload_file on the same page — both must succeed.""" slack = _navigate_to_slack(playground_page) @@ -233,7 +226,7 @@ def test_slack_switch_post_message_then_upload( slack.file_input.set_input_files(tmp_path) expect(slack.file_chosen_preview).to_be_visible(timeout=3_000) - slack.fill_upload_fields(channel=slack_test_channel) + slack.fill_upload_fields(channel=slack_upload_channel) slack.submit() for i in range(4): @@ -242,4 +235,4 @@ def test_slack_switch_post_message_then_upload( expect(slack.summary_text).to_contain_text("upload_file") expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Active") - _maybe_sleep() + maybe_sleep() diff --git a/tests/playground/stripe/test_stripe_integration.py b/tests/playground/stripe/test_stripe_integration.py index 7527692..02524bb 100644 --- a/tests/playground/stripe/test_stripe_integration.py +++ b/tests/playground/stripe/test_stripe_integration.py @@ -18,13 +18,11 @@ from __future__ import annotations -import os -import time - from playwright.sync_api import Page, expect from tests.playground.stripe.stripe_page import StripePage from tests.playground.home_page import PlaygroundHomePage +from tests.playground.utils import maybe_sleep _TIMEOUT = 20_000 # ms — all Stripe scenarios are 3-step @@ -36,12 +34,6 @@ def _navigate_to_stripe(page: Page) -> StripePage: return stripe -def _maybe_sleep() -> None: - env = os.getenv("PLAYGROUND_HEADED") or os.getenv("HEADED") - if env and env.lower().strip() in ("true", "1", "yes"): - time.sleep(3) - - # ── charge ──────────────────────────────────────────────────────────────────── @@ -61,7 +53,7 @@ def test_stripe_charge_default(playground_page: Page) -> None: expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(stripe.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() def test_stripe_charge_custom_amount(playground_page: Page) -> None: @@ -80,7 +72,7 @@ def test_stripe_charge_custom_amount(playground_page: Page) -> None: expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(stripe.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() def test_stripe_charge_no_description(playground_page: Page) -> None: @@ -97,7 +89,7 @@ def test_stripe_charge_no_description(playground_page: Page) -> None: expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") - _maybe_sleep() + maybe_sleep() # ── payment_intent ──────────────────────────────────────────────────────────── @@ -119,7 +111,7 @@ def test_stripe_payment_intent_default(playground_page: Page) -> None: expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(stripe.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() def test_stripe_payment_intent_custom_amount(playground_page: Page) -> None: @@ -137,7 +129,7 @@ def test_stripe_payment_intent_custom_amount(playground_page: Page) -> None: expect(stripe.result_tag).to_contain_text("pi_") expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") - _maybe_sleep() + maybe_sleep() def test_stripe_payment_intent_no_payment_method(playground_page: Page) -> None: @@ -154,7 +146,7 @@ def test_stripe_payment_intent_no_payment_method(playground_page: Page) -> None: expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") - _maybe_sleep() + maybe_sleep() # ── cancel_subscription ─────────────────────────────────────────────────────── @@ -176,7 +168,7 @@ def test_stripe_cancel_subscription_invalid_id(playground_page: Page) -> None: expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Failed") expect(stripe.log_terminal).to_contain_text("FAILED") - _maybe_sleep() + maybe_sleep() def test_stripe_cancel_subscription( @@ -198,7 +190,7 @@ def test_stripe_cancel_subscription( expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(stripe.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() # ── refund ──────────────────────────────────────────────────────────────────── @@ -221,7 +213,7 @@ def test_stripe_refund_by_charge_id(playground_page: Page, real_stripe_charge_id expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") expect(stripe.log_terminal).to_contain_text("SUCCESS") - _maybe_sleep() + maybe_sleep() def test_stripe_refund_invalid_id(playground_page: Page) -> None: @@ -240,7 +232,7 @@ def test_stripe_refund_invalid_id(playground_page: Page) -> None: expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Failed") expect(stripe.log_terminal).to_contain_text("FAILED") - _maybe_sleep() + maybe_sleep() # ── cross-action switch ─────────────────────────────────────────────────────── @@ -268,4 +260,4 @@ def test_stripe_switch_charge_then_payment_intent(playground_page: Page) -> None expect(stripe.summary_text).to_contain_text("payment intent") expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") - _maybe_sleep() + maybe_sleep() diff --git a/tests/playground/utils.py b/tests/playground/utils.py new file mode 100644 index 0000000..13345e0 --- /dev/null +++ b/tests/playground/utils.py @@ -0,0 +1,15 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import os +import time + + +def maybe_sleep() -> None: + """Pause for 3 s when running headed so a developer can observe the result.""" + env = os.getenv("PLAYGROUND_HEADED") or os.getenv("HEADED") + if env and env.lower().strip() in ("true", "1", "yes"): + time.sleep(3) From 71feab5a7a82addc4b3dfe9edb4600b1dae9ec50 Mon Sep 17 00:00:00 2001 From: rahul-aot Date: Tue, 26 May 2026 02:49:23 -0700 Subject: [PATCH 52/60] Remove unused import of random_email from Salesforce test helpers --- tests/playground/salesforce/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/playground/salesforce/conftest.py b/tests/playground/salesforce/conftest.py index 67f141f..cb7f4a5 100644 --- a/tests/playground/salesforce/conftest.py +++ b/tests/playground/salesforce/conftest.py @@ -7,7 +7,7 @@ import httpx import pytest -from tests.playground.salesforce.helpers import rnd as _rnd, random_email as _email +from tests.playground.salesforce.helpers import rnd as _rnd def _create_lead(api_server_url: str, last_name: str, company: str) -> str: From e6ad94e48f168f10036e2dca655c57b81fbf13da Mon Sep 17 00:00:00 2001 From: rahul-aot Date: Tue, 26 May 2026 02:55:01 -0700 Subject: [PATCH 53/60] Refactor function signatures in Slack integration tests for improved readability --- tests/playground/slack/slack_page.py | 4 +--- tests/playground/slack/test_slack_integration.py | 16 ++++------------ 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/tests/playground/slack/slack_page.py b/tests/playground/slack/slack_page.py index 782abcd..43a54c5 100644 --- a/tests/playground/slack/slack_page.py +++ b/tests/playground/slack/slack_page.py @@ -58,9 +58,7 @@ def select_action(self, action: str) -> None: """Change the action via the select element.""" self.action_select.select_option(action) - def fill_message_fields( - self, channel: str | None = None, message: str | None = None - ) -> None: + def fill_message_fields(self, channel: str | None = None, message: str | None = None) -> None: """Fill post_message / send_direct_message parameters.""" if channel is not None: self.channel.fill(channel) diff --git a/tests/playground/slack/test_slack_integration.py b/tests/playground/slack/test_slack_integration.py index 9cd75bd..21de6c1 100644 --- a/tests/playground/slack/test_slack_integration.py +++ b/tests/playground/slack/test_slack_integration.py @@ -60,9 +60,7 @@ def test_slack_post_message_default(playground_page: Page, slack_test_channel: s maybe_sleep() -def test_slack_post_message_custom_message( - playground_page: Page, slack_test_channel: str -) -> None: +def test_slack_post_message_custom_message(playground_page: Page, slack_test_channel: str) -> None: """Post a message with custom content; summary must reflect the channel.""" slack = _navigate_to_slack(playground_page) @@ -105,9 +103,7 @@ def test_slack_post_message_invalid_channel(playground_page: Page) -> None: # ── send_direct_message ─────────────────────────────────────────────────────── -def test_slack_send_direct_message( - playground_page: Page, slack_test_user_id: str -) -> None: +def test_slack_send_direct_message(playground_page: Page, slack_test_user_id: str) -> None: """Send a DM to a real user; all 4 steps must succeed.""" slack = _navigate_to_slack(playground_page) @@ -139,9 +135,7 @@ def test_slack_upload_file(playground_page: Page, slack_upload_channel: str) -> slack.select_action("upload_file") - with tempfile.NamedTemporaryFile( - suffix=".txt", delete=False, prefix="nw_slack_test_" - ) as tmp: + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False, prefix="nw_slack_test_") as tmp: tmp.write(b"node-wire Slack upload integration test - safe to delete.") tmp_path = tmp.name @@ -217,9 +211,7 @@ def test_slack_switch_post_message_then_upload( # Switch to upload_file and run slack.select_action("upload_file") - with tempfile.NamedTemporaryFile( - suffix=".txt", delete=False, prefix="nw_slack_switch_" - ) as tmp: + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False, prefix="nw_slack_switch_") as tmp: tmp.write(b"Cross-action switch test - safe to delete.") tmp_path = tmp.name From 5cb94c1cf02d0063d2b08460f627b5f0f2d37090 Mon Sep 17 00:00:00 2001 From: rahul-aot Date: Wed, 27 May 2026 22:20:38 -0700 Subject: [PATCH 54/60] Add integration tests for Cerner and Epic FHIR connectors, and HTTP connector; update environment variables and implement necessary fixtures --- .github/workflows/pytest.yml | 17 +++- tests/playground/cerner/README.md | 84 +++++++++++++++++++ tests/playground/cerner/cerner_page.py | 41 +++++++++ tests/playground/cerner/conftest.py | 35 ++++++++ .../cerner/test_cerner_integration.py | 47 +++++++++++ tests/playground/epic_fhir/README.md | 82 ++++++++++++++++++ tests/playground/epic_fhir/conftest.py | 35 ++++++++ tests/playground/epic_fhir/epic_fhir_page.py | 41 +++++++++ .../epic_fhir/test_epic_fhir_integration.py | 47 +++++++++++ tests/playground/http_connector/README.md | 68 +++++++++++++++ tests/playground/http_connector/conftest.py | 33 ++++++++ .../http_connector/http_connector_page.py | 41 +++++++++ .../test_http_connector_integration.py | 46 ++++++++++ 13 files changed, 616 insertions(+), 1 deletion(-) create mode 100644 tests/playground/cerner/README.md create mode 100644 tests/playground/cerner/cerner_page.py create mode 100644 tests/playground/cerner/conftest.py create mode 100644 tests/playground/cerner/test_cerner_integration.py create mode 100644 tests/playground/epic_fhir/README.md create mode 100644 tests/playground/epic_fhir/conftest.py create mode 100644 tests/playground/epic_fhir/epic_fhir_page.py create mode 100644 tests/playground/epic_fhir/test_epic_fhir_integration.py create mode 100644 tests/playground/http_connector/README.md create mode 100644 tests/playground/http_connector/conftest.py create mode 100644 tests/playground/http_connector/http_connector_page.py create mode 100644 tests/playground/http_connector/test_http_connector_integration.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 1abd319..21e2774 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -107,10 +107,25 @@ jobs: SLACK_TEST_USER_ID: ${{ secrets.SLACK_TEST_USER_ID }} SLACK_TEST_CHANNEL_ID: ${{ secrets.SLACK_TEST_CHANNEL_ID }} + #epic fhir + EPIC_CLIENT_ID: ${{ secrets.EPIC_CLIENT_ID }} + EPIC_PRIVATE_KEY: ${{ secrets.EPIC_PRIVATE_KEY }} + EPIC_TOKEN_URL: ${{ secrets.EPIC_TOKEN_URL }} + EPIC_KID: ${{ secrets.EPIC_KID }} + EPIC_FHIR_BASE_URL: ${{ secrets.EPIC_FHIR_BASE_URL }} + + #cerner + CERNER_CLIENT_ID: ${{ secrets.CERNER_CLIENT_ID }} + CERNER_PRIVATE_KEY: ${{ secrets.CERNER_PRIVATE_KEY }} + CERNER_TOKEN_URL: ${{ secrets.CERNER_TOKEN_URL }} + CERNER_KID: ${{ secrets.CERNER_KID }} + CERNER_FHIR_BASE_URL: ${{ secrets.CERNER_FHIR_BASE_URL }} + CERNER_SCOPES: ${{ secrets.CERNER_SCOPES }} + # Disable authentication and dotenv loading for playground tests, and restrict connectors NW_REST_AUTH_DISABLED: "true" NW_REST_LOAD_DOTENV: "false" - NW_ALLOWED_CONNECTORS: "google_drive,salesforce,stripe,slack" + NW_ALLOWED_CONNECTORS: "google_drive,salesforce,stripe,slack,fhir_epic,fhir_cerner,http_generic" run: uv run pytest tests/playground/ --no-cov -v - name: Upload Playwright traces on failure diff --git a/tests/playground/cerner/README.md b/tests/playground/cerner/README.md new file mode 100644 index 0000000..80c72b7 --- /dev/null +++ b/tests/playground/cerner/README.md @@ -0,0 +1,84 @@ + + +# Cerner Playground Integration Tests + +End-to-end Playwright tests that open the Playground UI in a real browser, +navigate to the Cerner connector panel, click the run button with the +pre-filled defaults, and assert on the rendered pipeline state. No mocking +— every test hits the real Cerner FHIR R4 Sandbox API. + +## What is tested + +| Test | Action | +|------|--------| +| `test_cerner_post_consultation_default` | Post-consultation sync — pre-filled patient Nancy Smart, all 4 steps must succeed | + +## How it works + +The test session starts a real FastAPI server on a random local port. Playwright +navigates to `/playground/`. The browser's +`fetch("/scenarios/cerner-post-consultation")` call routes to the real backend, +which authenticates via private-key JWT and calls the real Cerner FHIR R4 +Sandbox. No `page.route()` interception. + +The form is pre-filled in the HTML with a sandbox patient (`12724066`) and +encounter (`97957281`) — no field changes or dropdown selections are needed +before clicking run. + +## Running locally + +```bash +# Install Playwright browsers (once) +uv run python -m playwright install chromium + +# Run all Cerner tests +uv run pytest tests/playground/cerner/ --no-cov -v + +# Run headed (watch the browser) +PLAYGROUND_HEADED=true uv run pytest tests/playground/cerner/ --no-cov -v -s +``` + +> **Note:** Cerner tests are excluded from the default `uv run pytest` run and +> from regular CI (push/PR). They must be triggered explicitly. + +## Required environment variables + +Set these before running (`.env` is loaded automatically if present): + +| Variable | Description | +|----------|-------------| +| `CERNER_CLIENT_ID` | Cerner backend application client ID | +| `CERNER_PRIVATE_KEY` | RSA private key (PEM) used for private-key JWT auth | +| `CERNER_TOKEN_URL` | Cerner token endpoint URL | +| `CERNER_KID` | Key ID (`kid`) that matches the public key registered in Cerner | +| `CERNER_FHIR_BASE_URL` | Base FHIR R4 URL including tenant ID (defaults to the Cerner code sandbox if unset) | +| `CERNER_SCOPES` | Space-separated OAuth2 scopes (optional; defaults defined in `connectors.yaml`) | +| `NW_REST_AUTH_DISABLED` | Set to `true` to skip REST auth middleware | + +## CI / GitHub Actions + +Cerner tests run **only on manual `workflow_dispatch`** trigger +(`Actions → CI – Pytest → Run workflow`). + +Credentials are read from repository secrets: + +| Secret | Maps to env var | +|--------|----------------| +| `CERNER_CLIENT_ID` | `CERNER_CLIENT_ID` | +| `CERNER_PRIVATE_KEY` | `CERNER_PRIVATE_KEY` | +| `CERNER_TOKEN_URL` | `CERNER_TOKEN_URL` | +| `CERNER_KID` | `CERNER_KID` | +| `CERNER_FHIR_BASE_URL` | `CERNER_FHIR_BASE_URL` | + +Set these under **Settings → Secrets and variables → Actions** before triggering +the workflow. + +## Test data and cleanup + +The test uses the Cerner open Sandbox patient and encounter IDs pre-filled in the +Playground HTML. These are read-only sandbox resources — no records are created +or modified in a real Cerner environment. No cleanup is required after the session. diff --git a/tests/playground/cerner/cerner_page.py b/tests/playground/cerner/cerner_page.py new file mode 100644 index 0000000..76ca363 --- /dev/null +++ b/tests/playground/cerner/cerner_page.py @@ -0,0 +1,41 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from playwright.sync_api import Page, Locator + + +class CernerPage: + """Page Object Model for the Cerner connector panel in the Playground.""" + + def __init__(self, page: Page) -> None: + self.page = page + + # Selector for the Cerner card inside system connectors view + self.connector_card: Locator = page.locator(".connector-card[data-mode='cerner']") + + # Panel root and header + self.panel: Locator = page.locator("#cerner-panel") + self.title: Locator = page.locator("#cerner-panel .card-title h2") + self.run_btn: Locator = page.locator("#cerner-run-btn") + self.back_to_connectors: Locator = page.locator("#back-to-connectors") + + # Output and log elements + self.final_result: Locator = page.locator("#final-result") + self.summary_text: Locator = page.locator("#human-summary") + self.result_tag: Locator = page.locator("#result-id") + self.log_terminal: Locator = page.locator("#log-terminal") + + def navigate_to_panel(self) -> None: + """Click the Cerner card in system connectors to open the panel.""" + self.connector_card.click() + + def submit(self) -> None: + """Submit the form to execute the Cerner workflow.""" + self.run_btn.click() + + def go_back(self) -> None: + """Click 'Back to All Connectors' to return to connectors selection view.""" + self.back_to_connectors.click() diff --git a/tests/playground/cerner/conftest.py b/tests/playground/cerner/conftest.py new file mode 100644 index 0000000..3ac2004 --- /dev/null +++ b/tests/playground/cerner/conftest.py @@ -0,0 +1,35 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import httpx +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def cerner_connector_available(api_server_url: str) -> None: + """Skip the entire Cerner test session if the connector returns HTTP 500. + + This happens when Cerner FHIR credentials are missing or when NW_ALLOWED_CONNECTORS + is set but does not include 'fhir_cerner'. + """ + with httpx.Client(timeout=15) as client: + resp = client.post( + f"{api_server_url}/scenarios/cerner-post-consultation", + json={ + "patient_id": "12724066", + "encounter_id": "97957281", + "patient_given": "Nancy", + "patient_family": "Smart", + "note_text": "health-check", + }, + ) + if resp.status_code == 500: + detail = resp.json().get("detail", "unknown") + pytest.skip( + f"Cerner connector not available ({detail}). " + "Ensure Cerner credentials are configured and 'fhir_cerner' is in NW_ALLOWED_CONNECTORS " + "(or leave it unset)." + ) diff --git a/tests/playground/cerner/test_cerner_integration.py b/tests/playground/cerner/test_cerner_integration.py new file mode 100644 index 0000000..8866c3d --- /dev/null +++ b/tests/playground/cerner/test_cerner_integration.py @@ -0,0 +1,47 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Cerner connector Playground real integration tests. + +Each test opens the Playground UI, navigates to the Cerner panel, +clicks the run button with pre-filled defaults, and asserts the resulting +pipeline state — no API mocking, real Cerner FHIR Sandbox calls. + +Required env vars (loaded from .env): + Cerner credentials (e.g. CERNER_CLIENT_ID, CERNER_CLIENT_SECRET, CERNER_BASE_URL) +""" + +from __future__ import annotations + +from playwright.sync_api import Page, expect + +from tests.playground.cerner.cerner_page import CernerPage +from tests.playground.home_page import PlaygroundHomePage +from tests.playground.utils import maybe_sleep + +_TIMEOUT = 25_000 # ms — 4-step pipeline with async Cerner FHIR API calls + + +def _navigate_to_cerner(page: Page) -> CernerPage: + PlaygroundHomePage(page).click_connectors() + cerner = CernerPage(page) + cerner.navigate_to_panel() + return cerner + + +def test_cerner_post_consultation_default(playground_page: Page) -> None: + """Submit a Cerner consultation with default pre-filled values; all 4 steps must succeed.""" + cerner = _navigate_to_cerner(playground_page) + cerner.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(cerner.final_result).to_be_visible(timeout=_TIMEOUT) + expect(cerner.summary_text).to_contain_text("Cerner EHR") + expect(cerner.result_tag).to_be_visible() + expect(playground_page.locator("#cerner-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(cerner.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() diff --git a/tests/playground/epic_fhir/README.md b/tests/playground/epic_fhir/README.md new file mode 100644 index 0000000..c9c78e7 --- /dev/null +++ b/tests/playground/epic_fhir/README.md @@ -0,0 +1,82 @@ + + +# Epic FHIR Playground Integration Tests + +End-to-end Playwright tests that open the Playground UI in a real browser, +navigate to the Epic FHIR (EHR) connector panel, click the run button with +the pre-filled defaults, and assert on the rendered pipeline state. No mocking +— every test hits the real Epic FHIR Sandbox API. + +## What is tested + +| Test | Action | +|------|--------| +| `test_epic_fhir_post_consultation_default` | Post-consultation sync — pre-filled patient Jason Smith, all 4 steps must succeed | + +## How it works + +The test session starts a real FastAPI server on a random local port. Playwright +navigates to `/playground/`. The browser's `fetch("/scenarios/post-consultation")` +call routes to the real backend, which authenticates via private-key JWT and calls +the real Epic FHIR Sandbox. No `page.route()` interception. + +The form is pre-filled in the HTML with a sandbox patient (`e63wRTbPfr1p8UW81d8Seiw3`) +and encounter (`ecgXt3jVqNNpsXnNXZ3KljA3`) — no field changes or dropdown +selections are needed before clicking run. + +## Running locally + +```bash +# Install Playwright browsers (once) +uv run python -m playwright install chromium + +# Run all Epic FHIR tests +uv run pytest tests/playground/epic_fhir/ --no-cov -v + +# Run headed (watch the browser) +PLAYGROUND_HEADED=true uv run pytest tests/playground/epic_fhir/ --no-cov -v -s +``` + +> **Note:** Epic FHIR tests are excluded from the default `uv run pytest` run and +> from regular CI (push/PR). They must be triggered explicitly. + +## Required environment variables + +Set these before running (`.env` is loaded automatically if present): + +| Variable | Description | +|----------|-------------| +| `EPIC_CLIENT_ID` | Epic backend application client ID | +| `EPIC_PRIVATE_KEY` | RSA private key (PEM) used for private-key JWT auth | +| `EPIC_TOKEN_URL` | Epic token endpoint URL | +| `EPIC_KID` | Key ID (`kid`) that matches the public key registered in Epic | +| `EPIC_FHIR_BASE_URL` | Base FHIR R4 URL (defaults to the Epic Sandbox if unset) | +| `NW_REST_AUTH_DISABLED` | Set to `true` to skip REST auth middleware | + +## CI / GitHub Actions + +Epic FHIR tests run **only on manual `workflow_dispatch`** trigger +(`Actions → CI – Pytest → Run workflow`). + +Credentials are read from repository secrets: + +| Secret | Maps to env var | +|--------|----------------| +| `EPIC_CLIENT_ID` | `EPIC_CLIENT_ID` | +| `EPIC_PRIVATE_KEY` | `EPIC_PRIVATE_KEY` | +| `EPIC_TOKEN_URL` | `EPIC_TOKEN_URL` | +| `EPIC_KID` | `EPIC_KID` | +| `EPIC_FHIR_BASE_URL` | `EPIC_FHIR_BASE_URL` | + +Set these under **Settings → Secrets and variables → Actions** before triggering +the workflow. + +## Test data and cleanup + +The test uses the Epic open Sandbox patient and encounter IDs pre-filled in the +Playground HTML. These are read-only sandbox resources — no records are created +or modified in a real Epic environment. No cleanup is required after the session. diff --git a/tests/playground/epic_fhir/conftest.py b/tests/playground/epic_fhir/conftest.py new file mode 100644 index 0000000..6b6cde5 --- /dev/null +++ b/tests/playground/epic_fhir/conftest.py @@ -0,0 +1,35 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import httpx +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def epic_fhir_connector_available(api_server_url: str) -> None: + """Skip the entire Epic FHIR test session if the connector returns HTTP 500. + + This happens when Epic FHIR credentials are missing or when NW_ALLOWED_CONNECTORS + is set but does not include 'fhir_epic'. + """ + with httpx.Client(timeout=15) as client: + resp = client.post( + f"{api_server_url}/scenarios/post-consultation", + json={ + "patient_id": "e63wRTbPfr1p8UW81d8Seiw3", + "encounter_id": "ecgXt3jVqNNpsXnNXZ3KljA3", + "patient_given": "Jason", + "patient_family": "Smith", + "note_text": "health-check", + }, + ) + if resp.status_code == 500: + detail = resp.json().get("detail", "unknown") + pytest.skip( + f"Epic FHIR connector not available ({detail}). " + "Ensure Epic credentials are configured and 'fhir_epic' is in NW_ALLOWED_CONNECTORS " + "(or leave it unset)." + ) diff --git a/tests/playground/epic_fhir/epic_fhir_page.py b/tests/playground/epic_fhir/epic_fhir_page.py new file mode 100644 index 0000000..0a65c4a --- /dev/null +++ b/tests/playground/epic_fhir/epic_fhir_page.py @@ -0,0 +1,41 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from playwright.sync_api import Page, Locator + + +class EpicFhirPage: + """Page Object Model for the Epic FHIR (EHR) connector panel in the Playground.""" + + def __init__(self, page: Page) -> None: + self.page = page + + # Selector for the Epic FHIR card inside system connectors view + self.connector_card: Locator = page.locator(".connector-card[data-mode='ehr']") + + # Panel root and header + self.panel: Locator = page.locator("#ehr-panel") + self.title: Locator = page.locator("#ehr-panel .card-title h2") + self.run_btn: Locator = page.locator("#run-btn") + self.back_to_connectors: Locator = page.locator("#back-to-connectors") + + # Output and log elements + self.final_result: Locator = page.locator("#final-result") + self.summary_text: Locator = page.locator("#human-summary") + self.result_tag: Locator = page.locator("#result-id") + self.log_terminal: Locator = page.locator("#log-terminal") + + def navigate_to_panel(self) -> None: + """Click the Epic FHIR card in system connectors to open the panel.""" + self.connector_card.click() + + def submit(self) -> None: + """Submit the form to execute the Epic FHIR workflow.""" + self.run_btn.click() + + def go_back(self) -> None: + """Click 'Back to All Connectors' to return to connectors selection view.""" + self.back_to_connectors.click() diff --git a/tests/playground/epic_fhir/test_epic_fhir_integration.py b/tests/playground/epic_fhir/test_epic_fhir_integration.py new file mode 100644 index 0000000..c7b52ca --- /dev/null +++ b/tests/playground/epic_fhir/test_epic_fhir_integration.py @@ -0,0 +1,47 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Epic FHIR connector Playground real integration tests. + +Each test opens the Playground UI, navigates to the Epic FHIR panel, +clicks the run button with pre-filled defaults, and asserts the resulting +pipeline state — no API mocking, real Epic FHIR Sandbox calls. + +Required env vars (loaded from .env): + Epic FHIR credentials (e.g. EPIC_CLIENT_ID, EPIC_CLIENT_SECRET, EPIC_BASE_URL) +""" + +from __future__ import annotations + +from playwright.sync_api import Page, expect + +from tests.playground.epic_fhir.epic_fhir_page import EpicFhirPage +from tests.playground.home_page import PlaygroundHomePage +from tests.playground.utils import maybe_sleep + +_TIMEOUT = 25_000 # ms — 4-step pipeline with async Epic FHIR API calls + + +def _navigate_to_epic_fhir(page: Page) -> EpicFhirPage: + PlaygroundHomePage(page).click_connectors() + epic = EpicFhirPage(page) + epic.navigate_to_panel() + return epic + + +def test_epic_fhir_post_consultation_default(playground_page: Page) -> None: + """Submit a consultation with default pre-filled values; all 4 steps must succeed.""" + epic = _navigate_to_epic_fhir(playground_page) + epic.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(epic.final_result).to_be_visible(timeout=_TIMEOUT) + expect(epic.summary_text).to_contain_text("Epic") + expect(epic.result_tag).to_be_visible() + expect(playground_page.locator("#run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(epic.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() diff --git a/tests/playground/http_connector/README.md b/tests/playground/http_connector/README.md new file mode 100644 index 0000000..a1d8912 --- /dev/null +++ b/tests/playground/http_connector/README.md @@ -0,0 +1,68 @@ + + +# HTTP Connector Playground Integration Tests + +End-to-end Playwright tests that open the Playground UI in a real browser, +navigate to the HTTP connector (IT Ops) panel, click the run button with the +pre-filled defaults, and assert on the rendered pipeline state. No mocking — +every test makes a real HTTP POST via the `http_generic` connector. + +## What is tested + +| Test | Action | +|------|--------| +| `test_http_connector_submit_incident_default` | IT incident report — pre-filled High severity Gateway Proxy incident, all 4 steps must succeed | + +## How it works + +The test session starts a real FastAPI server on a random local port. Playwright +navigates to `/playground/`. The browser's `fetch("/scenarios/report-incident")` +call routes to the real backend, which formats an ITSM payload and dispatches it +via `http_generic` to `https://httpbin.org/post` — a public echo endpoint. No +`page.route()` interception. + +The form is pre-filled in the HTML with a sample incident (title, description, +severity, component, reporter) — no field changes or dropdown selections are +needed before clicking run. + +## Running locally + +```bash +# Install Playwright browsers (once) +uv run python -m playwright install chromium + +# Run all HTTP connector tests +uv run pytest tests/playground/http_connector/ --no-cov -v + +# Run headed (watch the browser) +PLAYGROUND_HEADED=true uv run pytest tests/playground/http_connector/ --no-cov -v -s +``` + +> **Note:** HTTP connector tests are excluded from the default `uv run pytest` run and +> from regular CI (push/PR). They must be triggered explicitly. + +## Required environment variables + +No connector credentials are needed — `http_generic` dispatches to the public +`httpbin.org` endpoint. + +| Variable | Description | +|----------|-------------| +| `NW_REST_AUTH_DISABLED` | Set to `true` to skip REST auth middleware | + +## CI / GitHub Actions + +HTTP connector tests run **only on manual `workflow_dispatch`** trigger +(`Actions → CI – Pytest → Run workflow`). + +No secrets are required for this connector beyond the standard auth bypass flag. + +## Test data and cleanup + +All requests are sent to `https://httpbin.org/post`, which echoes the payload +and discards it. No records are persisted anywhere. No cleanup is required after +the session. diff --git a/tests/playground/http_connector/conftest.py b/tests/playground/http_connector/conftest.py new file mode 100644 index 0000000..e9ae907 --- /dev/null +++ b/tests/playground/http_connector/conftest.py @@ -0,0 +1,33 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import httpx +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def http_connector_available(api_server_url: str) -> None: + """Skip the entire HTTP connector test session if the connector returns HTTP 500. + + This happens when NW_ALLOWED_CONNECTORS is set but does not include 'http_generic'. + """ + with httpx.Client(timeout=15) as client: + resp = client.post( + f"{api_server_url}/scenarios/report-incident", + json={ + "title": "health-check", + "severity": "HIGH", + "component": "Gateway Proxy", + "description": "health-check", + "reported_by": "DevOps Team Alpha", + }, + ) + if resp.status_code == 500: + detail = resp.json().get("detail", "unknown") + pytest.skip( + f"HTTP connector not available ({detail}). " + "Ensure 'http_generic' is in NW_ALLOWED_CONNECTORS (or leave it unset)." + ) diff --git a/tests/playground/http_connector/http_connector_page.py b/tests/playground/http_connector/http_connector_page.py new file mode 100644 index 0000000..1d08758 --- /dev/null +++ b/tests/playground/http_connector/http_connector_page.py @@ -0,0 +1,41 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from playwright.sync_api import Page, Locator + + +class HttpConnectorPage: + """Page Object Model for the HTTP connector (IT Ops) panel in the Playground.""" + + def __init__(self, page: Page) -> None: + self.page = page + + # Selector for the HTTP connector card inside system connectors view + self.connector_card: Locator = page.locator(".connector-card[data-mode='itops']") + + # Panel root and header + self.panel: Locator = page.locator("#itops-panel") + self.title: Locator = page.locator("#itops-panel .card-title h2") + self.run_btn: Locator = page.locator("#itops-run-btn") + self.back_to_connectors: Locator = page.locator("#back-to-connectors") + + # Output and log elements + self.final_result: Locator = page.locator("#final-result") + self.summary_text: Locator = page.locator("#human-summary") + self.result_tag: Locator = page.locator("#result-id") + self.log_terminal: Locator = page.locator("#log-terminal") + + def navigate_to_panel(self) -> None: + """Click the HTTP connector card in system connectors to open the panel.""" + self.connector_card.click() + + def submit(self) -> None: + """Submit the form to execute the HTTP connector workflow.""" + self.run_btn.click() + + def go_back(self) -> None: + """Click 'Back to All Connectors' to return to connectors selection view.""" + self.back_to_connectors.click() diff --git a/tests/playground/http_connector/test_http_connector_integration.py b/tests/playground/http_connector/test_http_connector_integration.py new file mode 100644 index 0000000..12398ea --- /dev/null +++ b/tests/playground/http_connector/test_http_connector_integration.py @@ -0,0 +1,46 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""HTTP connector Playground real integration tests. + +Each test opens the Playground UI, navigates to the HTTP connector (IT Ops) +panel, clicks the run button with pre-filled defaults, and asserts the +resulting pipeline state — no API mocking, real HTTP calls to httpbin.org. + +No credentials required; http_generic uses a public endpoint. +""" + +from __future__ import annotations + +from playwright.sync_api import Page, expect + +from tests.playground.http_connector.http_connector_page import HttpConnectorPage +from tests.playground.home_page import PlaygroundHomePage +from tests.playground.utils import maybe_sleep + +_TIMEOUT = 20_000 # ms — 4-step pipeline with httpbin.org calls + + +def _navigate_to_http_connector(page: Page) -> HttpConnectorPage: + PlaygroundHomePage(page).click_connectors() + http = HttpConnectorPage(page) + http.navigate_to_panel() + return http + + +def test_http_connector_submit_incident_default(playground_page: Page) -> None: + """Submit an IT incident with default pre-filled values; all 4 steps must succeed.""" + http = _navigate_to_http_connector(playground_page) + http.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(http.final_result).to_be_visible(timeout=_TIMEOUT) + expect(http.summary_text).to_contain_text("IT Incident") + expect(http.result_tag).to_be_visible() + expect(playground_page.locator("#itops-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(http.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() From 1a5c56387da42577b770ec1d2ffe1a813997bd71 Mon Sep 17 00:00:00 2001 From: kesav-aot Date: Tue, 2 Jun 2026 10:20:47 +0530 Subject: [PATCH 55/60] Add Grafana telemetry and packaging notes Document how to start the Grafana stack for telemetry and add a Build Packages (Wheels) section to the README that points to scripts/build-packages.sh and docs/packaging.md. Also add a link to the Grafana telemetry docs in the docs list and clean up heading order. In docs/packaging.md move the pip prerequisites up into the Python package build lifecycle and call out using bash scripts/build-packages.sh --help for usage. --- README.md | 28 +++++++++++++++++++++++++--- docs/packaging.md | 3 ++- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 1c1bc4d..486046f 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,16 @@ copy sample.env .env ``` *(Edit `.env` and set `NW_ALLOWED_CONNECTORS=http_generic` or others)* -### 3. Run +### 3. Run Grafana/OpenTelemetry (optional) + +For telemetry visualization, start the Grafana stack before running the application: + +```bash +cd grafana && docker compose up -d +``` + + +### 4. Run **Bash (Linux/macOS):** ```bash # Using uv (recommended) @@ -63,11 +72,24 @@ $env:MODE="API"; python -m bindings_entrypoint Open [http://localhost:8000/docs](http://localhost:8000/docs) to see the Swagger UI. -## Playground +### 5. Playground + The platform includes an interactive web playground at [http://localhost:8000/playground/](http://localhost:8000/playground/) (available when the REST API is running). --- +## Build Packages (Wheels) + +Before building Docker images, build the Python packages as binary wheels: + +```bash +bash scripts/build-packages.sh +``` + +See [docs/packaging.md](docs/packaging.md) for details on the wheel build lifecycle. + +--- + ## Build MCP Server Images Use this workflow when you want Docker images for the individual MCP servers such as Google Drive, SMTP, Stripe, Salesforce, or Slack. @@ -176,6 +198,7 @@ For more detailed information, please refer to the following guides: - **[MCP Servers & Docker](docs/mcp-servers.md)** — Deploying individual connectors as MCP servers. - **[Packaging & Publishing](docs/packaging.md)** — Wheel builds and CI flow. - **[Code Quality & Compliance](docs/code-quality-compliance.md)** — Ruff, Mypy, pre-commit, REUSE, and dependency compliance. +- **[Telemetry (Grafana)](grafana/README.md)** — Grafana + Loki for telemetry visualization. ## Developer docs @@ -186,7 +209,6 @@ For more detailed information, please refer to the following guides: --- - ## License This project is licensed under the Apache License 2.0. diff --git a/docs/packaging.md b/docs/packaging.md index bb4efc4..a9c176d 100644 --- a/docs/packaging.md +++ b/docs/packaging.md @@ -28,6 +28,8 @@ Each connector's `pyproject.toml` lives at `packages/connectors//pyproject ## Python package build lifecycle +Prerequisites: `pip install build cython wheel` (and a usable `python` on the host). Run `bash scripts/build-packages.sh --help` for usage. + ### Build all packages (default) ```bash @@ -36,7 +38,6 @@ bash scripts/build-packages.sh Default mode builds each of the **seven** known package paths (see inventory above): `python -m build --wheel` on the **host**, then again inside **Docker** (`python:3.12-slim`) so you get Linux-tagged wheels suitable for containers. **Docker must be installed and the daemon running.** After each package, the script scans every produced wheel and fails if any `.py` file appears inside the archive. -Prerequisites: `pip install build cython wheel` (and a usable `python` on the host). Run `bash scripts/build-packages.sh --help` for usage. ### Artifact layout and safe command usage From 3d0af73f174864ccfc6565a86a784eeac7c63277 Mon Sep 17 00:00:00 2001 From: kesav-aot Date: Tue, 2 Jun 2026 10:24:35 +0530 Subject: [PATCH 56/60] Update README.md --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 486046f..c6b787a 100644 --- a/README.md +++ b/README.md @@ -198,8 +198,6 @@ For more detailed information, please refer to the following guides: - **[MCP Servers & Docker](docs/mcp-servers.md)** — Deploying individual connectors as MCP servers. - **[Packaging & Publishing](docs/packaging.md)** — Wheel builds and CI flow. - **[Code Quality & Compliance](docs/code-quality-compliance.md)** — Ruff, Mypy, pre-commit, REUSE, and dependency compliance. -- **[Telemetry (Grafana)](grafana/README.md)** — Grafana + Loki for telemetry visualization. - ## Developer docs - Individual connector MCP servers (ToolHive): [docs/mcp-servers.md](docs/mcp-servers.md) From ac022c2736be329a7dbc8b323f3e84e1d16354d9 Mon Sep 17 00:00:00 2001 From: kesav-aot Date: Tue, 2 Jun 2026 11:06:21 +0530 Subject: [PATCH 57/60] Update local-packages-to-images.md --- docs/local-packages-to-images.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/local-packages-to-images.md b/docs/local-packages-to-images.md index 2a369c4..b66361a 100644 --- a/docs/local-packages-to-images.md +++ b/docs/local-packages-to-images.md @@ -24,9 +24,7 @@ python -m pip install --upgrade build cython wheel Run all commands from the repository root: -```bash -cd /path/to/vinaayakh-node-wire -``` + --- From c12f0bc5beede54b8b12284e213f658368e7b696 Mon Sep 17 00:00:00 2001 From: kesav-aot Date: Thu, 4 Jun 2026 12:44:16 +0530 Subject: [PATCH 58/60] Add per-service build and NW_ALLOWED_CONNECTORS Update docker-compose.mcp.yml to add build context/dockerfile for each MCP service and set NW_ALLOWED_CONNECTORS per service. Update README.md and docs/mcp-servers.md to recommend running docker compose with --build and to note that each service pins its allowed connector so a broad .env value won't cause per-connector images to import optional dependencies they don't contain. This ensures local wheels are built and images are constrained to their connector-specific deps for local validation. --- README.md | 7 ++++--- docker-compose.mcp.yml | 36 +++++++++++++++++++++++++++++++++++- docs/mcp-servers.md | 7 ++++--- 3 files changed, 43 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 1c1bc4d..4718fed 100644 --- a/README.md +++ b/README.md @@ -149,16 +149,17 @@ Before starting the MCP containers, make sure: - Your `.env` file is populated with the credentials needed by the connectors you want to run. `docker-compose.mcp.yml` starts all MCP servers as stdio containers in one command. This is useful for local validation before configuring ToolHive. +Each service pins `NW_ALLOWED_CONNECTORS` to its own connector so a broad value in `.env` does not make per-connector images import optional dependencies they do not contain. ```bash -# Ensure your .env is populated, then: -docker compose -f docker-compose.mcp.yml up +# Ensure local wheels exist and your .env is populated, then: +docker compose -f docker-compose.mcp.yml up --build ``` To start only a specific server: ```bash -docker compose -f docker-compose.mcp.yml up nw-smartonfhir-epic +docker compose -f docker-compose.mcp.yml up --build nw-smartonfhir-epic ``` --- diff --git a/docker-compose.mcp.yml b/docker-compose.mcp.yml index 657a190..0af631d 100644 --- a/docker-compose.mcp.yml +++ b/docker-compose.mcp.yml @@ -5,50 +5,84 @@ services: nw-google-drive: image: nw-google-drive:latest + build: + context: . + dockerfile: docker/google-drive/Dockerfile env_file: .env + environment: + NW_ALLOWED_CONNECTORS: google_drive stdin_open: true tty: true restart: unless-stopped nw-smartonfhir-epic: image: nw-smartonfhir-epic:latest + build: + context: . + dockerfile: docker/fhir-epic/Dockerfile env_file: .env + environment: + NW_ALLOWED_CONNECTORS: fhir_epic stdin_open: true tty: true restart: unless-stopped nw-smartonfhir-cerner: image: nw-smartonfhir-cerner:latest + build: + context: . + dockerfile: docker/fhir-cerner/Dockerfile env_file: .env + environment: + NW_ALLOWED_CONNECTORS: fhir_cerner stdin_open: true tty: true restart: unless-stopped nw-smtp: image: nw-smtp:latest + build: + context: . + dockerfile: docker/smtp/Dockerfile env_file: .env + environment: + NW_ALLOWED_CONNECTORS: smtp stdin_open: true tty: true restart: unless-stopped nw-stripe: image: nw-stripe:latest + build: + context: . + dockerfile: docker/stripe/Dockerfile env_file: .env + environment: + NW_ALLOWED_CONNECTORS: stripe stdin_open: true tty: true restart: unless-stopped nw-salesforce: image: nw-salesforce:latest + build: + context: . + dockerfile: docker/salesforce/Dockerfile env_file: .env + environment: + NW_ALLOWED_CONNECTORS: salesforce stdin_open: true tty: true restart: unless-stopped nw-slack: image: nw-slack:latest + build: + context: . + dockerfile: docker/slack/Dockerfile env_file: .env + environment: + NW_ALLOWED_CONNECTORS: slack stdin_open: true tty: true restart: unless-stopped - diff --git a/docs/mcp-servers.md b/docs/mcp-servers.md index 8aabc7d..3c5b459 100644 --- a/docs/mcp-servers.md +++ b/docs/mcp-servers.md @@ -473,16 +473,17 @@ docker build -f docker/slack/Dockerfile -t nw-slack:latest . ## Run with docker-compose `docker-compose.mcp.yml` starts all MCP servers as stdio containers in one command. This is useful for local validation before configuring ToolHive. +Each service pins `NW_ALLOWED_CONNECTORS` to its own connector so a broad value in `.env` does not make per-connector images import optional dependencies they do not contain. ```bash -# Ensure your .env is populated, then: -docker compose -f docker-compose.mcp.yml up +# Ensure local wheels exist and your .env is populated, then: +docker compose -f docker-compose.mcp.yml up --build ``` To start only a specific server: ```bash -docker compose -f docker-compose.mcp.yml up nw-smartonfhir-epic +docker compose -f docker-compose.mcp.yml up --build nw-smartonfhir-epic ``` --- From b7b96b041275ac1400010ce2ba04e6736df9259d Mon Sep 17 00:00:00 2001 From: My Name Date: Thu, 11 Jun 2026 17:49:29 +0530 Subject: [PATCH 59/60] MCP authentication flag has inverted and updated --- docs/configuration.md | 1 + docs/mcp-servers.md | 2 +- sample.env | 4 +- src/bindings/mcp_server/auth.py | 40 +++++++++++++++-- src/bindings/mcp_server/server.py | 7 ++- tests/conftest.py | 4 +- tests/test_mcp_auth.py | 71 +++++++++++++++++++++++++------ tests/test_mcp_transport.py | 2 + 8 files changed, 108 insertions(+), 23 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 936fb46..b27645a 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -41,6 +41,7 @@ copy sample.env .env | `PORT` | Port for the REST API | `8000` | | `NW_MCP_TRANSPORT` | MCP transport mode (`stdio` or `streamable-http`) | `stdio` | | `NW_MCP_PORT` | Port for streamable-http MCP | `8080` | +| `NW_MCP_AUTH_DISABLED` | Disable MCP authentication (local dev only) | `false` | | `NW_REST_AUTH_DISABLED` | Disable REST API authentication (local dev only) | `false` | --- diff --git a/docs/mcp-servers.md b/docs/mcp-servers.md index 3c5b459..03fe2c5 100644 --- a/docs/mcp-servers.md +++ b/docs/mcp-servers.md @@ -163,7 +163,7 @@ When running in `streamable-http` mode, clients must comply with the strict MCP Use these settings for production-style posture: ```env -NW_MCP_AUTH_ENABLED=false +NW_MCP_AUTH_DISABLED=false NW_MCP_SCOPE_POLICY_DEFAULT=deny # Optional guardrail: fail startup if scope policy would be disabled NW_MCP_SCOPE_POLICY_STRICT=true diff --git a/sample.env b/sample.env index 4db2b65..5f2537b 100644 --- a/sample.env +++ b/sample.env @@ -89,9 +89,9 @@ GEMINI_MODEL=gemini-2.5-flash ANTHROPIC_API_KEY=your-anthropic-api-key ANTHROPIC_MODEL=claude-3-5-haiku-20241022 -# MCP auth — NW_MCP_AUTH_ENABLED=true means auth is **disabled** (legacy naming; local dev). +# MCP auth — set NW_MCP_AUTH_DISABLED=true only for local development (matches NW_REST_AUTH_DISABLED). # For production, omit it or set false so MCP auth is enforced when NW_MCP_API_KEY / JWT is set. -NW_MCP_AUTH_ENABLED=true +NW_MCP_AUTH_DISABLED=true NW_MCP_API_KEY=replace-with-strong-random-value # API key scopes (JSON array or space/comma-separated). Empty = no scopes; use "*" only for explicit full access. # Wildcard API keys intentionally bypass per-action scope checks. diff --git a/src/bindings/mcp_server/auth.py b/src/bindings/mcp_server/auth.py index d94dd09..e426432 100644 --- a/src/bindings/mcp_server/auth.py +++ b/src/bindings/mcp_server/auth.py @@ -1,3 +1,12 @@ +""" +MCP authentication (enterprise default: required API key or JWT). + +Environment: + NW_MCP_API_KEY — shared secret; send as ``Authorization: Bearer `` or ``X-API-Key: ``. + NW_MCP_JWT_SECRET — optional HS256 secret; if set, Bearer tokens with three segments are verified as JWTs. + NW_MCP_AUTH_DISABLED — if ``true``/``1``/``yes``, skip auth (local dev only; do not use in production). +""" + from __future__ import annotations import os @@ -77,7 +86,7 @@ def __init__(self) -> None: super().__init__( ( "MCP authentication is not configured. Set NW_MCP_API_KEY " - "(and optionally NW_MCP_JWT_SECRET), or set NW_MCP_AUTH_ENABLED=true " + "(and optionally NW_MCP_JWT_SECRET), or set NW_MCP_AUTH_DISABLED=true " "for local development only." ), status_code=503, @@ -101,7 +110,7 @@ def _bootstrap_mcp_auth_env() -> None: # Align with REST/bindings: when dotenv merge is disabled (pytest, CI, prod), # never load repo `.env` with override=True — that stomps conftest env and - # monkeypatched values (e.g. NW_ALLOWED_CONNECTORS, NW_MCP_AUTH_ENABLED). + # monkeypatched values (e.g. NW_ALLOWED_CONNECTORS, NW_MCP_AUTH_DISABLED). rest_dotenv = os.environ.get("NW_REST_LOAD_DOTENV", "true").lower() if rest_dotenv in ("0", "false", "no"): # Keys may be injected later (tests); do not mark bootstrapped so we recheck. @@ -114,7 +123,19 @@ def _bootstrap_mcp_auth_env() -> None: def mcp_auth_disabled() -> bool: - return _truthy(os.environ.get("NW_MCP_AUTH_ENABLED")) + disabled = os.environ.get("NW_MCP_AUTH_DISABLED") + if disabled is not None: + return _truthy(disabled) + + legacy_enabled = os.environ.get("NW_MCP_AUTH_ENABLED") + if legacy_enabled is not None: + logger.warning( + "NW_MCP_AUTH_ENABLED is deprecated; use NW_MCP_AUTH_DISABLED instead " + "(true disables auth). NW_MCP_AUTH_ENABLED will be removed in a future release." + ) + return _truthy(legacy_enabled) + + return False def mcp_auth_configured() -> bool: @@ -122,6 +143,19 @@ def mcp_auth_configured() -> bool: return bool(os.environ.get("NW_MCP_API_KEY") or os.environ.get("NW_MCP_JWT_SECRET")) +def log_mcp_auth_startup_state() -> None: + """Log effective MCP auth posture once at server startup.""" + _bootstrap_mcp_auth_env() + disabled = mcp_auth_disabled() + configured = mcp_auth_configured() + state = "disabled" if disabled else "enabled" + logger.info("MCP authentication %s (configured=%s)", state, configured) + if disabled: + logger.warning( + "NW_MCP_AUTH_DISABLED is set — MCP auth is OFF; do not use in production" + ) + + def _get_meta_value(meta: Mapping[str, Any] | None, keys: tuple[str, ...]) -> str | None: if not meta: return None diff --git a/src/bindings/mcp_server/server.py b/src/bindings/mcp_server/server.py index fd2804a..2ecc4ed 100644 --- a/src/bindings/mcp_server/server.py +++ b/src/bindings/mcp_server/server.py @@ -13,7 +13,11 @@ from typing import Any, Dict, List, Mapping, Optional, Tuple from bindings.factory import ConnectorFactory -from bindings.mcp_server.auth import McpAuthError, authenticate_mcp_request +from bindings.mcp_server.auth import ( + McpAuthError, + authenticate_mcp_request, + log_mcp_auth_startup_state, +) from node_wire_runtime.caller_identity import CallerIdentity from node_wire_runtime.policies.mcp_scope_policy import ( action_allowed_for_identity_scopes, @@ -130,6 +134,7 @@ def __init__( MCP_MANIFEST_CONTRACT_VERSION, _pkg_ver, ) + log_mcp_auth_startup_state() def list_tools(self, *, identity: CallerIdentity | None = None) -> List[Dict[str, Any]]: identity = self._ensure_identity(identity=identity) diff --git a/tests/conftest.py b/tests/conftest.py index b127047..e61b03d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ """Shared pytest configuration. REST API tests default to ``NW_REST_AUTH_DISABLED=true`` so existing tests do not need -headers. MCP tests default to ``NW_MCP_AUTH_ENABLED=true`` for the same reason. +headers. MCP tests default to ``NW_MCP_AUTH_DISABLED=true`` for the same reason. Tests that assert authentication behavior override these env vars. """ @@ -63,7 +63,7 @@ def _preload_connector_logic_modules() -> None: @pytest.fixture(autouse=True) def _rest_auth_disabled_for_tests(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("NW_REST_AUTH_DISABLED", "true") - monkeypatch.setenv("NW_MCP_AUTH_ENABLED", "true") + monkeypatch.setenv("NW_MCP_AUTH_DISABLED", "true") monkeypatch.setenv("NW_RATE_LIMIT_BURST", "1000") # Increase for tests monkeypatch.setenv("NW_RATE_LIMIT_REFILL_RATE", "100.0") # Increase for tests monkeypatch.setenv("NW_RATE_LIMIT_DISABLED", "true") # Disable rate limiting for tests diff --git a/tests/test_mcp_auth.py b/tests/test_mcp_auth.py index 4689678..65a3f80 100644 --- a/tests/test_mcp_auth.py +++ b/tests/test_mcp_auth.py @@ -28,7 +28,7 @@ def _mcp_auth_clear_allowlist_from_host_env(monkeypatch: pytest.MonkeyPatch) -> def test_mcp_auth_missing_token_returns_401(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) @@ -39,7 +39,7 @@ def test_mcp_auth_missing_token_returns_401(monkeypatch: pytest.MonkeyPatch) -> def test_mcp_auth_invalid_token_returns_403(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) @@ -50,7 +50,7 @@ def test_mcp_auth_invalid_token_returns_403(monkeypatch: pytest.MonkeyPatch) -> def test_mcp_auth_valid_token_allows_tools_list(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) @@ -64,7 +64,7 @@ def test_mcp_auth_valid_token_allows_tools_list(monkeypatch: pytest.MonkeyPatch) @pytest.mark.asyncio async def test_mcp_authz_denies_tool_without_scope(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) monkeypatch.delenv("NW_MCP_API_KEY", raising=False) monkeypatch.setenv("NW_MCP_JWT_SECRET", "jwt-secret") monkeypatch.setenv( @@ -101,7 +101,7 @@ async def test_mcp_authz_denies_tool_without_scope(monkeypatch: pytest.MonkeyPat async def test_mcp_execution_passes_principal_and_tenant( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) monkeypatch.delenv("NW_MCP_API_KEY", raising=False) monkeypatch.setenv("NW_MCP_JWT_SECRET", "jwt-secret") monkeypatch.delenv("NW_MCP_ACTION_SCOPE_MAP_JSON", raising=False) @@ -151,7 +151,7 @@ async def fake_run(raw_input, *, principal=None, tenant_id=None, scopes=None): def test_mcp_api_key_scopes_filter_tools_list(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") monkeypatch.setenv( "NW_MCP_ACTION_SCOPE_MAP_JSON", @@ -169,7 +169,7 @@ def test_mcp_api_key_scopes_filter_tools_list(monkeypatch: pytest.MonkeyPatch) - def test_mcp_jwt_scopes_filter_tools_list(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) monkeypatch.delenv("NW_MCP_API_KEY", raising=False) monkeypatch.setenv("NW_MCP_JWT_SECRET", "jwt-secret") monkeypatch.setenv( @@ -192,7 +192,7 @@ def test_mcp_jwt_scopes_filter_tools_list(monkeypatch: pytest.MonkeyPatch) -> No async def test_mcp_default_deny_fallback_scope_invokes_tool( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) monkeypatch.delenv("NW_MCP_API_KEY", raising=False) monkeypatch.setenv("NW_MCP_JWT_SECRET", "jwt-secret") monkeypatch.delenv("NW_MCP_ACTION_SCOPE_MAP_JSON", raising=False) @@ -241,7 +241,7 @@ async def fake_run(raw_input, *, principal=None, tenant_id=None, scopes=None): async def test_mcp_default_deny_denies_without_fallback_scope( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) monkeypatch.delenv("NW_MCP_API_KEY", raising=False) monkeypatch.setenv("NW_MCP_JWT_SECRET", "jwt-secret") monkeypatch.delenv("NW_MCP_ACTION_SCOPE_MAP_JSON", raising=False) @@ -273,7 +273,7 @@ async def test_mcp_default_deny_denies_without_fallback_scope( def test_mcp_api_key_explicit_star_scope_lists_tool(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") monkeypatch.setenv( "NW_MCP_ACTION_SCOPE_MAP_JSON", @@ -298,7 +298,7 @@ async def handle_request(self, scope, receive, send): def test_streamable_http_edge_auth_rejects_missing_token(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) @@ -315,7 +315,7 @@ def test_streamable_http_edge_auth_rejects_missing_token(monkeypatch: pytest.Mon def test_streamable_http_edge_auth_rejects_invalid_token(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) @@ -336,7 +336,7 @@ def test_streamable_http_edge_auth_rejects_invalid_token(monkeypatch: pytest.Mon def test_streamable_http_edge_auth_accepts_valid_token(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) @@ -360,7 +360,7 @@ def test_streamable_http_edge_auth_accepts_valid_token(monkeypatch: pytest.Monke async def test_streamable_http_identity_context_is_used_by_mcp_server( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) @@ -378,3 +378,46 @@ async def test_streamable_http_identity_context_is_used_by_mcp_server( assert resolved is not None assert resolved.principal == "api-key-user" + + +def test_mcp_auth_enforced_when_not_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + with pytest.raises(McpAuthRequiredError): + authenticate_mcp_request() + + +def test_mcp_auth_enforced_when_disabled_false(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_MCP_AUTH_DISABLED", "false") + monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + with pytest.raises(McpAuthRequiredError): + authenticate_mcp_request() + + +def test_mcp_auth_skipped_when_disabled_true(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_MCP_AUTH_DISABLED", "true") + monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + assert authenticate_mcp_request() is None + + +def test_mcp_auth_legacy_enabled_env_disables_with_warning( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_MCP_AUTH_ENABLED", "true") + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + with caplog.at_level("WARNING"): + assert authenticate_mcp_request() is None + + assert "NW_MCP_AUTH_ENABLED is deprecated" in caplog.text diff --git a/tests/test_mcp_transport.py b/tests/test_mcp_transport.py index 74784e2..3f9a93b 100644 --- a/tests/test_mcp_transport.py +++ b/tests/test_mcp_transport.py @@ -149,6 +149,7 @@ async def test_mcp_http_tools_list_success(): @pytest.mark.anyio async def test_mcp_http_tools_list_accepts_authorization_header(monkeypatch): monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) @@ -208,6 +209,7 @@ async def test_mcp_http_tools_list_accepts_authorization_header(monkeypatch): @pytest.mark.anyio async def test_mcp_http_tools_list_accepts_x_api_key_header(monkeypatch): monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) From 4cd5805e29de4325167d99a6fb952acfa101c2e0 Mon Sep 17 00:00:00 2001 From: My Name Date: Fri, 12 Jun 2026 09:18:14 +0530 Subject: [PATCH 60/60] linitng issue resolved --- src/bindings/mcp_server/auth.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/bindings/mcp_server/auth.py b/src/bindings/mcp_server/auth.py index e426432..9a69ec9 100644 --- a/src/bindings/mcp_server/auth.py +++ b/src/bindings/mcp_server/auth.py @@ -151,9 +151,7 @@ def log_mcp_auth_startup_state() -> None: state = "disabled" if disabled else "enabled" logger.info("MCP authentication %s (configured=%s)", state, configured) if disabled: - logger.warning( - "NW_MCP_AUTH_DISABLED is set — MCP auth is OFF; do not use in production" - ) + logger.warning("NW_MCP_AUTH_DISABLED is set — MCP auth is OFF; do not use in production") def _get_meta_value(meta: Mapping[str, Any] | None, keys: tuple[str, ...]) -> str | None: