From a7deb3904b73f21775309796a2b1f04a0a240e20 Mon Sep 17 00:00:00 2001 From: Windsor Date: Tue, 10 Feb 2026 15:07:35 -0800 Subject: [PATCH 1/3] fix(_compat): remove duplicate by_alias keyword arg Stainless regen (2026-02-07) added `by_alias=by_alias` but a manually added `by_alias=True` from Nov 2025 was still present. Duplicate keyword arguments are a SyntaxError on Python 3.9. --- src/dedalus_labs/_compat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dedalus_labs/_compat.py b/src/dedalus_labs/_compat.py index bf1f1b9..76d017b 100644 --- a/src/dedalus_labs/_compat.py +++ b/src/dedalus_labs/_compat.py @@ -144,7 +144,6 @@ def model_dump( if (not PYDANTIC_V1) or hasattr(model, "model_dump"): return model.model_dump( mode=mode, - by_alias=True, exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, From e3a55da85cc4deed8c5122193084a8a4675818c5 Mon Sep 17 00:00:00 2001 From: Windsor Date: Tue, 10 Feb 2026 15:07:42 -0800 Subject: [PATCH 2/3] feat(stream): return StreamResult from stream_async / stream_sync stream_async and stream_sync now accumulate content and tool-call deltas into a StreamResult dataclass and return it. The accumulation logic is extracted into a shared accumulate_tool_call() function that DedalusRunner.core also delegates to (DRY). Exports StreamResult and accumulate_tool_call from dedalus_labs.utils.stream. Adds tests covering content, tool-call reassembly, thought_signature, parallel calls, empty streams, and result independence. --- src/dedalus_labs/lib/runner/core.py | 25 +--- src/dedalus_labs/lib/utils/_stream.py | 190 ++++++++++++++++--------- src/dedalus_labs/utils/stream.py | 4 +- tests/test_stream_helpers.py | 194 ++++++++++++++++++++++++++ 4 files changed, 321 insertions(+), 92 deletions(-) create mode 100644 tests/test_stream_helpers.py diff --git a/src/dedalus_labs/lib/runner/core.py b/src/dedalus_labs/lib/runner/core.py index 0e44600..a295f09 100644 --- a/src/dedalus_labs/lib/runner/core.py +++ b/src/dedalus_labs/lib/runner/core.py @@ -30,6 +30,7 @@ from .types import Message, ToolCall, JsonValue, ToolResult, PolicyInput, PolicyContext from ..._client import Dedalus, AsyncDedalus from ...types.shared import MCPToolResult +from ..utils._stream import accumulate_tool_call # Type alias for mcp_servers parameter - accepts strings, server objects, or mixed lists MCPServersInput = Union[ @@ -1240,29 +1241,7 @@ def _execute_tool_calls_sync( def _accumulate_tool_calls(self, deltas, acc: list[ToolCall]) -> None: """Accumulate streaming tool call deltas.""" for delta in deltas: - index = getattr(delta, "index", 0) - - # Ensure we have enough entries in acc - while len(acc) <= index: - acc.append( - { - "id": "", - "type": "function", - "function": {"name": "", "arguments": ""}, - } - ) - - if hasattr(delta, "id") and delta.id: - acc[index]["id"] = delta.id - if hasattr(delta, "function"): - fn = delta.function - if hasattr(fn, "name") and fn.name: - acc[index]["function"]["name"] = fn.name - if hasattr(fn, "arguments") and fn.arguments: - acc[index]["function"]["arguments"] += fn.arguments - thought_sig = getattr(delta, "thought_signature", None) - if thought_sig: - acc[index]["thought_signature"] = thought_sig + accumulate_tool_call(acc, delta) @staticmethod def _mk_kwargs(mc: _ModelConfig) -> Dict[str, Any]: diff --git a/src/dedalus_labs/lib/utils/_stream.py b/src/dedalus_labs/lib/utils/_stream.py index ac4c305..302b04b 100644 --- a/src/dedalus_labs/lib/utils/_stream.py +++ b/src/dedalus_labs/lib/utils/_stream.py @@ -7,23 +7,117 @@ from __future__ import annotations import os - +from dataclasses import dataclass, field from collections.abc import AsyncIterator, Iterator -from typing import TYPE_CHECKING +from typing import Any, Dict, List, TYPE_CHECKING if TYPE_CHECKING: from ...types.chat.stream_chunk import StreamChunk __all__ = [ + "StreamResult", + "accumulate_tool_call", "stream_async", "stream_sync", ] -async def stream_async(stream: AsyncIterator[StreamChunk] | object) -> None: +@dataclass +class StreamResult: + """Collected data from a consumed stream. + + Returned by :func:`stream_async` and :func:`stream_sync` so callers + can inspect what happened after the stream is exhausted. + """ + + content: str = "" + """Concatenated text content from all ``delta.content`` fragments.""" + + tool_calls: List[Dict[str, Any]] = field(default_factory=list) + """Reassembled tool calls accumulated from streaming deltas. + + Each entry has the shape ``{"id": str, "function": {"name": str, "arguments": str}}``. + """ + + +def accumulate_tool_call( + tool_calls: List[Dict[str, Any]], + tc_delta: object, +) -> None: + """Accumulate a single streaming tool-call delta into *tool_calls*. + + Reassembles the incremental fragments (id, function name, argument + chunks, thought_signature) that arrive across multiple SSE chunks + into complete tool-call dicts keyed by ``tc_delta.index``. + + This is the canonical implementation shared by both the stream + helpers and :class:`~dedalus_labs.lib.runner.core.DedalusRunner`. + """ + idx: int = getattr(tc_delta, "index", 0) + while len(tool_calls) <= idx: + tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}}) + entry = tool_calls[idx] + + tc_id = getattr(tc_delta, "id", None) + if tc_id: + entry["id"] = tc_id + + fn = getattr(tc_delta, "function", None) + if fn is not None: + if getattr(fn, "name", None): + entry["function"]["name"] = fn.name + if getattr(fn, "arguments", None): + entry["function"]["arguments"] += fn.arguments + + thought_sig = getattr(tc_delta, "thought_signature", None) + if thought_sig: + entry["thought_signature"] = thought_sig + + +def _process_chunk( + chunk: object, + result: StreamResult, + verbose: bool, +) -> None: + """Extract content and tool-call deltas from a single StreamChunk.""" + if verbose: + extra = getattr(chunk, "__pydantic_extra__", None) + if isinstance(extra, dict): + meta = extra.get("dedalus_event") + if isinstance(meta, dict): + print(f"\n[EVENT] {meta}") + + choices = getattr(chunk, "choices", None) + if not choices: + return + + choice = choices[0] + delta = choice.delta + + # Tool-call deltas + for tc in getattr(delta, "tool_calls", None) or []: + if verbose: + fn = getattr(tc, "function", None) + print(f"\n[TOOL_CALL] name={getattr(fn, 'name', None)} id={getattr(tc, 'id', None)}") + accumulate_tool_call(result.tool_calls, tc) + + # Content + if delta.content: + print(delta.content, end="", flush=True) + result.content += delta.content + + # Finish reason (verbose-only) + if verbose and getattr(choice, "finish_reason", None): + print(f"\n[FINISH] reason={choice.finish_reason}") + + +async def stream_async(stream: AsyncIterator[StreamChunk] | object) -> StreamResult: """Stream text content from an async streaming response. + Prints content tokens to stdout as they arrive **and** returns a + :class:`StreamResult` with the accumulated content and tool calls. + Supports both: - Raw StreamChunk iterator from .create(stream=True) or DedalusRunner.run(stream=True) - ChatCompletionStreamManager from .stream() (Pydantic models with event API) @@ -31,16 +125,16 @@ async def stream_async(stream: AsyncIterator[StreamChunk] | object) -> None: Args: stream: An async iterator of StreamChunk or a ChatCompletionStreamManager - Example: - >>> # With .create(stream=True) - >>> stream = await client.chat.completions.create(stream=True, ...) - >>> await stream_async(stream) + Returns: + A :class:`StreamResult` containing the full content and any tool calls. - >>> # With .stream() (Pydantic models) - >>> stream = client.chat.completions.stream(response_format=Model, ...) - >>> await stream_async(stream) + Example: + >>> stream = runner.run(input="...", model="...", stream=True) + >>> result = await stream_async(stream) + >>> print(result.tool_calls) """ verbose = os.environ.get("DEDALUS_SDK_VERBOSE", "").lower() in ("1", "true", "yes", "on", "debug") + result = StreamResult() # Stream manager (event API) vs raw AsyncStream: discriminate via __aenter__ without __aiter__ if hasattr(stream, "__aenter__") and not hasattr(stream, "__aiter__"): @@ -48,43 +142,25 @@ async def stream_async(stream: AsyncIterator[StreamChunk] | object) -> None: async for event in event_stream: if event.type == "content.delta": print(event.delta, end="", flush=True) + result.content += event.delta elif verbose and event.type == "content.done" and hasattr(event, "parsed") and event.parsed: print(f"\n[PARSED] {type(event.parsed).__name__}") print() # Final newline - return + return result # Simple StreamChunk iterator case async for chunk in stream: - # Print server-side metadata events if present (verbose-only) - if verbose: - extra = getattr(chunk, "__pydantic_extra__", None) - if isinstance(extra, dict): - meta = extra.get("dedalus_event") - if isinstance(meta, dict): - print(f"\n[EVENT] {meta}") - - if chunk.choices: - choice = chunk.choices[0] - delta = choice.delta - # Print tool-call deltas as debug (verbose-only) - if verbose and getattr(delta, "tool_calls", None): - for tc in delta.tool_calls: - fn = getattr(tc, "function", None) - name = getattr(fn, "name", None) - tcid = getattr(tc, "id", None) - print(f"\n[TOOL_CALL] name={name} id={tcid}") - # Always print content - if delta.content: - print(delta.content, end="", flush=True) - # Print finish reason (verbose-only) - if verbose and getattr(choice, "finish_reason", None): - print(f"\n[FINISH] reason={choice.finish_reason}") + _process_chunk(chunk, result, verbose) print() # Final newline + return result -def stream_sync(stream: Iterator[StreamChunk] | object) -> None: +def stream_sync(stream: Iterator[StreamChunk] | object) -> StreamResult: """Stream text content from a streaming response. + Prints content tokens to stdout as they arrive **and** returns a + :class:`StreamResult` with the accumulated content and tool calls. + Supports both: - Raw StreamChunk iterator from .create(stream=True) or DedalusRunner.run(stream=True) - ChatCompletionStreamManager from .stream() (Pydantic models with event API) @@ -92,16 +168,16 @@ def stream_sync(stream: Iterator[StreamChunk] | object) -> None: Args: stream: An iterator of StreamChunk or a ChatCompletionStreamManager - Example: - >>> # With .create(stream=True) - >>> stream = client.chat.completions.create(stream=True, ...) - >>> stream_sync(stream) + Returns: + A :class:`StreamResult` containing the full content and any tool calls. - >>> # With .stream() (Pydantic models) - >>> stream = client.chat.completions.stream(response_format=Model, ...) - >>> stream_sync(stream) + Example: + >>> stream = runner.run(input="...", model="...", stream=True) + >>> result = stream_sync(stream) + >>> print(result.tool_calls) """ verbose = os.environ.get("DEDALUS_SDK_VERBOSE", "").lower() in ("1", "true", "yes", "on", "debug") + result = StreamResult() # Stream manager (event API) vs raw Stream: discriminate via __enter__ without __iter__ if hasattr(stream, "__enter__") and not hasattr(stream, "__iter__"): @@ -109,34 +185,14 @@ def stream_sync(stream: Iterator[StreamChunk] | object) -> None: for event in event_stream: if event.type == "content.delta": print(event.delta, end="", flush=True) + result.content += event.delta elif verbose and event.type == "content.done" and hasattr(event, "parsed") and event.parsed: print(f"\n[PARSED] {type(event.parsed).__name__}") print() # Final newline - return + return result # Simple StreamChunk iterator case for chunk in stream: - # Print server-side metadata events if present (verbose-only) - if verbose: - extra = getattr(chunk, "__pydantic_extra__", None) - if isinstance(extra, dict): - meta = extra.get("dedalus_event") - if isinstance(meta, dict): - print(f"\n[EVENT] {meta}") - - if chunk.choices: - choice = chunk.choices[0] - delta = choice.delta - # Print tool-call deltas as debug (verbose-only) - if verbose and getattr(delta, "tool_calls", None): - for tc in delta.tool_calls: - fn = getattr(tc, "function", None) - name = getattr(fn, "name", None) - tcid = getattr(tc, "id", None) - print(f"\n[TOOL_CALL] name={name} id={tcid}") - # Always print content - if delta.content: - print(delta.content, end="", flush=True) - if verbose and getattr(choice, "finish_reason", None): - print(f"\n[FINISH] reason={choice.finish_reason}") + _process_chunk(chunk, result, verbose) print() # Final newline + return result diff --git a/src/dedalus_labs/utils/stream.py b/src/dedalus_labs/utils/stream.py index a82ecd0..ec9f9b8 100644 --- a/src/dedalus_labs/utils/stream.py +++ b/src/dedalus_labs/utils/stream.py @@ -6,6 +6,6 @@ """Stream utilities for printing model responses in real-time.""" -from ..lib.utils._stream import stream_async, stream_sync +from ..lib.utils._stream import StreamResult, stream_async, stream_sync -__all__ = ["stream_async", "stream_sync"] +__all__ = ["StreamResult", "stream_async", "stream_sync"] diff --git a/tests/test_stream_helpers.py b/tests/test_stream_helpers.py new file mode 100644 index 0000000..2c2d89a --- /dev/null +++ b/tests/test_stream_helpers.py @@ -0,0 +1,194 @@ +# ============================================================================== +# © 2025 Dedalus Labs, Inc. and affiliates +# Licensed under MIT +# github.com/dedalus-labs/dedalus-sdk-python/LICENSE +# ============================================================================== + +"""Tests for stream helpers: accumulate_tool_call, stream_async, stream_sync.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, AsyncIterator, List + +import pytest + +from dedalus_labs.lib.utils._stream import StreamResult, accumulate_tool_call, stream_async, stream_sync + + +# --- Test helpers --- + + +def _chunk( + *, + content: str | None = None, + tool_calls: List[Any] | None = None, + finish_reason: str | None = None, +) -> SimpleNamespace: + delta = SimpleNamespace(content=content, tool_calls=tool_calls) + choice = SimpleNamespace(delta=delta, finish_reason=finish_reason) + return SimpleNamespace(choices=[choice]) + + +def _tc( + index: int, + *, + tc_id: str | None = None, + name: str | None = None, + arguments: str | None = None, + thought_signature: str | None = None, +) -> SimpleNamespace: + fn = SimpleNamespace(name=name, arguments=arguments) + ns = SimpleNamespace(index=index, id=tc_id, function=fn) + if thought_signature is not None: + ns.thought_signature = thought_signature + return ns + + +async def _aiter(items: list) -> AsyncIterator: + for item in items: + yield item + + +# --- accumulate_tool_call --- + + +def test_accumulate_creates_entry(): + acc: list[dict] = [] + accumulate_tool_call(acc, _tc(0, tc_id="call_1", name="search")) + assert acc == [{"id": "call_1", "type": "function", "function": {"name": "search", "arguments": ""}}] + + +def test_accumulate_appends_arguments(): + acc: list[dict] = [] + accumulate_tool_call(acc, _tc(0, tc_id="c", name="fn")) + accumulate_tool_call(acc, _tc(0, arguments='{"a":')) + accumulate_tool_call(acc, _tc(0, arguments=" 1}")) + assert acc[0]["function"]["arguments"] == '{"a": 1}' + + +def test_accumulate_parallel_indices(): + acc: list[dict] = [] + accumulate_tool_call(acc, _tc(0, tc_id="c0", name="alpha")) + accumulate_tool_call(acc, _tc(1, tc_id="c1", name="beta")) + assert len(acc) == 2 + assert acc[0]["function"]["name"] == "alpha" + assert acc[1]["function"]["name"] == "beta" + + +def test_accumulate_sparse_index_pads(): + acc: list[dict] = [] + accumulate_tool_call(acc, _tc(2, tc_id="c2", name="gamma")) + assert len(acc) == 3 + assert acc[2]["function"]["name"] == "gamma" + assert acc[0]["function"]["name"] == "" + + +def test_accumulate_thought_signature(): + acc: list[dict] = [] + accumulate_tool_call(acc, _tc(0, tc_id="c", name="fn", thought_signature="sig_abc")) + assert acc[0]["thought_signature"] == "sig_abc" + + +def test_accumulate_no_thought_signature(): + acc: list[dict] = [] + accumulate_tool_call(acc, _tc(0, tc_id="c", name="fn")) + assert "thought_signature" not in acc[0] + + +# --- stream_async --- + + +@pytest.mark.asyncio +async def test_async_content_only(capsys: pytest.CaptureFixture[str]): + result = await stream_async(_aiter([_chunk(content="Hello"), _chunk(content=" world")])) + assert isinstance(result, StreamResult) + assert result.content == "Hello world" + assert result.tool_calls == [] + assert "Hello world" in capsys.readouterr().out + + +@pytest.mark.asyncio +async def test_async_empty(): + result = await stream_async(_aiter([])) + assert result.content == "" + assert result.tool_calls == [] + + +@pytest.mark.asyncio +async def test_async_reassembles_tool_call(): + chunks = [ + _chunk(tool_calls=[_tc(0, tc_id="call_abc", name="gmail_send")]), + _chunk(tool_calls=[_tc(0, arguments='{"to":')]), + _chunk(tool_calls=[_tc(0, arguments=' "a@b.com"}')]), + _chunk(content="Sent.", finish_reason="stop"), + ] + result = await stream_async(_aiter(chunks)) + assert result.content == "Sent." + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["id"] == "call_abc" + assert result.tool_calls[0]["type"] == "function" + assert result.tool_calls[0]["function"] == {"name": "gmail_send", "arguments": '{"to": "a@b.com"}'} + + +@pytest.mark.asyncio +async def test_async_parallel_tool_calls(): + chunks = [ + _chunk(tool_calls=[_tc(0, tc_id="c1", name="search"), _tc(1, tc_id="c2", name="fetch")]), + _chunk(tool_calls=[_tc(0, arguments='{"q":"x"}')]), + _chunk(tool_calls=[_tc(1, arguments='{"url":"y"}')]), + ] + result = await stream_async(_aiter(chunks)) + assert len(result.tool_calls) == 2 + assert result.tool_calls[0]["function"]["name"] == "search" + assert result.tool_calls[1]["function"]["name"] == "fetch" + + +@pytest.mark.asyncio +async def test_async_skips_empty_choices(): + chunks = [SimpleNamespace(choices=[]), _chunk(content="ok"), SimpleNamespace(choices=None)] + result = await stream_async(_aiter(chunks)) + assert result.content == "ok" + + +@pytest.mark.asyncio +async def test_async_independent_results(): + """Each call returns a fresh StreamResult.""" + a = await stream_async(_aiter([_chunk(tool_calls=[_tc(0, tc_id="a", name="fa")])])) + b = await stream_async(_aiter([_chunk(tool_calls=[_tc(0, tc_id="b", name="fb")])])) + assert a.tool_calls[0]["id"] == "a" + assert b.tool_calls[0]["id"] == "b" + assert a.tool_calls is not b.tool_calls + + +# --- stream_sync --- + + +def test_sync_content_only(capsys: pytest.CaptureFixture[str]): + result = stream_sync(iter([_chunk(content="Hello"), _chunk(content=" world")])) + assert result.content == "Hello world" + assert result.tool_calls == [] + assert "Hello world" in capsys.readouterr().out + + +def test_sync_empty(): + result = stream_sync(iter([])) + assert result.content == "" + assert result.tool_calls == [] + + +def test_sync_reassembles_tool_call(): + chunks = [ + _chunk(tool_calls=[_tc(0, tc_id="call_abc", name="gmail_send")]), + _chunk(tool_calls=[_tc(0, arguments='{"to": "a@b.com"}')]), + _chunk(content="Sent.", finish_reason="stop"), + ] + result = stream_sync(iter(chunks)) + assert result.tool_calls[0]["function"] == {"name": "gmail_send", "arguments": '{"to": "a@b.com"}'} + assert result.content == "Sent." + + +def test_sync_thought_signature(): + chunks = [_chunk(tool_calls=[_tc(0, tc_id="c", name="fn", thought_signature="sig")])] + result = stream_sync(iter(chunks)) + assert result.tool_calls[0]["thought_signature"] == "sig" From 57dc1818c864459b2e1da71487eb3adf94dbf668 Mon Sep 17 00:00:00 2001 From: Windsor Date: Tue, 10 Feb 2026 15:07:47 -0800 Subject: [PATCH 3/3] style(tests): flatten test classes to module-level functions Standardise hand-written test files (test_encryption, test_mcp_wire, test_mcp_wire_connections, test__bug_report) to match the module-level function style used in test_local_scheduler. Replaces class wrappers and multi-line banners with single-line section comments. --- tests/test__bug_report.py | 230 ++++++++------------ tests/test_encryption.py | 267 ++++++++++------------- tests/test_mcp_wire.py | 327 +++++++++++++---------------- tests/test_mcp_wire_connections.py | 314 ++++++++++----------------- 4 files changed, 453 insertions(+), 685 deletions(-) diff --git a/tests/test__bug_report.py b/tests/test__bug_report.py index a3b8e3e..6e570dd 100644 --- a/tests/test__bug_report.py +++ b/tests/test__bug_report.py @@ -11,186 +11,128 @@ from urllib.parse import parse_qs, urlparse import httpx -import pytest -from dedalus_labs._exceptions import APIError, APIStatusError, BadRequestError +from dedalus_labs._exceptions import APIError, BadRequestError from dedalus_labs.lib._bug_report import generate_bug_report_url, get_bug_report_url_from_error -class TestGenerateBugReportUrl: - """Tests for generate_bug_report_url function.""" +# --- generate_bug_report_url --- - def test_minimal_parameters(self): - """URL generation with no params includes auto-populated system info.""" - url = generate_bug_report_url() - parsed = urlparse(url) - params = parse_qs(parsed.query) +def test_minimal_parameters(): + url = generate_bug_report_url() + parsed = urlparse(url) + params = parse_qs(parsed.query) - assert parsed.netloc == "github.com" - assert parsed.path == "/dedalus-labs/dedalus-sdk-python/issues/new" - assert params["template"] == ["bug-report.yml"] - assert params["component"] == ["Python SDK"] - assert "python_version" in params - assert "platform" in params + assert parsed.netloc == "github.com" + assert parsed.path == "/dedalus-labs/dedalus-sdk-python/issues/new" + assert params["template"] == ["bug-report.yml"] + assert params["component"] == ["Python SDK"] + assert "python_version" in params + assert "platform" in params - def test_all_parameters(self): - """URL generation with all parameters populates fields correctly.""" - url = generate_bug_report_url( - version="0.0.1", - error_type="APIError", - error_message="Connection timeout", - environment="dev", - request_id="req-123", - endpoint="/v1/chat/completions", - method="POST", - ) - parsed = urlparse(url) - params = parse_qs(parsed.query) +def test_all_parameters(): + url = generate_bug_report_url( + version="0.0.1", + error_type="APIError", + error_message="Connection timeout", + environment="dev", + request_id="req-123", + endpoint="/v1/chat/completions", + method="POST", + ) + params = parse_qs(urlparse(url).query) - assert params["version"] == ["0.0.1"] - assert params["error_type"] == ["APIError"] - assert params["actual"] == ["Connection timeout"] - assert params["environment"] == ["dev"] - assert params["notes"][0] == "Request ID: req-123\nEndpoint: POST /v1/chat/completions" + assert params["version"] == ["0.0.1"] + assert params["error_type"] == ["APIError"] + assert params["actual"] == ["Connection timeout"] + assert params["environment"] == ["dev"] + assert params["notes"][0] == "Request ID: req-123\nEndpoint: POST /v1/chat/completions" - def test_request_id_in_notes(self): - """Request ID is included in notes field.""" - url = generate_bug_report_url(request_id="req-abc-123") - parsed = urlparse(url) - params = parse_qs(parsed.query) +def test_request_id_in_notes(): + params = parse_qs(urlparse(generate_bug_report_url(request_id="req-abc-123")).query) + assert "Request ID: req-abc-123" in params["notes"][0] - assert "notes" in params - assert "Request ID: req-abc-123" in params["notes"][0] - def test_custom_template(self): - """Custom template name is respected.""" - url = generate_bug_report_url(template="custom.yml") +def test_custom_template(): + params = parse_qs(urlparse(generate_bug_report_url(template="custom.yml")).query) + assert params["template"] == ["custom.yml"] - parsed = urlparse(url) - params = parse_qs(parsed.query) - assert params["template"] == ["custom.yml"] +# --- get_bug_report_url_from_error --- -class TestGetBugReportUrlFromError: - """Tests for get_bug_report_url_from_error function.""" +def test_basic_api_error(): + request = httpx.Request("POST", "https://api.dedalus.ai/v1/chat/completions") + error = APIError("Request failed", request, body=None) + params = parse_qs(urlparse(get_bug_report_url_from_error(error)).query) - def test_basic_api_error(self): - """Generates URL from basic APIError instance.""" - request = httpx.Request("POST", "https://api.dedalus.ai/v1/chat/completions") - error = APIError("Request failed", request, body=None) + assert params["error_type"] == ["APIError"] + assert params["actual"] == ["Request failed"] + assert "version" in params - url = get_bug_report_url_from_error(error) - parsed = urlparse(url) - params = parse_qs(parsed.query) +def test_status_error_with_code(): + request = httpx.Request("POST", "https://api.dedalus.ai/v1/chat/completions") + response = httpx.Response(400, request=request) + error = BadRequestError("Invalid request", response=response, body=None) + params = parse_qs(urlparse(get_bug_report_url_from_error(error)).query) - assert params["error_type"] == ["APIError"] - assert params["actual"] == ["Request failed"] - assert "version" in params + assert params["error_type"] == ["BadRequestError"] + assert "[400]" in params["actual"][0] + assert "Invalid request" in params["actual"][0] - def test_api_status_error_with_code(self): - """Status code is included in error message for APIStatusError.""" - request = httpx.Request("POST", "https://api.dedalus.ai/v1/chat/completions") - response = httpx.Response(400, request=request) - error = BadRequestError("Invalid request", response=response, body=None) - url = get_bug_report_url_from_error(error) +def test_error_with_request_id(): + request = httpx.Request("POST", "https://api.dedalus.ai/v1/chat/completions") + error = APIError("Test error", request, body=None) + params = parse_qs(urlparse(get_bug_report_url_from_error(error, request_id="req-456")).query) + assert "Request ID: req-456" in params["notes"][0] - parsed = urlparse(url) - params = parse_qs(parsed.query) - assert params["error_type"] == ["BadRequestError"] - assert "[400]" in params["actual"][0] - assert "Invalid request" in params["actual"][0] +def test_includes_sdk_version(): + request = httpx.Request("POST", "https://api.dedalus.ai/v1/chat/completions") + error = APIError("Test error", request, body=None) + params = parse_qs(urlparse(get_bug_report_url_from_error(error)).query) + assert len(params["version"][0]) > 0 - def test_with_request_id(self): - """Request ID parameter is included when provided.""" - request = httpx.Request("POST", "https://api.dedalus.ai/v1/chat/completions") - error = APIError("Test error", request, body=None) - url = get_bug_report_url_from_error(error, request_id="req-456") +# --- Platform info --- - parsed = urlparse(url) - params = parse_qs(parsed.query) - assert "notes" in params - assert "Request ID: req-456" in params["notes"][0] +def test_platform_info_format(): + params = parse_qs(urlparse(generate_bug_report_url()).query) + parts = params["platform"][0].split() + assert len(parts) >= 2 - def test_includes_sdk_version(self): - """SDK version is automatically included from __version__.""" - request = httpx.Request("POST", "https://api.dedalus.ai/v1/chat/completions") - error = APIError("Test error", request, body=None) - url = get_bug_report_url_from_error(error) +def test_python_version_format(): + params = parse_qs(urlparse(generate_bug_report_url()).query) + python_version = params["python_version"][0] + assert python_version.startswith("Python ") + assert python_version.replace("Python ", "")[0].isdigit() - parsed = urlparse(url) - params = parse_qs(parsed.query) - # Should have version parameter populated - assert "version" in params - # Should be non-empty - assert len(params["version"][0]) > 0 +# --- URL encoding --- -class TestPlatformInfo: - """Tests for platform info collection.""" +def test_special_chars_encoded(): + url = generate_bug_report_url( + error_message="Error @ 127.0.0.1:8080 #fail", + request_id="req/test#123", + ) + query_string = url.split("?")[1] + assert "@" not in query_string + assert "#" not in query_string - def test_platform_info_format(self): - """Platform info has expected format.""" - url = generate_bug_report_url() + params = parse_qs(urlparse(url).query) + assert "@" in params["actual"][0] + assert "#" in params["notes"][0] - parsed = urlparse(url) - params = parse_qs(parsed.query) - platform_info = params["platform"][0] - # Format: "System Release Machine" - parts = platform_info.split() - assert len(parts) >= 2 - - def test_python_version_format(self): - """Python version has expected format.""" - url = generate_bug_report_url() - - parsed = urlparse(url) - params = parse_qs(parsed.query) - python_version = params["python_version"][0] - - assert python_version.startswith("Python ") - version_part = python_version.replace("Python ", "") - assert len(version_part) > 0 - assert version_part[0].isdigit() - - -class TestUrlEncoding: - """Tests for URL encoding edge cases.""" - - def test_special_chars_encoded(self): - """Special characters are properly URL-encoded.""" - url = generate_bug_report_url( - error_message="Error @ 127.0.0.1:8080 #fail", - request_id="req/test#123", - ) - - # URL query string should not contain raw special chars - query_string = url.split("?")[1] - assert "@" not in query_string - assert "#" not in query_string - - # But decoded params should contain them - parsed = urlparse(url) - params = parse_qs(parsed.query) - assert "@" in params["actual"][0] - assert "#" in params["notes"][0] - - def test_unicode_handling(self): - """Unicode characters are properly encoded.""" - url = generate_bug_report_url(error_message="Error: 数据库连接失败") - - # Should not raise and should produce valid URL - parsed = urlparse(url) - params = parse_qs(parsed.query) - assert "数据库连接失败" in params["actual"][0] +def test_unicode_handling(): + url = generate_bug_report_url(error_message="Error: 数据库连接失败") + params = parse_qs(urlparse(url).query) + assert "数据库连接失败" in params["actual"][0] diff --git a/tests/test_encryption.py b/tests/test_encryption.py index c1e09ea..ac5af38 100644 --- a/tests/test_encryption.py +++ b/tests/test_encryption.py @@ -28,56 +28,66 @@ ) -# Envelope v1 constants (must match encryption.py) +# --- Constants (must match encryption.py) --- + _ENVELOPE_VERSION = 0x01 _NONCE_LEN = 12 _TAG_LEN = 16 +# --- Test helpers --- + + def _b64url_encode(data: bytes) -> str: - """Base64url encode without padding (test helper).""" return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") def _b64url_decode(s: str) -> bytes: - """Base64url decode with padding restoration (test helper).""" pad = 4 - len(s) % 4 if pad != 4: s += "=" * pad return base64.urlsafe_b64decode(s) +def _decrypt_envelope_v1(private_key: Any, envelope: bytes) -> bytes: + """Decrypt envelope v1 format.""" + key_size = private_key.key_size // 8 + assert envelope[0] == _ENVELOPE_VERSION + wrapped_key = envelope[1 : 1 + key_size] + nonce = envelope[1 + key_size : 1 + key_size + _NONCE_LEN] + ciphertext_with_tag = envelope[1 + key_size + _NONCE_LEN :] + aes_key = private_key.decrypt( + wrapped_key, + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None, + ), + ) + return AESGCM(aes_key).decrypt(nonce, ciphertext_with_tag, None) + + +# --- Fixtures --- + + @pytest.fixture def rsa_keypair() -> tuple[Any, Any]: - """Generate RSA keypair for testing.""" - private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=2048, - backend=default_backend(), - ) + private_key = rsa.generate_private_key(65537, 2048, default_backend()) return private_key, private_key.public_key() @pytest.fixture def rsa_keypair_3072() -> tuple[Any, Any]: - """Generate 3072-bit RSA keypair (production size).""" - private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=3072, - backend=default_backend(), - ) + private_key = rsa.generate_private_key(65537, 3072, default_backend()) return private_key, private_key.public_key() @pytest.fixture def rsa_jwk(rsa_keypair: tuple[Any, Any]) -> dict[str, Any]: - """Create JWK from keypair.""" _, public_key = rsa_keypair numbers = public_key.public_numbers() - n_bytes = numbers.n.to_bytes((numbers.n.bit_length() + 7) // 8, "big") e_bytes = numbers.e.to_bytes((numbers.e.bit_length() + 7) // 8, "big") - return { "kty": "RSA", "use": "enc", @@ -87,170 +97,113 @@ def rsa_jwk(rsa_keypair: tuple[Any, Any]) -> dict[str, Any]: } -def decrypt_envelope_v1(private_key: Any, envelope: bytes) -> bytes: - """Decrypt envelope v1 format (test helper).""" - key_size = private_key.key_size // 8 - - version = envelope[0] - assert version == _ENVELOPE_VERSION, f"Expected version 0x01, got 0x{version:02x}" - - wrapped_key = envelope[1 : 1 + key_size] - nonce = envelope[1 + key_size : 1 + key_size + _NONCE_LEN] - ciphertext_with_tag = envelope[1 + key_size + _NONCE_LEN :] - - aes_key = private_key.decrypt( - wrapped_key, - padding.OAEP( - mgf=padding.MGF1(algorithm=hashes.SHA256()), - algorithm=hashes.SHA256(), - label=None, - ), - ) - - return AESGCM(aes_key).decrypt(nonce, ciphertext_with_tag, None) - - -class TestJwkToPublicKey: - """Test JWK to public key conversion.""" - - def test_valid_jwk(self, rsa_jwk: dict[str, Any], rsa_keypair: tuple[Any, Any]) -> None: - """Convert valid JWK to public key.""" - _, expected_public = rsa_keypair - public_key = jwk_to_public_key(rsa_jwk) - - assert public_key.public_numbers().n == expected_public.public_numbers().n - assert public_key.public_numbers().e == expected_public.public_numbers().e - - def test_wrong_kty_raises(self) -> None: - """Raise on non-RSA key type.""" - with pytest.raises(ValueError, match="expected RSA key type"): - jwk_to_public_key({"kty": "EC", "n": "xxx", "e": "xxx"}) - - def test_missing_n_raises(self, rsa_jwk: dict[str, Any]) -> None: - """Raise on missing n parameter.""" - del rsa_jwk["n"] - with pytest.raises(ValueError, match="missing required JWK field"): - jwk_to_public_key(rsa_jwk) +# --- jwk_to_public_key --- - def test_small_key_rejected(self) -> None: - """Reject keys smaller than minimum size.""" - small_key = rsa.generate_private_key(65537, 1024, default_backend()) - numbers = small_key.public_key().public_numbers() - n_bytes = numbers.n.to_bytes((numbers.n.bit_length() + 7) // 8, "big") - e_bytes = numbers.e.to_bytes((numbers.e.bit_length() + 7) // 8, "big") - jwk = {"kty": "RSA", "n": _b64url_encode(n_bytes), "e": _b64url_encode(e_bytes)} +def test_jwk_valid(rsa_jwk: dict[str, Any], rsa_keypair: tuple[Any, Any]): + _, expected_public = rsa_keypair + public_key = jwk_to_public_key(rsa_jwk) + assert public_key.public_numbers().n == expected_public.public_numbers().n + assert public_key.public_numbers().e == expected_public.public_numbers().e - with pytest.raises(ValueError, match="below minimum"): - jwk_to_public_key(jwk, min_key_size=2048) +def test_jwk_wrong_kty_raises(): + with pytest.raises(ValueError, match="expected RSA key type"): + jwk_to_public_key({"kty": "EC", "n": "xxx", "e": "xxx"}) -class TestEncryptCredentials: - """Test credential encryption (envelope v1).""" - def test_envelope_format(self, rsa_keypair: tuple[Any, Any]) -> None: - """Encrypt produces valid envelope v1 format.""" - private_key, public_key = rsa_keypair - credentials = {"token": "ghp_xxx123"} +def test_jwk_missing_n_raises(rsa_jwk: dict[str, Any]): + del rsa_jwk["n"] + with pytest.raises(ValueError, match="missing required JWK field"): + jwk_to_public_key(rsa_jwk) - ciphertext_b64 = encrypt_credentials(public_key, credentials) - envelope = _b64url_decode(ciphertext_b64) - key_size = private_key.key_size // 8 - min_len = 1 + key_size + _NONCE_LEN + _TAG_LEN - assert len(envelope) >= min_len - assert envelope[0] == _ENVELOPE_VERSION - - def test_roundtrip(self, rsa_keypair: tuple[Any, Any]) -> None: - """Encrypted credentials can be decrypted with private key.""" - private_key, public_key = rsa_keypair - credentials = {"api_key": "sk_test_123", "org_id": "org_456"} - - ciphertext_b64 = encrypt_credentials(public_key, credentials) - envelope = _b64url_decode(ciphertext_b64) - plaintext = decrypt_envelope_v1(private_key, envelope) - - assert json.loads(plaintext) == credentials - - def test_large_payload(self, rsa_keypair: tuple[Any, Any]) -> None: - """Envelope v1 handles payloads larger than RSA limit.""" - private_key, public_key = rsa_keypair - credentials = {"large_token": "x" * 1000, "another": "y" * 500} - - ciphertext_b64 = encrypt_credentials(public_key, credentials) - envelope = _b64url_decode(ciphertext_b64) - plaintext = decrypt_envelope_v1(private_key, envelope) - - assert json.loads(plaintext) == credentials - - def test_randomized(self, rsa_keypair: tuple[Any, Any]) -> None: - """Same plaintext produces different ciphertext each time.""" - _, public_key = rsa_keypair - credentials = {"token": "same_value"} - - ct1 = encrypt_credentials(public_key, credentials) - ct2 = encrypt_credentials(public_key, credentials) +def test_jwk_small_key_rejected(): + small_key = rsa.generate_private_key(65537, 1024, default_backend()) + numbers = small_key.public_key().public_numbers() + n_bytes = numbers.n.to_bytes((numbers.n.bit_length() + 7) // 8, "big") + e_bytes = numbers.e.to_bytes((numbers.e.bit_length() + 7) // 8, "big") + jwk = {"kty": "RSA", "n": _b64url_encode(n_bytes), "e": _b64url_encode(e_bytes)} + with pytest.raises(ValueError, match="below minimum"): + jwk_to_public_key(jwk, min_key_size=2048) - assert ct1 != ct2 - def test_with_3072_key(self, rsa_keypair_3072: tuple[Any, Any]) -> None: - """Works with production-size 3072-bit keys.""" - private_key, public_key = rsa_keypair_3072 - credentials = {"token": "production_token"} +# --- encrypt_credentials (envelope v1) --- - ciphertext_b64 = encrypt_credentials(public_key, credentials) - envelope = _b64url_decode(ciphertext_b64) - plaintext = decrypt_envelope_v1(private_key, envelope) - assert json.loads(plaintext) == credentials +def test_encrypt_envelope_format(rsa_keypair: tuple[Any, Any]): + private_key, public_key = rsa_keypair + ciphertext_b64 = encrypt_credentials(public_key, {"token": "ghp_xxx123"}) + envelope = _b64url_decode(ciphertext_b64) + key_size = private_key.key_size // 8 + assert len(envelope) >= 1 + key_size + _NONCE_LEN + _TAG_LEN + assert envelope[0] == _ENVELOPE_VERSION -class TestSecurityInvariants: - """Verify security properties.""" +def test_encrypt_roundtrip(rsa_keypair: tuple[Any, Any]): + private_key, public_key = rsa_keypair + credentials = {"api_key": "sk_test_123", "org_id": "org_456"} + ciphertext_b64 = encrypt_credentials(public_key, credentials) + plaintext = _decrypt_envelope_v1(private_key, _b64url_decode(ciphertext_b64)) + assert json.loads(plaintext) == credentials - def test_plaintext_not_in_ciphertext(self, rsa_keypair: tuple[Any, Any]) -> None: - """Plaintext must not appear in ciphertext.""" - _, public_key = rsa_keypair - secret = "ghp_super_secret_token_12345" - ciphertext = encrypt_credentials(public_key, {"token": secret}) +def test_encrypt_large_payload(rsa_keypair: tuple[Any, Any]): + private_key, public_key = rsa_keypair + credentials = {"large_token": "x" * 1000, "another": "y" * 500} + ciphertext_b64 = encrypt_credentials(public_key, credentials) + plaintext = _decrypt_envelope_v1(private_key, _b64url_decode(ciphertext_b64)) + assert json.loads(plaintext) == credentials - assert secret not in ciphertext - assert "ghp_" not in ciphertext - def test_wrong_key_fails(self, rsa_keypair: tuple[Any, Any]) -> None: - """Decryption fails with wrong private key.""" - _, public_key = rsa_keypair - attacker_key = rsa.generate_private_key(65537, 2048, default_backend()) +def test_encrypt_randomized(rsa_keypair: tuple[Any, Any]): + """Same plaintext produces different ciphertext each time.""" + _, public_key = rsa_keypair + credentials = {"token": "same_value"} + assert encrypt_credentials(public_key, credentials) != encrypt_credentials(public_key, credentials) - ciphertext_b64 = encrypt_credentials(public_key, {"token": "secret"}) - envelope = _b64url_decode(ciphertext_b64) - with pytest.raises(Exception): - decrypt_envelope_v1(attacker_key, envelope) +def test_encrypt_with_3072_key(rsa_keypair_3072: tuple[Any, Any]): + private_key, public_key = rsa_keypair_3072 + credentials = {"token": "production_token"} + ciphertext_b64 = encrypt_credentials(public_key, credentials) + plaintext = _decrypt_envelope_v1(private_key, _b64url_decode(ciphertext_b64)) + assert json.loads(plaintext) == credentials - def test_tampered_ciphertext_fails(self, rsa_keypair: tuple[Any, Any]) -> None: - """GCM authentication rejects tampered ciphertext.""" - private_key, public_key = rsa_keypair - ciphertext_b64 = encrypt_credentials(public_key, {"token": "test"}) - envelope = bytearray(_b64url_decode(ciphertext_b64)) +# --- Security invariants --- - # Tamper with ciphertext portion - key_size = private_key.key_size // 8 - envelope[1 + key_size + _NONCE_LEN + 5] ^= 0xFF - with pytest.raises(Exception): - decrypt_envelope_v1(private_key, bytes(envelope)) +def test_plaintext_not_in_ciphertext(rsa_keypair: tuple[Any, Any]): + _, public_key = rsa_keypair + secret = "ghp_super_secret_token_12345" + ciphertext = encrypt_credentials(public_key, {"token": secret}) + assert secret not in ciphertext + assert "ghp_" not in ciphertext - def test_tampered_wrapped_key_fails(self, rsa_keypair: tuple[Any, Any]) -> None: - """Tampered wrapped key is rejected.""" - private_key, public_key = rsa_keypair - ciphertext_b64 = encrypt_credentials(public_key, {"token": "test"}) - envelope = bytearray(_b64url_decode(ciphertext_b64)) +def test_wrong_key_fails(rsa_keypair: tuple[Any, Any]): + _, public_key = rsa_keypair + attacker_key = rsa.generate_private_key(65537, 2048, default_backend()) + ciphertext_b64 = encrypt_credentials(public_key, {"token": "secret"}) + with pytest.raises(Exception): + _decrypt_envelope_v1(attacker_key, _b64url_decode(ciphertext_b64)) - envelope[10] ^= 0xFF - with pytest.raises(Exception): - decrypt_envelope_v1(private_key, bytes(envelope)) +def test_tampered_ciphertext_fails(rsa_keypair: tuple[Any, Any]): + private_key, public_key = rsa_keypair + ciphertext_b64 = encrypt_credentials(public_key, {"token": "test"}) + envelope = bytearray(_b64url_decode(ciphertext_b64)) + key_size = private_key.key_size // 8 + envelope[1 + key_size + _NONCE_LEN + 5] ^= 0xFF + with pytest.raises(Exception): + _decrypt_envelope_v1(private_key, bytes(envelope)) + + +def test_tampered_wrapped_key_fails(rsa_keypair: tuple[Any, Any]): + private_key, public_key = rsa_keypair + ciphertext_b64 = encrypt_credentials(public_key, {"token": "test"}) + envelope = bytearray(_b64url_decode(ciphertext_b64)) + envelope[10] ^= 0xFF + with pytest.raises(Exception): + _decrypt_envelope_v1(private_key, bytes(envelope)) diff --git a/tests/test_mcp_wire.py b/tests/test_mcp_wire.py index f7eb810..f40999d 100644 --- a/tests/test_mcp_wire.py +++ b/tests/test_mcp_wire.py @@ -22,7 +22,7 @@ ) -# --- Fixtures ---------------------------------------------------------------- +# --- Test helpers --- class FakeMCPServer: @@ -45,247 +45,204 @@ def serve(self, *args: Any, **kwargs: Any) -> None: class IncompleteServer: - """Missing required protocol attributes (no url, no serve).""" + """Missing required protocol attributes.""" def __init__(self) -> None: self.name = "incomplete" -# --- MCPServerWireSpec Construction ------------------------------------------ +# --- MCPServerWireSpec construction --- -class TestMCPServerWireSpecConstruction: - """Factory methods for creating wire specs.""" +def test_from_slug_simple(): + spec = MCPServerWireSpec.from_slug("dedalus-labs/example-server") + assert spec.slug == "dedalus-labs/example-server" + assert spec.version is None - def test_from_slug_simple(self) -> None: - """Simple marketplace slug.""" - spec = MCPServerWireSpec.from_slug("dedalus-labs/example-server") - assert spec.slug == "dedalus-labs/example-server" - assert spec.version is None - def test_from_slug_with_version(self) -> None: - """Slug with explicit version parameter.""" - spec = MCPServerWireSpec.from_slug("dedalus-labs/example-server", version="v1.2.0") - assert spec.version == "v1.2.0" +def test_from_slug_with_version(): + spec = MCPServerWireSpec.from_slug("dedalus-labs/example-server", version="v1.2.0") + assert spec.version == "v1.2.0" - def test_from_slug_with_embedded_version(self) -> None: - """Slug@version syntax parsed correctly.""" - spec = MCPServerWireSpec.from_slug("dedalus-labs/example-server@v2") - assert spec.slug == "dedalus-labs/example-server" - assert spec.version == "v2" - def test_from_url(self) -> None: - """Direct URL.""" - spec = MCPServerWireSpec.from_url(url="http://127.0.0.1:8000/mcp") - assert spec.url == "http://127.0.0.1:8000/mcp" +def test_from_slug_with_embedded_version(): + spec = MCPServerWireSpec.from_slug("dedalus-labs/example-server@v2") + assert spec.slug == "dedalus-labs/example-server" + assert spec.version == "v2" -# --- MCPServerWireSpec Validation -------------------------------------------- +def test_from_url(): + spec = MCPServerWireSpec.from_url(url="http://127.0.0.1:8000/mcp") + assert spec.url == "http://127.0.0.1:8000/mcp" -class TestMCPServerWireSpecValidation: - """Pydantic validation rules for wire specs.""" +# --- MCPServerWireSpec validation --- - def test_requires_slug_or_url(self) -> None: - """Must provide either slug or url.""" - with pytest.raises(ValidationError) as exc_info: - MCPServerWireSpec() - assert "requires either 'slug' or 'url'" in str(exc_info.value) - def test_rejects_both_slug_and_url(self) -> None: - """Cannot provide both slug and url.""" - with pytest.raises(ValidationError) as exc_info: - MCPServerWireSpec( - slug="dedalus-labs/example-server", - url="http://localhost:8000/mcp", - ) - assert "cannot have both" in str(exc_info.value) +def test_requires_slug_or_url(): + with pytest.raises(ValidationError, match="requires either"): + MCPServerWireSpec() - def test_url_must_start_with_http(self) -> None: - """URL must have http:// or https:// scheme.""" - with pytest.raises(ValidationError) as exc_info: - MCPServerWireSpec(url="localhost:8000/mcp") - assert "must start with http://" in str(exc_info.value) - def test_https_url_accepted(self) -> None: - """HTTPS URLs are valid.""" - spec = MCPServerWireSpec(url="https://mcp.dedaluslabs.ai/acme/my-server/mcp") - assert spec.url == "https://mcp.dedaluslabs.ai/acme/my-server/mcp" +def test_rejects_both_slug_and_url(): + with pytest.raises(ValidationError, match="cannot have both"): + MCPServerWireSpec(slug="dedalus-labs/example-server", url="http://localhost:8000/mcp") - def test_localhost_url_accepted(self) -> None: - """Localhost URLs are valid for dev.""" - spec = MCPServerWireSpec(url="http://127.0.0.1:8000/mcp") - assert spec.url == "http://127.0.0.1:8000/mcp" - def test_slug_format_validation(self) -> None: - """Slug must match org/project pattern.""" - MCPServerWireSpec(slug="dedalus-labs/example-server") - MCPServerWireSpec(slug="org_123/project_456") - MCPServerWireSpec(slug="a/b") +def test_url_must_start_with_http(): + with pytest.raises(ValidationError, match="must start with http://"): + MCPServerWireSpec(url="localhost:8000/mcp") - with pytest.raises(ValidationError): - MCPServerWireSpec(slug="invalid-no-slash") - with pytest.raises(ValidationError): - MCPServerWireSpec(slug="too/many/slashes") +def test_https_url_accepted(): + spec = MCPServerWireSpec(url="https://mcp.dedaluslabs.ai/acme/my-server/mcp") + assert spec.url == "https://mcp.dedaluslabs.ai/acme/my-server/mcp" - def test_slug_with_at_sign_rejected_by_pattern(self) -> None: - """Slug pattern doesn't allow @ - use from_slug() for version parsing.""" - with pytest.raises(ValidationError) as exc_info: - MCPServerWireSpec(slug="org/project@v1", version="v2") - assert "string_pattern_mismatch" in str(exc_info.value).lower() - # Correct way: use from_slug() which parses the version - spec = MCPServerWireSpec.from_slug("org/project@v1") - assert spec.slug == "org/project" - assert spec.version == "v1" +def test_localhost_url_accepted(): + spec = MCPServerWireSpec(url="http://127.0.0.1:8000/mcp") + assert spec.url == "http://127.0.0.1:8000/mcp" - def test_extra_fields_forbidden(self) -> None: - """Extra fields rejected (ConfigDict extra='forbid').""" - with pytest.raises(ValidationError) as exc_info: - MCPServerWireSpec(slug="org/test", unknown_field="value") # type: ignore[call-arg] - assert "extra" in str(exc_info.value).lower() +def test_slug_format_validation(): + MCPServerWireSpec(slug="dedalus-labs/example-server") + MCPServerWireSpec(slug="org_123/project_456") + MCPServerWireSpec(slug="a/b") -# --- MCPServerWireSpec Serialization ----------------------------------------- + with pytest.raises(ValidationError): + MCPServerWireSpec(slug="invalid-no-slash") + with pytest.raises(ValidationError): + MCPServerWireSpec(slug="too/many/slashes") -class TestMCPServerWireSpecSerialization: - """to_wire() output for different spec types.""" - def test_simple_slug_serializes_to_string(self) -> None: - """Simple slug-only specs serialize to plain string (efficient).""" - spec = MCPServerWireSpec.from_slug("dedalus-labs/example-server") - wire = spec.to_wire() - assert wire == "dedalus-labs/example-server" - assert isinstance(wire, str) +def test_slug_with_at_sign_rejected(): + """Slug pattern doesn't allow @. Use from_slug() for version parsing.""" + with pytest.raises(ValidationError): + MCPServerWireSpec(slug="org/project@v1", version="v2") - def test_versioned_slug_serializes_to_dict(self) -> None: - """Slug with version serializes to dict.""" - spec = MCPServerWireSpec.from_slug("dedalus-labs/example-server", version="v1.0.0") - wire = spec.to_wire() - assert wire == {"slug": "dedalus-labs/example-server", "version": "v1.0.0"} + spec = MCPServerWireSpec.from_slug("org/project@v1") + assert spec.slug == "org/project" + assert spec.version == "v1" - def test_url_spec_serializes_to_dict(self) -> None: - """URL-based specs serialize to dict with just url.""" - spec = MCPServerWireSpec.from_url(url="http://127.0.0.1:8000/mcp") - wire = spec.to_wire() - assert wire == {"url": "http://127.0.0.1:8000/mcp"} - def test_serialization_is_json_compatible(self) -> None: - """Wire format round-trips through JSON.""" - spec = MCPServerWireSpec.from_url(url="http://127.0.0.1:8000/mcp") - json_str = json.dumps(spec.to_wire()) - assert '"url": "http://127.0.0.1:8000/mcp"' in json_str +def test_extra_fields_forbidden(): + with pytest.raises(ValidationError): + MCPServerWireSpec(slug="org/test", unknown_field="value") # type: ignore[call-arg] -# --- MCPServerProtocol ------------------------------------------------------- +# --- MCPServerWireSpec serialization --- -class TestMCPServerProtocol: - """Structural typing for MCP server objects.""" - - def test_fake_server_satisfies_protocol(self) -> None: - """FakeMCPServer satisfies MCPServerProtocol.""" - server = FakeMCPServer(name="test", url="http://localhost:8000/mcp") - assert is_mcp_server(server) - assert isinstance(server, MCPServerProtocol) - - def test_string_does_not_satisfy_protocol(self) -> None: - """Plain strings are not MCPServerProtocol.""" - assert not is_mcp_server("dedalus-labs/example-server") - - def test_dict_does_not_satisfy_protocol(self) -> None: - """Dicts are not MCPServerProtocol.""" - assert not is_mcp_server({"name": "test", "url": "http://localhost/mcp"}) +def test_simple_slug_serializes_to_string(): + wire = MCPServerWireSpec.from_slug("dedalus-labs/example-server").to_wire() + assert wire == "dedalus-labs/example-server" + assert isinstance(wire, str) - def test_incomplete_server_does_not_satisfy(self) -> None: - """Missing attributes means protocol not satisfied.""" - assert not is_mcp_server(IncompleteServer()) +def test_versioned_slug_serializes_to_dict(): + wire = MCPServerWireSpec.from_slug("dedalus-labs/example-server", version="v1.0.0").to_wire() + assert wire == {"slug": "dedalus-labs/example-server", "version": "v1.0.0"} -# --- serialize_mcp_servers --------------------------------------------------- +def test_url_spec_serializes_to_dict(): + wire = MCPServerWireSpec.from_url(url="http://127.0.0.1:8000/mcp").to_wire() + assert wire == {"url": "http://127.0.0.1:8000/mcp"} -class TestSerializeMCPServers: - """End-to-end serialization of mixed mcp_servers input.""" - def test_none_returns_empty_list(self) -> None: - """None input returns empty list.""" - assert serialize_mcp_servers(None) == [] +def test_serialization_json_roundtrip(): + wire = MCPServerWireSpec.from_url(url="http://127.0.0.1:8000/mcp").to_wire() + assert '"url": "http://127.0.0.1:8000/mcp"' in json.dumps(wire) - def test_single_string_slug(self) -> None: - """Single slug string passes through.""" - assert serialize_mcp_servers("dedalus-labs/example-server") == ["dedalus-labs/example-server"] - def test_single_string_url(self) -> None: - """Single URL string passes through.""" - assert serialize_mcp_servers("http://localhost:8000/mcp") == ["http://localhost:8000/mcp"] +# --- MCPServerProtocol --- - def test_single_mcp_server_object(self) -> None: - """MCPServerProtocol object serializes to URL dict.""" - server = FakeMCPServer(name="calculator", url="http://127.0.0.1:8000/mcp") - result = serialize_mcp_servers(server) - assert result == [{"url": "http://127.0.0.1:8000/mcp"}] - def test_list_of_slugs(self) -> None: - """List of slug strings.""" - result = serialize_mcp_servers(["dedalus-labs/example-server", "dedalus-labs/weather"]) - assert result == ["dedalus-labs/example-server", "dedalus-labs/weather"] +def test_fake_server_satisfies_protocol(): + server = FakeMCPServer(name="test", url="http://localhost:8000/mcp") + assert is_mcp_server(server) + assert isinstance(server, MCPServerProtocol) - def test_versioned_slug_in_list(self) -> None: - """Slug@version syntax expands to dict.""" - result = serialize_mcp_servers(["dedalus-labs/example-server@v2"]) - assert result == [{"slug": "dedalus-labs/example-server", "version": "v2"}] - def test_mixed_list(self) -> None: - """Mixed list of slugs, URLs, and server objects.""" - server = FakeMCPServer(name="local", url="http://127.0.0.1:8000/mcp") - result = serialize_mcp_servers([server, "dedalus-labs/example-server", "dedalus-labs/weather@v2"]) +def test_string_not_protocol(): + assert not is_mcp_server("dedalus-labs/example-server") - assert len(result) == 3 - assert result[0] == {"url": "http://127.0.0.1:8000/mcp"} - assert result[1] == "dedalus-labs/example-server" - assert result[2] == {"slug": "dedalus-labs/weather", "version": "v2"} - def test_server_without_url_uses_name_as_slug(self) -> None: - """Server object without URL returns name as slug.""" - server = FakeMCPServer(name="org/my-server", url=None) - result = serialize_mcp_servers(server) - assert result == ["org/my-server"] - - def test_dict_input_validated(self) -> None: - """Dict inputs pass through MCPServerWireSpec validation.""" - result = serialize_mcp_servers([{"slug": "dedalus-labs/test"}]) - assert result == ["dedalus-labs/test"] +def test_dict_not_protocol(): + assert not is_mcp_server({"name": "test", "url": "http://localhost/mcp"}) -# --- JSON Compatibility ------------------------------------------------------ +def test_incomplete_server_not_protocol(): + assert not is_mcp_server(IncompleteServer()) -class TestJSONCompatibility: - """Wire format is JSON-serializable and API-compatible.""" +# --- serialize_mcp_servers --- - def test_full_payload_structure(self) -> None: - """Complete API payload round-trips through JSON.""" - server = FakeMCPServer(name="calculator", url="http://127.0.0.1:8000/mcp") - wire_data = serialize_mcp_servers([server, "dedalus-labs/example-server", "dedalus-labs/weather@v2"]) - payload = { - "model": "openai/gpt-5-nano", - "messages": [{"role": "user", "content": "What is 2 + 2?"}], - "mcp_servers": wire_data, - } +def test_serialize_none(): + assert serialize_mcp_servers(None) == [] - parsed = json.loads(json.dumps(payload)) - assert parsed["mcp_servers"][0] == {"url": "http://127.0.0.1:8000/mcp"} - assert parsed["mcp_servers"][1] == "dedalus-labs/example-server" - assert parsed["mcp_servers"][2]["slug"] == "dedalus-labs/weather" - def test_unicode_in_url(self) -> None: - """Unicode in URL paths are handled.""" - # Dedalus-hosted URL with unicode in path token - spec = MCPServerWireSpec(url="http://mcp.dedaluslabs.ai/acme/計算機/mcp") - result = spec.to_wire() - json_str = json.dumps(result, ensure_ascii=False) - assert "計算機" in json_str +def test_serialize_single_slug(): + assert serialize_mcp_servers("dedalus-labs/example-server") == ["dedalus-labs/example-server"] + + +def test_serialize_single_url(): + assert serialize_mcp_servers("http://localhost:8000/mcp") == ["http://localhost:8000/mcp"] + + +def test_serialize_server_object(): + server = FakeMCPServer(name="calculator", url="http://127.0.0.1:8000/mcp") + assert serialize_mcp_servers(server) == [{"url": "http://127.0.0.1:8000/mcp"}] + + +def test_serialize_slug_list(): + result = serialize_mcp_servers(["dedalus-labs/example-server", "dedalus-labs/weather"]) + assert result == ["dedalus-labs/example-server", "dedalus-labs/weather"] + + +def test_serialize_versioned_slug(): + result = serialize_mcp_servers(["dedalus-labs/example-server@v2"]) + assert result == [{"slug": "dedalus-labs/example-server", "version": "v2"}] + + +def test_serialize_mixed_list(): + server = FakeMCPServer(name="local", url="http://127.0.0.1:8000/mcp") + result = serialize_mcp_servers([server, "dedalus-labs/example-server", "dedalus-labs/weather@v2"]) + + assert len(result) == 3 + assert result[0] == {"url": "http://127.0.0.1:8000/mcp"} + assert result[1] == "dedalus-labs/example-server" + assert result[2] == {"slug": "dedalus-labs/weather", "version": "v2"} + + +def test_serialize_server_without_url(): + server = FakeMCPServer(name="org/my-server", url=None) + assert serialize_mcp_servers(server) == ["org/my-server"] + + +def test_serialize_dict_input(): + result = serialize_mcp_servers([{"slug": "dedalus-labs/test"}]) + assert result == ["dedalus-labs/test"] + + +# --- JSON compatibility --- + + +def test_full_payload_json_roundtrip(): + server = FakeMCPServer(name="calculator", url="http://127.0.0.1:8000/mcp") + wire_data = serialize_mcp_servers([server, "dedalus-labs/example-server", "dedalus-labs/weather@v2"]) + payload = { + "model": "openai/gpt-5-nano", + "messages": [{"role": "user", "content": "What is 2 + 2?"}], + "mcp_servers": wire_data, + } + parsed = json.loads(json.dumps(payload)) + assert parsed["mcp_servers"][0] == {"url": "http://127.0.0.1:8000/mcp"} + assert parsed["mcp_servers"][1] == "dedalus-labs/example-server" + assert parsed["mcp_servers"][2]["slug"] == "dedalus-labs/weather" + + +def test_unicode_in_url(): + spec = MCPServerWireSpec(url="http://mcp.dedaluslabs.ai/acme/計算機/mcp") + assert "計算機" in json.dumps(spec.to_wire(), ensure_ascii=False) diff --git a/tests/test_mcp_wire_connections.py b/tests/test_mcp_wire_connections.py index f71d8de..b0bdae2 100644 --- a/tests/test_mcp_wire_connections.py +++ b/tests/test_mcp_wire_connections.py @@ -20,18 +20,11 @@ ) -# --- Mock objects for testing --- +# --- Test helpers --- class MockConnection: - """Mock Connection object implementing the protocol.""" - - def __init__( - self, - name: str, - base_url: str | None = None, - timeout_ms: int = 30000, - ) -> None: + def __init__(self, name: str, base_url: str | None = None, timeout_ms: int = 30000) -> None: self._name = name self._base_url = base_url self._timeout_ms = timeout_ms @@ -58,8 +51,6 @@ def to_dict(self) -> dict[str, Any]: class MockCredential: - """Mock Secret object implementing the protocol.""" - def __init__(self, connection: MockConnection, **values: Any) -> None: self._connection = connection self._values = values @@ -73,236 +64,161 @@ def values(self) -> dict[str, Any]: return dict(self._values) def to_dict(self) -> dict[str, Any]: - return { - "connection_name": self._connection.name, - "values": dict(self._values), - } + return {"connection_name": self._connection.name, "values": dict(self._values)} def values_for_encryption(self) -> dict[str, Any]: return dict(self._values) -class TestSerializeConnection: - """Test serialize_connection helper.""" - - def test_with_connection_object(self) -> None: - """Serialize Connection object with to_dict().""" - conn = MockConnection("github", "https://api.github.com", 60000) - - result = serialize_connection(conn) - - assert result["name"] == "github" - assert result["base_url"] == "https://api.github.com" - assert result["timeout_ms"] == 60000 - - def test_with_dict(self) -> None: - """Pass-through for dict input.""" - data = {"name": "dedalus", "base_url": "https://api.dedaluslabs.ai/v1"} - - result = serialize_connection(data) - - assert result == data - - def test_duck_type_extraction(self) -> None: - """Extract fields from object without to_dict().""" - - class BareConnection: - name = "bare" - base_url = "https://bare.api.com" - timeout_ms = 15000 - - result = serialize_connection(BareConnection()) - - assert result["name"] == "bare" - assert result["base_url"] == "https://bare.api.com" - assert result["timeout_ms"] == 15000 - - -class TestMatchSecretsToConnections: - """Test match_credentials_to_connections helper.""" - - def test_basic_matching(self) -> None: - """Match secrets to connections by name.""" - github = MockConnection("github") - dedalus = MockConnection("dedalus") - - github_secret = MockCredential(github, token="ghp_xxx") - dedalus_secret = MockCredential(dedalus, api_key="sk_xxx") - - pairs = match_credentials_to_connections( - [github, dedalus], - [dedalus_secret, github_secret], # Different order - ) - - assert len(pairs) == 2 - # Pairs should be in connection order - assert pairs[0][0].name == "github" - assert pairs[0][1].values == {"token": "ghp_xxx"} - assert pairs[1][0].name == "dedalus" - assert pairs[1][1].values == {"api_key": "sk_xxx"} - - def test_missing_secret_raises(self) -> None: - """Raise ValueError if connection has no secret.""" - github = MockConnection("github") - dedalus = MockConnection("dedalus") - - github_secret = MockCredential(github, token="ghp_xxx") - - with pytest.raises( - ValueError, match="Missing credentials for connections.*dedalus" - ): - match_credentials_to_connections([github, dedalus], [github_secret]) - - def test_with_dict_inputs(self) -> None: - """Works with dict inputs too.""" - connections = [{"name": "api"}] - secrets = [{"connection_name": "api", "values": {"key": "xxx"}}] - - pairs = match_credentials_to_connections(connections, secrets) - - assert len(pairs) == 1 - assert pairs[0][0]["name"] == "api" - assert pairs[0][1]["values"] == {"key": "xxx"} - - def test_missing_multiple_secrets(self) -> None: - """Error message lists all missing secrets.""" - github = MockConnection("github") - dedalus = MockConnection("dedalus") - slack = MockConnection("slack") - - github_secret = MockCredential(github, token="ghp_xxx") - - with pytest.raises(ValueError) as exc: - match_credentials_to_connections([github, dedalus, slack], [github_secret]) - - assert "dedalus" in str(exc.value) - assert "slack" in str(exc.value) - - -# --- Mock server for multi-server tests --- - - class MockServer: - """Mock MCPServer for testing.""" - def __init__(self, name: str, connections: list[Any] | None = None) -> None: self.name = name self.connections = connections or [] -class TestCollectUniqueConnections: - """Test collect_unique_connections helper.""" +# --- serialize_connection --- - def test_single_server(self) -> None: - """Collect connections from single server.""" - github = MockConnection("github") - dedalus = MockConnection("dedalus") - server = MockServer("bot", connections=[github, dedalus]) - result = collect_unique_connections([server]) +def test_serialize_connection_object(): + conn = MockConnection("github", "https://api.github.com", 60000) + result = serialize_connection(conn) + assert result["name"] == "github" + assert result["base_url"] == "https://api.github.com" + assert result["timeout_ms"] == 60000 - assert len(result) == 2 - assert result[0].name == "github" - assert result[1].name == "dedalus" - def test_shared_connection_deduplicated(self) -> None: - """Shared Connection appears only once.""" - github = MockConnection("github") +def test_serialize_connection_dict(): + data = {"name": "dedalus", "base_url": "https://api.dedaluslabs.ai/v1"} + assert serialize_connection(data) == data - server_a = MockServer("issues", connections=[github]) - server_b = MockServer("prs", connections=[github]) - result = collect_unique_connections([server_a, server_b]) +def test_serialize_connection_duck_type(): + class BareConnection: + name = "bare" + base_url = "https://bare.api.com" + timeout_ms = 15000 - assert len(result) == 1 - assert result[0].name == "github" + result = serialize_connection(BareConnection()) + assert result["name"] == "bare" + assert result["base_url"] == "https://bare.api.com" + assert result["timeout_ms"] == 15000 - def test_same_name_different_objects(self) -> None: - """Connections with same name are deduplicated.""" - # Even if different objects, same name means same logical connection - github_a = MockConnection("github", base_url="https://api.github.com") - github_b = MockConnection("github", base_url="https://api.github.com") - server_a = MockServer("a", connections=[github_a]) - server_b = MockServer("b", connections=[github_b]) +# --- match_credentials_to_connections --- - result = collect_unique_connections([server_a, server_b]) - # Should only include first occurrence - assert len(result) == 1 - assert result[0] is github_a +def test_match_basic(): + github = MockConnection("github") + dedalus = MockConnection("dedalus") + pairs = match_credentials_to_connections( + [github, dedalus], + [MockCredential(dedalus, api_key="sk_xxx"), MockCredential(github, token="ghp_xxx")], + ) + assert len(pairs) == 2 + assert pairs[0][0].name == "github" + assert pairs[0][1].values == {"token": "ghp_xxx"} + assert pairs[1][0].name == "dedalus" + assert pairs[1][1].values == {"api_key": "sk_xxx"} - def test_multiple_servers_multiple_connections(self) -> None: - """Collect and deduplicate across multiple servers.""" - github = MockConnection("github") - dedalus = MockConnection("dedalus") - slack = MockConnection("slack") - server_a = MockServer("bot1", connections=[github, dedalus]) - server_b = MockServer("bot2", connections=[github, slack]) +def test_match_missing_raises(): + github = MockConnection("github") + dedalus = MockConnection("dedalus") + with pytest.raises(ValueError, match="Missing credentials for connections.*dedalus"): + match_credentials_to_connections([github, dedalus], [MockCredential(github, token="ghp_xxx")]) - result = collect_unique_connections([server_a, server_b]) - assert len(result) == 3 - names = [c.name for c in result] - assert names == ["github", "dedalus", "slack"] +def test_match_with_dicts(): + connections = [{"name": "api"}] + secrets = [{"connection_name": "api", "values": {"key": "xxx"}}] + pairs = match_credentials_to_connections(connections, secrets) + assert len(pairs) == 1 + assert pairs[0][0]["name"] == "api" + assert pairs[0][1]["values"] == {"key": "xxx"} - def test_server_without_connections(self) -> None: - """Handle servers with no connections.""" - server_a = MockServer("empty") - server_b = MockServer("has", connections=[MockConnection("api")]) - result = collect_unique_connections([server_a, server_b]) +def test_match_missing_multiple(): + github = MockConnection("github") + dedalus = MockConnection("dedalus") + slack = MockConnection("slack") + with pytest.raises(ValueError) as exc: + match_credentials_to_connections([github, dedalus, slack], [MockCredential(github, token="ghp_xxx")]) + assert "dedalus" in str(exc.value) + assert "slack" in str(exc.value) - assert len(result) == 1 +# --- collect_unique_connections --- -class TestValidateSecretsForServers: - """Test validate_credentials_for_servers (main SDK init validation).""" - def test_all_connections_have_secrets(self) -> None: - """Success when all connections have matching secrets.""" - github = MockConnection("github") - dedalus = MockConnection("dedalus") +def test_collect_single_server(): + github = MockConnection("github") + dedalus = MockConnection("dedalus") + result = collect_unique_connections([MockServer("bot", connections=[github, dedalus])]) + assert len(result) == 2 + assert result[0].name == "github" + assert result[1].name == "dedalus" - server = MockServer("bot", connections=[github, dedalus]) - github_secret = MockCredential(github, token="ghp_xxx") - dedalus_secret = MockCredential(dedalus, api_key="sk_xxx") +def test_collect_deduplicates_shared(): + github = MockConnection("github") + result = collect_unique_connections([MockServer("a", [github]), MockServer("b", [github])]) + assert len(result) == 1 + assert result[0].name == "github" - pairs = validate_credentials_for_servers( - [server], [github_secret, dedalus_secret] - ) - assert len(pairs) == 2 +def test_collect_deduplicates_by_name(): + github_a = MockConnection("github", base_url="https://api.github.com") + github_b = MockConnection("github", base_url="https://api.github.com") + result = collect_unique_connections([MockServer("a", [github_a]), MockServer("b", [github_b])]) + assert len(result) == 1 + assert result[0] is github_a - def test_shared_connection_one_secret(self) -> None: - """One Secret covers shared Connection across servers.""" - github = MockConnection("github") - server_a = MockServer("issues", connections=[github]) - server_b = MockServer("prs", connections=[github]) +def test_collect_multiple_servers(): + github = MockConnection("github") + dedalus = MockConnection("dedalus") + slack = MockConnection("slack") + result = collect_unique_connections([ + MockServer("bot1", [github, dedalus]), + MockServer("bot2", [github, slack]), + ]) + assert [c.name for c in result] == ["github", "dedalus", "slack"] - github_secret = MockCredential(github, token="ghp_xxx") - pairs = validate_credentials_for_servers( - [server_a, server_b], - [github_secret], # Only one secret needed - ) +def test_collect_server_without_connections(): + result = collect_unique_connections([MockServer("empty"), MockServer("has", [MockConnection("api")])]) + assert len(result) == 1 - assert len(pairs) == 1 - assert pairs[0][0].name == "github" - def test_missing_secret_fails_fast(self) -> None: - """Raise immediately if any connection lacks a secret.""" - github = MockConnection("github") - dedalus = MockConnection("dedalus") +# --- validate_credentials_for_servers --- - server = MockServer("bot", connections=[github, dedalus]) - github_secret = MockCredential(github, token="ghp_xxx") - with pytest.raises(ValueError) as exc: - validate_credentials_for_servers([server], [github_secret]) +def test_validate_all_present(): + github = MockConnection("github") + dedalus = MockConnection("dedalus") + server = MockServer("bot", connections=[github, dedalus]) + pairs = validate_credentials_for_servers( + [server], + [MockCredential(github, token="ghp_xxx"), MockCredential(dedalus, api_key="sk_xxx")], + ) + assert len(pairs) == 2 + - assert "dedalus" in str(exc.value) - assert "Missing credentials" in str(exc.value) +def test_validate_shared_connection(): + github = MockConnection("github") + pairs = validate_credentials_for_servers( + [MockServer("a", [github]), MockServer("b", [github])], + [MockCredential(github, token="ghp_xxx")], + ) + assert len(pairs) == 1 + assert pairs[0][0].name == "github" + + +def test_validate_missing_fails_fast(): + github = MockConnection("github") + dedalus = MockConnection("dedalus") + server = MockServer("bot", connections=[github, dedalus]) + with pytest.raises(ValueError) as exc: + validate_credentials_for_servers([server], [MockCredential(github, token="ghp_xxx")]) + assert "dedalus" in str(exc.value) + assert "Missing credentials" in str(exc.value)