Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/agent-common/agent_common/a2a/structured_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_response_format(
"""Get the appropriate response_format strategy for a model.

Encapsulates the model-specific logic for structured output:
- OpenAI/Azure: ToolStrategy (avoids .parse() API that requires strict tools)
- OpenAI (Azure or direct ChatOpenAI): ToolStrategy (avoids .parse() API that requires strict tools)
- Bedrock without thinking: AutoStrategy
- Bedrock with thinking: None + SubAgentResponseSchema added as a tool
- Others (Gemini, etc.): AutoStrategy
Expand All @@ -98,7 +98,7 @@ def get_response_format(
"""
model_class = model.__class__.__name__

if model_class == "AzureChatOpenAI":
if model_class == "AzureChatOpenAI" or model_class == "ChatOpenAI":
return ToolStrategy(schema=SubAgentResponseSchema)
elif model_class == "ChatBedrockConverse":
if thinking_enabled:
Expand Down
10 changes: 9 additions & 1 deletion packages/agent-common/agent_common/core/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware, HumanInTheLoopMiddleware, ToolRetryMiddleware
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
from langchain_aws import ChatBedrockConverse
from langchain_aws.middleware.prompt_caching import BedrockPromptCachingMiddleware
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import ToolMessage
Expand Down Expand Up @@ -241,7 +242,14 @@ def build_common_middleware_stack(
truncate_args_settings=summarization_defaults["truncate_args_settings"],
),
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
BedrockPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
# BedrockPromptCachingMiddleware injects Bedrock-specific cache point hints
# into requests. Only attach it for actual Bedrock models — on OpenAI,
# Gemini, or local models it is at best a no-op and at worst confuses
# the provider with unknown fields.
if isinstance(model, ChatBedrockConverse):
middleware.append(BedrockPromptCachingMiddleware(unsupported_model_behavior="ignore"))
middleware += [
PatchToolCallsMiddleware(),
ToolRetryMiddleware(
max_retries=5,
Expand Down
6 changes: 6 additions & 0 deletions packages/agent-common/agent_common/core/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,12 @@ def create_model(
api_key=api_key,
temperature=0.7,
callbacks=callbacks,
# Force OpenAI's streaming API to include `usage` in the final
# chunk (stream_options={"include_usage": true}) so LangChain
# populates `usage_metadata` on the AIMessage. Without this,
# CostTrackingCallback sees no usage data on plain ChatOpenAI
# endpoints and OpenAI-backed conversations report zero cost.
stream_usage=True,
)
else:
# Lazy import for Azure OpenAI provider
Expand Down
84 changes: 84 additions & 0 deletions packages/agent-common/tests/test_structured_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Tests for get_response_format strategy selection."""

from unittest.mock import MagicMock

from agent_common.a2a.structured_response import (
SubAgentResponseSchema,
get_response_format,
)
from langchain.agents.structured_output import AutoStrategy, ToolStrategy


def _model_with_class(class_name: str) -> MagicMock:
model = MagicMock()
model.__class__ = type(class_name, (), {})
return model


def test_chat_openai_uses_tool_strategy():
"""Plain ChatOpenAI (real OpenAI endpoint) must use ToolStrategy.

AutoStrategy would resolve to the Responses API .parse() path, which
requires every bound tool to be strict — dynamic MCP tools are not.
"""
model = _model_with_class("ChatOpenAI")
tools: list = []

fmt = get_response_format(model, tools)

assert isinstance(fmt, ToolStrategy)
assert fmt.schema is SubAgentResponseSchema
assert tools == []


def test_azure_chat_openai_uses_tool_strategy():
model = _model_with_class("AzureChatOpenAI")
tools: list = []

fmt = get_response_format(model, tools)

assert isinstance(fmt, ToolStrategy)
assert fmt.schema is SubAgentResponseSchema


def test_bedrock_without_thinking_uses_auto_strategy():
model = _model_with_class("ChatBedrockConverse")
tools: list = []

fmt = get_response_format(model, tools, thinking_enabled=False)

assert isinstance(fmt, AutoStrategy)
assert fmt.schema is SubAgentResponseSchema
assert tools == []


def test_bedrock_with_thinking_returns_none_and_appends_tool():
model = _model_with_class("ChatBedrockConverse")
tools: list = []

fmt = get_response_format(model, tools, thinking_enabled=True)

assert fmt is None
assert len(tools) == 1
assert tools[0].name == "SubAgentResponseSchema"


def test_gemini_returns_none_and_appends_tool():
model = _model_with_class("ChatGoogleGenerativeAI")
tools: list = []

fmt = get_response_format(model, tools)

assert fmt is None
assert len(tools) == 1
assert tools[0].name == "SubAgentResponseSchema"


def test_unknown_model_falls_back_to_auto_strategy():
model = _model_with_class("SomeOtherModel")
tools: list = []

fmt = get_response_format(model, tools)

assert isinstance(fmt, AutoStrategy)
assert fmt.schema is SubAgentResponseSchema
17 changes: 13 additions & 4 deletions packages/orchestrator-agent/app/core/graph_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from langchain.agents import create_agent
from langchain.agents.middleware import HumanInTheLoopMiddleware, ToolRetryMiddleware
from langchain.agents.structured_output import AutoStrategy, ToolStrategy
from langchain_aws import ChatBedrockConverse
from langchain_aws.middleware.prompt_caching import BedrockPromptCachingMiddleware
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool, StructuredTool
Expand Down Expand Up @@ -406,7 +407,7 @@ def _get_or_create_model(self, model_type: ModelType, thinking_level: Optional[T
self._models[cache_key] = self._create_model(model_type, thinking_level)
return self._models[cache_key]

def _create_middleware_stack(self) -> list[Any]:
def _create_middleware_stack(self, model: BaseChatModel | None = None) -> list[Any]:
"""Create the complete middleware stack for a graph.

Middleware Execution Order (LangChain convention):
Expand Down Expand Up @@ -511,10 +512,17 @@ async def _forward_to_active_subagents(context_id: str, messages: list) -> None:
# confirmation before executing guarded tools (e.g. console_create_bug_report).
hitl_middleware = _create_hitl_middleware()

return [
middleware_stack: list[Any] = [
dynamic_tool_middleware,
storage_paths_middleware,
BedrockPromptCachingMiddleware(),
]
# BedrockPromptCachingMiddleware injects Bedrock-specific cache point
# hints. Only attach it for actual Bedrock models — on OpenAI, Gemini
# or local models it is at best a no-op and at worst confuses the
# provider with unknown request fields.
if isinstance(model, ChatBedrockConverse):
middleware_stack.append(BedrockPromptCachingMiddleware())
middleware_stack += [
steering_middleware,
user_preferences_middleware,
playbook_middleware,
Expand All @@ -527,6 +535,7 @@ async def _forward_to_active_subagents(context_id: str, messages: list) -> None:
self._a2a_middleware,
self._todo_middleware,
]
return middleware_stack

def get_static_tools(self, with_response_tool: bool = False) -> list[BaseTool]:
"""Get static tools for the given model type.
Expand Down Expand Up @@ -620,7 +629,7 @@ def _create_graph(self, model_type: ModelType, thinking_level: Optional[Thinking
requires_response_tool = True
else:
response_format = AutoStrategy(schema=FinalResponseSchema)
middleware = self._create_middleware_stack()
middleware = self._create_middleware_stack(model=model)
static_tools_list = self.get_static_tools(with_response_tool=requires_response_tool)

# Add Google built-in tools for Gemini models
Expand Down
38 changes: 35 additions & 3 deletions packages/orchestrator-agent/tests/test_graph_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
Graph creation with actual models should be tested in integration tests.
"""

from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

import pytest
from agent_common.middleware.storage_paths_middleware import StoragePathsInstructionMiddleware
from langchain.agents.middleware import ToolRetryMiddleware
from langchain_aws import ChatBedrockConverse
from langchain_aws.middleware.prompt_caching import BedrockPromptCachingMiddleware

from app.core.graph_factory import GraphFactory
Expand Down Expand Up @@ -111,10 +112,11 @@ class TestMiddlewareStack:
@patch("langgraph.store.postgres.aio.AsyncPostgresStore")
@patch("langgraph_checkpoint_aws.DynamoDBSaver")
def test_middleware_stack_order(self, mock_dynamodb, mock_pg_store, _mock_creds, mock_config):
"""Test that middleware stack is assembled in the correct order."""
"""Test that middleware stack is assembled in the correct order for a Bedrock model."""
factory = GraphFactory(config=mock_config)

stack = factory._create_middleware_stack()
bedrock_model = MagicMock(spec=ChatBedrockConverse)
stack = factory._create_middleware_stack(model=bedrock_model)

# Verify correct order (DynamicTool first, static content before cache point,
# steering after cache, user prefs after steering, playbook after prefs)
Expand All @@ -138,6 +140,36 @@ def test_middleware_stack_order(self, mock_dynamodb, mock_pg_store, _mock_creds,
assert isinstance(stack[12], A2ATaskTrackingMiddleware)
assert isinstance(stack[13], TodoStatusMiddleware)

@patch("app.core.graph_factory._has_aws_credentials", return_value=True)
@patch("langgraph.store.postgres.aio.AsyncPostgresStore")
@patch("langgraph_checkpoint_aws.DynamoDBSaver")
def test_middleware_stack_excludes_bedrock_caching_for_non_bedrock_models(
self, mock_dynamodb, mock_pg_store, _mock_creds, mock_config
):
"""BedrockPromptCachingMiddleware must NOT be attached for non-Bedrock models."""
factory = GraphFactory(config=mock_config)

# Non-Bedrock model (e.g. OpenAI / Gemini): plain Mock that is NOT a ChatBedrockConverse
non_bedrock_model = Mock()
stack = factory._create_middleware_stack(model=non_bedrock_model)

assert not any(isinstance(m, BedrockPromptCachingMiddleware) for m in stack)
# One fewer middleware than the Bedrock case (cache middleware skipped)
assert len(stack) == 13

@patch("app.core.graph_factory._has_aws_credentials", return_value=True)
@patch("langgraph.store.postgres.aio.AsyncPostgresStore")
@patch("langgraph_checkpoint_aws.DynamoDBSaver")
def test_middleware_stack_excludes_bedrock_caching_when_model_is_none(
self, mock_dynamodb, mock_pg_store, _mock_creds, mock_config
):
"""Default (model=None) call path must not inject Bedrock caching either."""
factory = GraphFactory(config=mock_config)

stack = factory._create_middleware_stack()

assert not any(isinstance(m, BedrockPromptCachingMiddleware) for m in stack)

@patch("app.core.graph_factory._has_aws_credentials", return_value=True)
@patch("langgraph.store.postgres.aio.AsyncPostgresStore")
@patch("langgraph_checkpoint_aws.DynamoDBSaver")
Expand Down