diff --git a/README.md b/README.md index ce5395b..33f28a8 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ - [Option 2: install via pip](#option-2-install-via-pip) - [Quickstart guide](#quickstart-guide) - [WebUI](#webui) -- [MCP tool wrapper](#mcp-tool-wrapper) +- [Model context protocol (MCP)](#model-context-protocol-mcp) ## Installation @@ -157,7 +157,9 @@ auto-scrolling back, copy everything under `examples/webui/` (including the hidd folder `.config`) into the working directory where `start_webui.py` is located. Now the JS scroller will be injected into the WebUI to enable auto-scrolling. -## MCP tool wrapper +## Model context protocol (MCP) + +### MCP tool wrapper EAA's MCP tool wrapper allows you to convert any tools that are subclasses of `BaseTool` into an MCP tool and launch an MCP server offering these tools. @@ -198,4 +200,37 @@ to activate the environment before launching the tool. Below is an example: } ``` Now the MCP client should be able to run and connect to the MCP server and use the -tool. \ No newline at end of file +tool. + +### Using MCP tools (experimental) + +EAA itself can also use MCP tools. While we still recommend using the built-in +`BaseTool` classes as function-calling tools if possible, using external MCP +tools allows you to extend the agent's capability beyond what's in the built-in tools. + +To use an external MCP tool, first create a config dictionary. This dictionary should +follow the [FastMCP format](https://gofastmcp.com/clients/client#configuration-format), +which is the same format as the `settings.json` files used by many MCP clients such as +Claude, Gemini CLI and Cursor. The dictionary should be wrapped in an `MCPTool` object. +The object should then be passed to the task manager in the same way as other `BaseTool` +objects. + +```python +from eaa.tools.mcp import MCPTool + +config = { + "mcpServers": { + "image_acquisition": { + "command": "python", + "args": ["./image_acquisition_mcp_server.py"] + } + } +} + +mcp_tool = MCPTool(config) +``` + +Known issue(s): +- EAA currently cannot tell if an MCP tool returns an image path, and as such, + routines in task managers that handle images will not work properly. + diff --git a/src/eaa/agents/base.py b/src/eaa/agents/base.py index 2246363..3b0a845 100644 --- a/src/eaa/agents/base.py +++ b/src/eaa/agents/base.py @@ -81,7 +81,6 @@ } """ -import typing from typing import ( Any, Callable, @@ -89,18 +88,17 @@ List, Tuple, Optional, - Literal, - get_type_hints, - get_args + Literal ) -import inspect import json import logging +import asyncio import numpy as np from openai.types.chat import ChatCompletionMessage -from eaa.tools.base import ToolReturnType +from eaa.tools.base import ToolReturnType, generate_openai_tool_schema +from eaa.tools.mcp import MCPTool from eaa.comms import get_api_key from eaa.util import encode_image_base64, get_image_path_from_text @@ -110,15 +108,24 @@ class ToolManager: def __init__(self): - self.tools: List[Dict[str, Any]] = [] + self.function_tools: List[Dict[str, Any]] = [] + self.mcp_tools: List[MCPTool] = [] def get_all_schema(self) -> Dict[str, Any]: - """Get the schema for the tool. + """Get the schema for the tool. MCP tools' schemas are generated + as if they are function tools, so that all tool schemas have the + same format. """ - return [generate_openai_tool_schema(tool["name"], tool["function"]) for tool in self.tools] + schemas = [] + schemas += [generate_openai_tool_schema(tool["name"], tool["function"]) for tool in self.function_tools] + mcp_schemas = [] + for tool in self.mcp_tools: + mcp_schemas += tool.get_all_schema() + schemas += mcp_schemas + return schemas - def add_tool(self, name: str, tool_function: Callable, return_type: ToolReturnType) -> None: - """Add a tool to the tool manager. + def add_function_tool(self, name: str, tool_function: Callable, return_type: ToolReturnType) -> None: + """Add a function tool to the tool manager. Parameters ---------- @@ -130,7 +137,7 @@ def add_tool(self, name: str, tool_function: Callable, return_type: ToolReturnTy return_type : ToolReturnType The type of the return value of the tool. """ - self.tools.append( + self.function_tools.append( { "name": name, "function": tool_function, @@ -138,6 +145,11 @@ def add_tool(self, name: str, tool_function: Callable, return_type: ToolReturnTy "schema": generate_openai_tool_schema(name, tool_function) } ) + + def add_mcp_tool(self, tool: MCPTool): + """Add an MCP tool to the tool manager. + """ + self.mcp_tools.append(tool) def execute_tool( self, @@ -153,30 +165,50 @@ def execute_tool( tool_kwargs : Dict[str, Any] The arguments to be passed to the tool. """ - return self.get_tool_callable(tool_name)(**tool_kwargs) + callable = self.get_tool_callable(tool_name) + if isinstance(callable, MCPTool): + loop = asyncio.get_event_loop() + return loop.run_until_complete(callable.call_tool(tool_name, tool_kwargs)) + else: + return callable(**tool_kwargs) - def get_tool_dict(self, tool_name: str) -> Dict[str, Any]: - """Get the tool dictionary for a given tool name. + def get_tool(self, tool_name: str) -> Dict[str, Any] | MCPTool: + """Get the tool dictionary or MCPTool object for a given tool name. """ - for tool in self.tools: + for tool in self.function_tools: if tool["name"] == tool_name: return tool + for tool in self.mcp_tools: + if tool_name in tool.get_all_tool_names(): + return tool raise ValueError(f"Tool {tool_name} not found.") def get_tool_return_type(self, tool_name: str) -> ToolReturnType: """Get the return type of a tool. """ - return self.get_tool_dict(tool_name)["return_type"] + tool = self.get_tool(tool_name) + if isinstance(tool, MCPTool): + return ToolReturnType.TEXT + else: + return tool["return_type"] - def get_tool_callable(self, tool_name: str) -> Callable: - """Get the callable function for a given tool name. + def get_tool_callable(self, tool_name: str) -> Callable | MCPTool: + """Get the callable function or MCPTool object for a given tool name. """ - return self.get_tool_dict(tool_name)["function"] + tool = self.get_tool(tool_name) + if isinstance(tool, MCPTool): + return tool + else: + return tool["function"] def get_tool_schema(self, tool_name: str) -> Dict[str, Any]: """Get the schema for a given tool name. """ - return self.get_tool_dict(tool_name)["schema"] + tool = self.get_tool(tool_name) + if isinstance(tool, MCPTool): + return tool.get_all_schema() + else: + return tool["schema"] class BaseAgent: @@ -243,7 +275,7 @@ def api_key(self) -> str: def create_client(self) -> Any: raise NotImplementedError - def register_tools(self, tools: List[Dict[str, Any]]) -> None: + def register_function_tools(self, tools: List[Dict[str, Any]]) -> None: """Register tools with the OpenAI-compatible API. Parameters @@ -259,11 +291,21 @@ def register_tools(self, tools: List[Dict[str, Any]]) -> None: ) for tool_dict in tools: - self.tool_manager.add_tool( + self.tool_manager.add_function_tool( name=tool_dict["name"], tool_function=tool_dict["function"], return_type=tool_dict["return_type"] ) + + def register_mcp_tools(self, tools: List[MCPTool]) -> None: + """Register MCP tools with the OpenAI-compatible API. + """ + if not isinstance(tools, List): + raise ValueError( + "tools must be a list of MCPTool objects." + ) + for tool in tools: + self.tool_manager.add_mcp_tool(tool) def receive( self, @@ -599,76 +641,6 @@ def generate_openai_message( return message -def generate_openai_tool_schema(tool_name: str, func: Callable) -> Dict[str, Any]: - """ - Generates an OpenAI-compatible tool schema from a Python function - with type annotations and a docstring. - - Parameters - ---------- - tool_name : str - The name of the tool. - func : Callable - The function to generate the tool schema from. - - Returns - ------- - dict - The OpenAI-compatible tool schema. - """ - sig = inspect.signature(func) - type_hints = get_type_hints(func) - doc = inspect.getdoc(func) or "" - - # JSON schema type mapping - python_type_to_json = { - str: "string", - int: "integer", - float: "number", - bool: "boolean", - list: "array", - tuple: "array", - dict: "object" - } - - def resolve_json_type(py_type): - origin = typing.get_origin(py_type) - args = typing.get_args(py_type) - if origin is list or origin is typing.List: - return { - "type": "array", - "items": {"type": python_type_to_json.get(args[0], "string")} - } - return {"type": python_type_to_json.get(py_type, "string")} - - properties = {} - required = [] - - for name, param in sig.parameters.items(): - if name not in type_hints: - continue - json_type = resolve_json_type(type_hints[name]) - description = f"{name} parameter" - if len(get_args(sig.parameters[name].annotation)) > 0: - description = get_args(sig.parameters[name].annotation)[1] - properties[name] = {**json_type, "description": description} - if param.default == inspect.Parameter.empty: - required.append(name) - - return { - "type": "function", - "function": { - "name": tool_name, - "description": doc, - "parameters": { - "type": "object", - "properties": properties, - "required": required - } - } - } - - def has_tool_call(message: dict | ChatCompletionMessage) -> bool: """Check if the message has a tool call. diff --git a/src/eaa/mcp/server.py b/src/eaa/mcp/server.py index 19b9c87..e552c52 100644 --- a/src/eaa/mcp/server.py +++ b/src/eaa/mcp/server.py @@ -16,8 +16,7 @@ "Install it with: pip install fastmcp" ) -from eaa.tools.base import BaseTool -from eaa.agents.base import generate_openai_tool_schema +from eaa.tools.base import BaseTool, generate_openai_tool_schema logger = logging.getLogger(__name__) diff --git a/src/eaa/task_managers/base.py b/src/eaa/task_managers/base.py index b8e18db..2f401cb 100644 --- a/src/eaa/task_managers/base.py +++ b/src/eaa/task_managers/base.py @@ -9,6 +9,7 @@ from eaa.util import get_timestamp from eaa.tools.base import ToolReturnType from eaa.api.llm_config import LLMConfig, OpenAIConfig, AskSageConfig +from eaa.tools.mcp import MCPTool try: from eaa.agents.asksage import AskSageAgent except ImportError: @@ -122,9 +123,11 @@ def register_tools( ) -> None: if not isinstance(tools, (list, tuple)): tools = [tools] - self.agent.register_tools( - self.create_tool_list(tools) - ) + for tool in tools: + if isinstance(tool, MCPTool): + self.agent.register_mcp_tools([tool]) + else: + self.agent.register_function_tools(self.create_tool_list([tool])) def create_tool_list(self, tools: list[BaseTool]) -> list[dict]: """Create a list of tool dictionaries by concatenating the exposed_tools diff --git a/src/eaa/tools/base.py b/src/eaa/tools/base.py index 2ab69db..8e7e9b7 100644 --- a/src/eaa/tools/base.py +++ b/src/eaa/tools/base.py @@ -1,8 +1,10 @@ -from typing import Optional, Dict, Callable, List, Any +import typing +from typing import Optional, Dict, Callable, List, Any, get_args, get_type_hints import base64 import os import io from enum import StrEnum, auto +import inspect import matplotlib.pyplot as plt import numpy as np @@ -167,3 +169,73 @@ def wrapper(self, *args, **kwargs): ) return return_value return wrapper + + +def generate_openai_tool_schema(tool_name: str, func: Callable) -> Dict[str, Any]: + """ + Generates an OpenAI-compatible tool schema from a Python function + with type annotations and a docstring. + + Parameters + ---------- + tool_name : str + The name of the tool. + func : Callable + The function to generate the tool schema from. + + Returns + ------- + dict + The OpenAI-compatible tool schema. + """ + sig = inspect.signature(func) + type_hints = get_type_hints(func) + doc = inspect.getdoc(func) or "" + + # JSON schema type mapping + python_type_to_json = { + str: "string", + int: "integer", + float: "number", + bool: "boolean", + list: "array", + tuple: "array", + dict: "object" + } + + def resolve_json_type(py_type): + origin = typing.get_origin(py_type) + args = typing.get_args(py_type) + if origin is list or origin is typing.List: + return { + "type": "array", + "items": {"type": python_type_to_json.get(args[0], "string")} + } + return {"type": python_type_to_json.get(py_type, "string")} + + properties = {} + required = [] + + for name, param in sig.parameters.items(): + if name not in type_hints: + continue + json_type = resolve_json_type(type_hints[name]) + description = f"{name} parameter" + if len(get_args(sig.parameters[name].annotation)) > 0: + description = get_args(sig.parameters[name].annotation)[1] + properties[name] = {**json_type, "description": description} + if param.default == inspect.Parameter.empty: + required.append(name) + + return { + "type": "function", + "function": { + "name": tool_name, + "description": doc, + "parameters": { + "type": "object", + "properties": properties, + "required": required + } + } + } diff --git a/src/eaa/tools/mcp.py b/src/eaa/tools/mcp.py new file mode 100644 index 0000000..28c8ae0 --- /dev/null +++ b/src/eaa/tools/mcp.py @@ -0,0 +1,147 @@ +import asyncio + +import fastmcp + +from eaa.tools.base import BaseTool + + +class MCPTool(BaseTool): + + def __init__( + self, + config: dict, + *args, **kwargs + ): + """Initialize an MCP tool. + + Parameters + ---------- + config : dict + A dictionary giving the configurations of one or multiple MCP + servers. The structure of the dictionary should follow the standard + of FastMCP (https://gofastmcp.com/clients/client): + ``` + config = { + "mcpServers": { + "server_name": { + # Remote HTTP/SSE server + "transport": "http", # or "sse" + "url": "https://api.example.com/mcp", + "headers": {"Authorization": "Bearer token"}, + "auth": "oauth" # or bearer token string + }, + "local_server": { + # Local stdio server + "transport": "stdio", + "command": "python", + "args": ["./server.py", "--verbose"], + "env": {"DEBUG": "true"}, + "cwd": "/path/to/server", + } + } + } + ``` + Below is a multi-server example from the FastMCP documentation: + ``` + config = { + "mcpServers": { + "weather": {"url": "https://weather-api.example.com/mcp"}, + "assistant": {"command": "python", "args": ["./assistant_server.py"]} + } + } + ``` + """ + super().__init__(*args, **kwargs) + self.config = config + self._client = None + self._connected = False + self._loop = asyncio.get_event_loop() + + async def _ensure_connected(self): + """Ensure the MCP client is connected.""" + if not self._connected or self._client is None: + await self.connect() + + async def connect(self): + """Connect to the MCP server.""" + if self._client is not None: + await self.disconnect() + + self._client = fastmcp.Client(self.config) + await self._client.__aenter__() + self._connected = True + + async def disconnect(self): + """Disconnect from the MCP server.""" + if self._client is not None and self._connected: + await self._client.__aexit__(None, None, None) + self._connected = False + self._client = None + + async def list_tools(self): + """List the tools available on the MCP server.""" + await self._ensure_connected() + return await self._client.list_tools() + + async def list_resources(self): + """List the resources available on the MCP server.""" + await self._ensure_connected() + return await self._client.list_resources() + + async def call_tool(self, tool_name: str, arguments: dict): + """Call a tool on the MCP server.""" + await self._ensure_connected() + result = await self._client.call_tool(tool_name, arguments) + return result.structured_content["result"] + + def get_all_schema(self): + """Get the function call-like schema for all the tools + available on the MCP server. + """ + tools = self._loop.run_until_complete(self.list_tools()) + schemas = [] + for tool in tools: + schema = { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": tool.inputSchema["properties"], + "required": tool.inputSchema["required"] + } + } + } + schemas.append(schema) + return schemas + + def get_all_tool_names(self): + """Get the names of all the tools available on the MCP server.""" + tools = self._loop.run_until_complete(self.list_tools()) + return [tool.name for tool in tools] + + async def __aenter__(self): + """Async context manager entry.""" + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.disconnect() + + def __del__(self): + """Cleanup when the object is destroyed.""" + if self._connected and self._client is not None: + # Try to clean up the connection, but don't fail if event loop is closed + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # If loop is running, schedule cleanup + loop.create_task(self.disconnect()) + else: + # If loop is not running, run cleanup synchronously + loop.run_until_complete(self.disconnect()) + except RuntimeError: + # Event loop might be closed, ignore cleanup + pass diff --git a/tests/test_asksage_api.py b/tests/test_asksage_api.py index ee2eb0e..2ffa2d0 100644 --- a/tests/test_asksage_api.py +++ b/tests/test_asksage_api.py @@ -55,7 +55,7 @@ def list_sum(numbers: list[float]) -> float: system_message="You are a helpful assistant." ) - agent.register_tools( + agent.register_function_tools( [ { "name": "list_sum", @@ -104,7 +104,7 @@ def get_image() -> str: system_message="You are a helpful assistant." ) - agent.register_tools( + agent.register_function_tools( [ { "name": "get_image", diff --git a/tests/test_oai_api.py b/tests/test_oai_api.py index a9adc07..d665ee7 100644 --- a/tests/test_oai_api.py +++ b/tests/test_oai_api.py @@ -49,7 +49,7 @@ def list_sum(numbers: list[float]) -> float: system_message="You are a helpful assistant." ) - agent.register_tools( + agent.register_function_tools( [ { "name": "list_sum", @@ -95,7 +95,7 @@ def get_image() -> str: system_message="You are a helpful assistant." ) - agent.register_tools( + agent.register_function_tools( [ { "name": "get_image",