From a0c5abaac1f9a5ab819b54f97d458a855cffcc6f Mon Sep 17 00:00:00 2001 From: tarunag10 Date: Tue, 2 Jun 2026 11:28:30 +0100 Subject: [PATCH] Add default Origin for streamable HTTP client --- src/mcp/client/streamable_http.py | 20 +++++++++-- tests/shared/test_streamable_http.py | 54 +++++++++++++++++++++++++++- 2 files changed, 71 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index aa3e50e07e..65888d6da0 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -7,6 +7,7 @@ from collections.abc import AsyncGenerator, Awaitable, Callable from contextlib import asynccontextmanager from dataclasses import dataclass +from urllib.parse import urlsplit import anyio import httpx @@ -50,6 +51,15 @@ MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up +def _get_default_origin(url: str) -> str | None: + parsed_url = urlsplit(url) + if parsed_url.scheme not in {"http", "https"} or not parsed_url.netloc: + return None + + authority = parsed_url.netloc.rsplit("@", 1)[-1] + return f"{parsed_url.scheme}://{authority}" + + class StreamableHTTPError(Exception): """Base exception for StreamableHTTP transport errors.""" @@ -72,13 +82,16 @@ class RequestContext: class StreamableHTTPTransport: """StreamableHTTP client transport implementation.""" - def __init__(self, url: str) -> None: + def __init__(self, url: str, default_origin: str | None = None) -> None: """Initialize the StreamableHTTP transport. Args: url: The endpoint URL. + default_origin: Origin header to include when the caller has not + configured one on the HTTP client. """ self.url = url + self.default_origin = default_origin self.session_id: str | None = None self.protocol_version: str | None = None @@ -92,6 +105,8 @@ def _prepare_headers(self) -> dict[str, str]: "accept": "application/json, text/event-stream", "content-type": "application/json", } + if self.default_origin: + headers["Origin"] = self.default_origin # Add session headers if available if self.session_id: headers[MCP_SESSION_ID] = self.session_id @@ -547,7 +562,8 @@ async def streamable_http_client( # Create default client with recommended MCP timeouts client = create_mcp_http_client() - transport = StreamableHTTPTransport(url) + default_origin = None if "origin" in client.headers else _get_default_origin(url) + transport = StreamableHTTPTransport(url, default_origin=default_origin) logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3d5770fb61..16890086f2 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -29,7 +29,7 @@ from mcp import MCPError, types from mcp.client.session import ClientSession -from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client +from mcp.client.streamable_http import StreamableHTTPTransport, _get_default_origin, streamable_http_client from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, @@ -767,6 +767,58 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") +def test_get_default_origin_derives_origin_from_url(): + assert _get_default_origin("https://example.com:8443/mcp?token=abc") == "https://example.com:8443" + assert _get_default_origin("http://user:pass@[::1]:8080/mcp") == "http://[::1]:8080" + + +@pytest.mark.anyio +async def test_streamable_http_client_sets_default_origin_on_http_client(): + recorded_headers: list[httpx.Headers] = [] + + def handler(request: httpx.Request) -> httpx.Response: + recorded_headers.append(request.headers) + return httpx.Response(202, request=request) + + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: + async with streamable_http_client("https://mcp.example.com:8443/mcp", http_client=client) as ( + _read_stream, + write_stream, + ): + await write_stream.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))) + with anyio.fail_after(1): + while not recorded_headers: + await anyio.sleep(0.01) + + assert recorded_headers[0]["origin"] == "https://mcp.example.com:8443" + assert "origin" not in client.headers + + +@pytest.mark.anyio +async def test_streamable_http_client_preserves_custom_origin_header(): + recorded_headers: list[httpx.Headers] = [] + + def handler(request: httpx.Request) -> httpx.Response: + recorded_headers.append(request.headers) + return httpx.Response(202, request=request) + + async with httpx.AsyncClient( + headers={"Origin": "https://proxy.example"}, + transport=httpx.MockTransport(handler), + ) as client: + async with streamable_http_client("https://mcp.example.com/mcp", http_client=client) as ( + _read_stream, + write_stream, + ): + await write_stream.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))) + with anyio.fail_after(1): + while not recorded_headers: + await anyio.sleep(0.01) + + assert recorded_headers[0]["origin"] == "https://proxy.example" + assert client.headers["origin"] == "https://proxy.example" + + def test_session_termination(basic_server: None, basic_server_url: str): """Test session termination via DELETE and subsequent request handling.""" response = requests.post(