Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from urllib.parse import urlsplit

import anyio
import httpx
Expand Down Expand Up @@ -50,6 +51,15 @@
MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up


def _get_default_origin(url: str) -> str | None:
parsed_url = urlsplit(url)
if parsed_url.scheme not in {"http", "https"} or not parsed_url.netloc:
return None

authority = parsed_url.netloc.rsplit("@", 1)[-1]
return f"{parsed_url.scheme}://{authority}"


class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors."""

Expand All @@ -72,13 +82,16 @@ class RequestContext:
class StreamableHTTPTransport:
"""StreamableHTTP client transport implementation."""

def __init__(self, url: str) -> None:
def __init__(self, url: str, default_origin: str | None = None) -> None:
"""Initialize the StreamableHTTP transport.

Args:
url: The endpoint URL.
default_origin: Origin header to include when the caller has not
configured one on the HTTP client.
"""
self.url = url
self.default_origin = default_origin
self.session_id: str | None = None
self.protocol_version: str | None = None

Expand All @@ -92,6 +105,8 @@ def _prepare_headers(self) -> dict[str, str]:
"accept": "application/json, text/event-stream",
"content-type": "application/json",
}
if self.default_origin:
headers["Origin"] = self.default_origin
# Add session headers if available
if self.session_id:
headers[MCP_SESSION_ID] = self.session_id
Expand Down Expand Up @@ -547,7 +562,8 @@ async def streamable_http_client(
# Create default client with recommended MCP timeouts
client = create_mcp_http_client()

transport = StreamableHTTPTransport(url)
default_origin = None if "origin" in client.headers else _get_default_origin(url)
transport = StreamableHTTPTransport(url, default_origin=default_origin)

logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")

Expand Down
54 changes: 53 additions & 1 deletion tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from mcp import MCPError, types
from mcp.client.session import ClientSession
from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client
from mcp.client.streamable_http import StreamableHTTPTransport, _get_default_origin, streamable_http_client
from mcp.server import Server, ServerRequestContext
from mcp.server.streamable_http import (
MCP_PROTOCOL_VERSION_HEADER,
Expand Down Expand Up @@ -767,6 +767,58 @@ def test_streamable_http_transport_init_validation():
StreamableHTTPServerTransport(mcp_session_id="test\n")


def test_get_default_origin_derives_origin_from_url():
assert _get_default_origin("https://example.com:8443/mcp?token=abc") == "https://example.com:8443"
assert _get_default_origin("http://user:pass@[::1]:8080/mcp") == "http://[::1]:8080"


@pytest.mark.anyio
async def test_streamable_http_client_sets_default_origin_on_http_client():
recorded_headers: list[httpx.Headers] = []

def handler(request: httpx.Request) -> httpx.Response:
recorded_headers.append(request.headers)
return httpx.Response(202, request=request)

async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
async with streamable_http_client("https://mcp.example.com:8443/mcp", http_client=client) as (
_read_stream,
write_stream,
):
await write_stream.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")))
with anyio.fail_after(1):
while not recorded_headers:
await anyio.sleep(0.01)

assert recorded_headers[0]["origin"] == "https://mcp.example.com:8443"
assert "origin" not in client.headers


@pytest.mark.anyio
async def test_streamable_http_client_preserves_custom_origin_header():
recorded_headers: list[httpx.Headers] = []

def handler(request: httpx.Request) -> httpx.Response:
recorded_headers.append(request.headers)
return httpx.Response(202, request=request)

async with httpx.AsyncClient(
headers={"Origin": "https://proxy.example"},
transport=httpx.MockTransport(handler),
) as client:
async with streamable_http_client("https://mcp.example.com/mcp", http_client=client) as (
_read_stream,
write_stream,
):
await write_stream.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")))
with anyio.fail_after(1):
while not recorded_headers:
await anyio.sleep(0.01)

assert recorded_headers[0]["origin"] == "https://proxy.example"
assert client.headers["origin"] == "https://proxy.example"


def test_session_termination(basic_server: None, basic_server_url: str):
"""Test session termination via DELETE and subsequent request handling."""
response = requests.post(
Expand Down
Loading