diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9cdf717c7..ea536bfc0 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -50,6 +50,27 @@ MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up +def _get_default_origin(url: str) -> str | None: + """Derive a same-origin ``Origin`` value for *url*. + + Browsers always send an ``Origin`` on cross-origin-capable requests; a server-to-server + client sends none. Emitting a correct same-origin value matches browser behavior and + satisfies servers that gate state-changing requests on a present, same-origin ``Origin`` + (defense-in-depth against DNS-rebinding/CSRF), without weakening any server's posture. + + The value is built from ``httpx.URL`` so it uses the exact scheme/host/port normalization + httpx applies to the ``Host`` header (default ports dropped, IPv6 hosts bracketed, userinfo + stripped). That keeps ``Origin`` and ``Host`` byte-for-byte consistent even for inputs like + ``https://host:443/mcp``, where naive parsing keeps a redundant ``:443`` that would *not* + match the ``Host`` httpx sends. Returns ``None`` for non-HTTP(S) URLs or URLs without an + authority, where no meaningful web origin exists. + """ + parsed = httpx.URL(url) + if parsed.scheme not in ("http", "https") or not parsed.netloc: + return None + return f"{parsed.scheme}://{parsed.netloc.decode('ascii')}" + + class StreamableHTTPError(Exception): """Base exception for StreamableHTTP transport errors.""" @@ -72,13 +93,16 @@ class RequestContext: class StreamableHTTPTransport: """StreamableHTTP client transport implementation.""" - def __init__(self, url: str) -> None: + def __init__(self, url: str, default_origin: str | None = None) -> None: """Initialize the StreamableHTTP transport. Args: url: The endpoint URL. + default_origin: ``Origin`` header to send when the caller has not configured one + on the HTTP client. See ``_get_default_origin``. """ self.url = url + self.default_origin = default_origin self.session_id: str | None = None self.protocol_version: str | None = None @@ -92,6 +116,9 @@ def _prepare_headers(self) -> dict[str, str]: "accept": "application/json, text/event-stream", "content-type": "application/json", } + # Same-origin Origin for servers that gate on it; only when the caller set none. + if self.default_origin: + headers["origin"] = self.default_origin # Add session headers if available if self.session_id: headers[MCP_SESSION_ID] = self.session_id @@ -547,7 +574,10 @@ async def streamable_http_client( # Create default client with recommended MCP timeouts client = create_mcp_http_client() - transport = StreamableHTTPTransport(url) + # Only supply a default Origin when the caller hasn't set one, so an explicit Origin + # (e.g. a multi-tenant proxy's) always wins. The client's own headers are left untouched. + default_origin = None if "origin" in client.headers else _get_default_origin(url) + transport = StreamableHTTPTransport(url, default_origin=default_origin) logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index b43a3361c..d64f75a6a 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -21,11 +21,12 @@ from httpx_sse import ServerSentEvent from starlette.applications import Starlette from starlette.requests import Request -from starlette.routing import Mount +from starlette.responses import Response +from starlette.routing import Mount, Route from mcp import MCPError, types from mcp.client.session import ClientSession -from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client +from mcp.client.streamable_http import StreamableHTTPTransport, _get_default_origin, streamable_http_client from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, @@ -355,6 +356,65 @@ def make_client(app: Starlette, headers: dict[str, str] | None = None) -> httpx. ) +def test_get_default_origin_normalizes_authority() -> None: + """The default Origin matches the Host header httpx emits for the same URL.""" + # Default ports are dropped, so Origin "https://h:443" can't mismatch the Host "h". + assert _get_default_origin("https://example.com:443/mcp?token=abc") == "https://example.com" + assert _get_default_origin("http://example.com:80/mcp") == "http://example.com" + # Non-default ports kept; IPv6 hosts bracketed; userinfo stripped. + assert _get_default_origin("https://example.com:8443/mcp") == "https://example.com:8443" + assert _get_default_origin("http://user:pass@[::1]:8080/mcp") == "http://[::1]:8080" + + +def test_get_default_origin_returns_none_without_web_origin() -> None: + """URLs with no meaningful web origin yield no Origin header.""" + assert _get_default_origin("ws://example.com/mcp") is None # non-HTTP scheme + assert _get_default_origin("http:///mcp") is None # no authority + + +def _make_origin_recording_app(seen: anyio.Event, recorded: dict[str, str | None]) -> Starlette: + async def mcp_endpoint(request: Request) -> Response: + recorded["origin"] = request.headers.get("origin") + recorded["host"] = request.headers.get("host") + seen.set() + return Response(status_code=202) + + return Starlette(routes=[Route("/mcp", endpoint=mcp_endpoint, methods=["POST"])]) + + +@pytest.mark.anyio +async def test_streamable_http_client_sends_same_origin_by_default() -> None: + """The client sends a same-origin Origin derived from the URL, matching the Host it emits.""" + seen = anyio.Event() + recorded: dict[str, str | None] = {} + async with make_client(_make_origin_recording_app(seen, recorded)) as client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=client) as (_read_stream, write_stream): + await write_stream.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))) + with anyio.fail_after(5): + await seen.wait() + + assert recorded["origin"] == BASE_URL + assert recorded["origin"] is not None + assert recorded["origin"].split("://", 1)[1] == recorded["host"] # Origin host == Host header + assert "origin" not in client.headers # caller's client is left untouched + + +@pytest.mark.anyio +async def test_streamable_http_client_preserves_custom_origin() -> None: + """A caller-configured Origin always wins over the derived default.""" + seen = anyio.Event() + recorded: dict[str, str | None] = {} + app = _make_origin_recording_app(seen, recorded) + async with make_client(app, headers={"Origin": "https://proxy.example"}) as client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=client) as (_read_stream, write_stream): + await write_stream.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))) + with anyio.fail_after(5): + await seen.wait() + + assert recorded["origin"] == "https://proxy.example" + assert client.headers["origin"] == "https://proxy.example" + + # Test fixtures @pytest.fixture async def basic_app() -> AsyncIterator[Starlette]: