diff --git a/.gitignore b/.gitignore index 5717c73..3ae101e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ .env +node_modules/ +package-lock.json +package.json \ No newline at end of file diff --git a/README.md b/README.md index 3920e4f..a9c0994 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,71 @@ This project uses HashiCorp Cloud Platform (HCP) Vault for secure secrets manage #### WebSocket - `VITE_WEBSOCKET_URL=ws://localhost:8081/ws` +--- +### Configuring External MCP Servers (OPTIONAL) + +CortexON supports integration with external MCP (Model Context Protocol) servers for extended capabilities. Configure these in the `cortex_on/config/external_mcp_servers.json` file. + +#### 1. GitHub Personal Access Token + +1. **Create a GitHub Account** if you don't already have one at [github.com](https://github.com) + +2. **Generate a Personal Access Token (PAT)**: + - Follow the steps as listed here: [Personal Access Token Setup](https://github.com/modelcontextprotocol/servers/tree/main/src/github#setup) + +3. **Add the Token to Your Configuration**: + - Open `cortex_on/config/external_mcp_servers.json` + - Find the GitHub section and replace the empty token: + ```json + "env": { + "GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_YourTokenHere" + } + ``` +4. **Sample Queries** + - Update the README.md file in the repository by on branch main. Insert the line "Changed by CortexOn" in the end. Provide the updated file content as the content parameter and set branch as main. + - List the latest commit and in which repo the commit was made by + +#### 2. Google Maps API Key + +1. **Create a Google Cloud Account**: + - Go to [Google Cloud Console](https://console.cloud.google.com/) + - Create an account or sign in with your Google account + +2. **Create a New Project**: + - In the cloud console, click on the project dropdown at the top + - Click "New Project" + - Name it (e.g., "CortexON Maps") + - Click "Create" + +3. **Enable the Required APIs**: + - In your project, go to "APIs & Services" → "Library" + - Search for and enable these APIs: + * Maps JavaScript API + * Geocoding API + * Directions API + * Places API + * Distance Matrix API + - You can enable more APIs as per your requirements + +4. **Create an API Key**: + - Go to "APIs & Services" → "Credentials" + - Click "Create Credentials" → "API Key" + - Your new API key will be displayed + +5. **Add the API Key to Your Configuration**: + - Open `cortex_on/config/external_mcp_servers.json` + - Find the Google Maps section and replace the empty key: + ```json + "env": { + "GOOGLE_MAPS_API_KEY": "" + } + ``` +6. **Sample Queries** + - Find the closest pizza shops to \[address] within a 5-mile radius + - Find the shortest driving route that includes the following stops: \[address 1], \[address 2], and \[address 3] + +--- + ### Docker Setup 1. Clone the CortexON repository: diff --git a/cortex_on/Dockerfile b/cortex_on/Dockerfile index 8d7373e..9cbcc4a 100644 --- a/cortex_on/Dockerfile +++ b/cortex_on/Dockerfile @@ -8,6 +8,12 @@ RUN apt-get update && apt-get install -y \ build-essential \ cmake \ g++ \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Install Node.js and npm +RUN curl -fsSL https://deb.nodesource.com/setup_18.x | bash - \ + && apt-get install -y nodejs \ && rm -rf /var/lib/apt/lists/* RUN export PYTHONPATH=/app @@ -18,6 +24,13 @@ RUN uv pip install --system --no-cache-dir -r requirements.txt COPY . . +# Set environment variables +ENV PYTHONPATH=/app +ENV ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY} +ENV ANTHROPIC_MODEL_NAME=${ANTHROPIC_MODEL_NAME:-claude-3-sonnet-20240229} + EXPOSE 8081 +EXPOSE 3001 +# Run only the main API - MCP server will be started programmatically CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8081"] \ No newline at end of file diff --git a/cortex_on/README.md b/cortex_on/README.md index 99ae575..b730373 100644 --- a/cortex_on/README.md +++ b/cortex_on/README.md @@ -3,3 +3,6 @@ - configure `.env` (using example `.env.copy`) - either run `python -m src.main` in root folder - or run `uvicorn --reload --access-log --host 0.0.0.0 --port 8001 src.main:app` to use with frontend + + + diff --git a/cortex_on/agents/code_agent.py b/cortex_on/agents/code_agent.py index 43fa047..7ec59e7 100644 --- a/cortex_on/agents/code_agent.py +++ b/cortex_on/agents/code_agent.py @@ -13,6 +13,7 @@ from pydantic import BaseModel, Field from pydantic_ai import Agent, RunContext from pydantic_ai.models.anthropic import AnthropicModel +from pydantic_ai.providers.anthropic import AnthropicProvider # Local application imports from utils.ant_client import get_client @@ -236,10 +237,12 @@ async def send_stream_update(ctx: RunContext[CoderAgentDeps], message: str) -> N stream_output_json = json.dumps(asdict(ctx.deps.stream_output)) logfire.debug("WebSocket message sent: {stream_output_json}", stream_output_json=stream_output_json) -# Initialize the model +# Initialize Anthropic provider with API key +provider = AnthropicProvider(api_key=os.environ.get("ANTHROPIC_API_KEY")) + model = AnthropicModel( model_name=os.environ.get("ANTHROPIC_MODEL_NAME"), - anthropic_client=get_client() + provider=provider ) # Initialize the agent diff --git a/cortex_on/agents/mcp_server.py b/cortex_on/agents/mcp_server.py new file mode 100644 index 0000000..37f8d82 --- /dev/null +++ b/cortex_on/agents/mcp_server.py @@ -0,0 +1,423 @@ +#Standard library imports +import uuid +import threading +import os +from typing import List, Optional, Dict, Any, Union, Tuple +import json +from dataclasses import asdict + +#Third party imports +from mcp.server.fastmcp import FastMCP +from fastapi import WebSocket +import logfire + +#Local imports +from utils.stream_response_format import StreamResponse +from agents.planner_agent import planner_agent +from agents.code_agent import coder_agent +from agents.code_agent import coder_agent, CoderAgentDeps +from agents.web_surfer import WebSurfer + +# Server manager to handle multiple MCP servers +class ServerManager: + def __init__(self): + self.servers = {} # Dictionary to track running servers by port + self.default_port = 8002 # default port for main MCP server with agents as a tool + + def start_server(self, port=None, name=None): + """Start an MCP server on the specified port""" + if port is None: + port = self.default_port + + if name is None: + name = f"mcp_server_{port}" + + # Check if server is already running on this port + if port in self.servers and self.servers[port]['running']: + logfire.info(f"MCP server already running on port {port}") + return + + # Configure server for this port + server_instance = FastMCP(name=name, host="0.0.0.0", port=port) + + # Track server in our registry + self.servers[port] = { + 'running': True, + 'name': name, + 'instance': server_instance, + 'thread': None + } + + def run_server(): + logfire.info(f"Starting MCP server '{name}' on port {port}...") + # Configure the server to use the specified port + server_instance.run(transport="sse") + + # Start in a separate thread + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + self.servers[port]['thread'] = thread + logfire.info(f"MCP server thread started for '{name}' on port {port}") + + def get_server(self, port=None): + """Get the server instance for the specified port""" + if port is None: + port = self.default_port + + if port in self.servers: + return self.servers[port]['instance'] + return None + +# Initialize the server manager +server_manager = ServerManager() + +def start_mcp_server(port=None, name=None): + """Start an MCP server on the specified port""" + server_manager.start_server(port=port, name=name) + #we can add multiple servers here + +# For backwards compatibility +def start_mcp_server_in_thread(): + """Start the MCP server in a separate thread (legacy function)""" + start_mcp_server() + +def register_tools_for_main_mcp_server(websocket: WebSocket, port=None) -> None: + """ + Dynamically register MCP server tools with the provided WebSocket. + This ensures all tools have access to the active WebSocket connection. + + Args: + websocket: The active WebSocket connection + port: Optional port number to target a specific MCP server + """ + # Get the appropriate server instance + server_instance = server_manager.get_server(port) + if server_instance is None: + logfire.error(f"No MCP server found on port {port or server_manager.default_port}") + return + + # First, unregister existing tools if they exist + tool_names = ["plan_task", "code_task", "web_surf_task", "ask_human", "planner_agent_update"] + for tool_name in tool_names: + if tool_name in server_instance._tool_manager._tools: + del server_instance._tool_manager._tools[tool_name] + + logfire.info("Registering MCP tools with WebSocket connection") + + async def plan_task(task: str) -> str: + """Plans the task and assigns it to the appropriate agents""" + try: + logfire.info(f"Planning task: {task}") + planner_stream_output = StreamResponse( + agent_name="Planner Agent", + instructions=task, + steps=[], + output="", + status_code=0, + message_id=str(uuid.uuid4()) + ) + + await _safe_websocket_send(websocket, planner_stream_output) + + # Update planner stream + planner_stream_output.steps.append("Planning task...") + await _safe_websocket_send(websocket, planner_stream_output) + + # Run planner agent + planner_response = await planner_agent.run(user_prompt=task) + + # Update planner stream with results + plan_text = planner_response.data.plan + planner_stream_output.steps.append("Task planned successfully") + planner_stream_output.output = plan_text + planner_stream_output.status_code = 200 + await _safe_websocket_send(websocket, planner_stream_output) + + return f"Task planned successfully\nTask: {plan_text}" + except Exception as e: + error_msg = f"Error planning task: {str(e)}" + logfire.error(error_msg, exc_info=True) + + # Update planner stream with error + if 'planner_stream_output' in locals(): + planner_stream_output.steps.append(f"Planning failed: {str(e)}") + planner_stream_output.status_code = 500 + await _safe_websocket_send(websocket, planner_stream_output) + + return f"Failed to plan task: {error_msg}" + + async def code_task(task: str) -> str: + """Assigns coding tasks to the coder agent""" + try: + logfire.info(f"Assigning coding task: {task}") + # Create a new StreamResponse for Coder Agent + coder_stream_output = StreamResponse( + agent_name="Coder Agent", + instructions=task, + steps=[], + output="", + status_code=0, + message_id=str(uuid.uuid4()) + ) + + await _safe_websocket_send(websocket, coder_stream_output) + + # Create deps with the new stream_output + deps_for_coder_agent = CoderAgentDeps( + websocket=websocket, + stream_output=coder_stream_output + ) + + # Run coder agent + coder_response = await coder_agent.run( + user_prompt=task, + deps=deps_for_coder_agent + ) + + # Extract response data + response_data = coder_response.data.content + + # Update coder_stream_output with coding results + coder_stream_output.output = response_data + coder_stream_output.status_code = 200 + coder_stream_output.steps.append("Coding task completed successfully") + await _safe_websocket_send(websocket, coder_stream_output) + + # Add a reminder in the result message to update the plan using planner_agent_update + response_with_reminder = f"{response_data}\n\nReminder: You must now call planner_agent_update with the completed task description: \"{task} (coder_agent)\"" + + return response_with_reminder + except Exception as e: + error_msg = f"Error assigning coding task: {str(e)}" + logfire.error(error_msg, exc_info=True) + + # Update coder_stream_output with error + if 'coder_stream_output' in locals(): + coder_stream_output.steps.append(f"Coding task failed: {str(e)}") + coder_stream_output.status_code = 500 + await _safe_websocket_send(websocket, coder_stream_output) + + return f"Failed to assign coding task: {error_msg}" + + async def web_surf_task(task: str) -> str: + """Assigns web surfing tasks to the web surfer agent""" + try: + logfire.info(f"Assigning web surfing task: {task}") + + # Create a new StreamResponse for WebSurfer + web_surfer_stream_output = StreamResponse( + agent_name="Web Surfer", + instructions=task, + steps=[], + output="", + status_code=0, + live_url=None, + message_id=str(uuid.uuid4()) + ) + + await _safe_websocket_send(websocket, web_surfer_stream_output) + + # Initialize WebSurfer agent + web_surfer_agent = WebSurfer(api_url="http://localhost:8000/api/v1/web/stream") + + # Run WebSurfer with its own stream_output + success, message, messages = await web_surfer_agent.generate_reply( + instruction=task, + websocket=websocket, + stream_output=web_surfer_stream_output + ) + + # Update WebSurfer's stream_output with final result + if success: + web_surfer_stream_output.steps.append("Web search completed successfully") + web_surfer_stream_output.output = message + web_surfer_stream_output.status_code = 200 + + # Add a reminder to update the plan + message_with_reminder = f"{message}\n\nReminder: You must now call planner_agent_update with the completed task description: \"{task} (web_surfer_agent)\"" + else: + web_surfer_stream_output.steps.append(f"Web search completed with issues: {message[:100]}") + web_surfer_stream_output.status_code = 500 + message_with_reminder = message + + await _safe_websocket_send(websocket, web_surfer_stream_output) + + web_surfer_stream_output.steps.append(f"WebSurfer completed: {'Success' if success else 'Failed'}") + await _safe_websocket_send(websocket, web_surfer_stream_output) + + return message_with_reminder + except Exception as e: + error_msg = f"Error assigning web surfing task: {str(e)}" + logfire.error(error_msg, exc_info=True) + + # Update WebSurfer's stream_output with error + if 'web_surfer_stream_output' in locals(): + web_surfer_stream_output.steps.append(f"Web search failed: {str(e)}") + web_surfer_stream_output.status_code = 500 + await _safe_websocket_send(websocket, web_surfer_stream_output) + return f"Failed to assign web surfing task: {error_msg}" + + async def planner_agent_update(completed_task: str) -> str: + """ + Updates the todo.md file to mark a task as completed and returns the full updated plan. + """ + try: + logfire.info(f"Updating plan with completed task: {completed_task}") + # Create a new StreamResponse for Planner Agent update + planner_stream_output = StreamResponse( + agent_name="Planner Agent", + instructions=f"Update todo.md to mark as completed: {completed_task}", + steps=[], + output="", + status_code=0, + message_id=str(uuid.uuid4()) + ) + + # Send initial update + await _safe_websocket_send(websocket, planner_stream_output) + + # Directly read and update the todo.md file + base_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + planner_dir = os.path.join(base_dir, "agents", "planner") + todo_path = os.path.join(planner_dir, "todo.md") + + planner_stream_output.steps.append("Reading current todo.md...") + await _safe_websocket_send(websocket, planner_stream_output) + + # Make sure the directory exists + os.makedirs(planner_dir, exist_ok=True) + + try: + # Check if todo.md exists + if not os.path.exists(todo_path): + planner_stream_output.steps.append("No todo.md file found. Will create new one after task completion.") + await _safe_websocket_send(websocket, planner_stream_output) + + # We'll directly call planner_agent.run() to create a new plan first + plan_prompt = f"Create a simple task plan based on this completed task: {completed_task}" + plan_response = await planner_agent.run(user_prompt=plan_prompt) + current_content = plan_response.data.plan + else: + # Read existing todo.md + with open(todo_path, "r") as file: + current_content = file.read() + planner_stream_output.steps.append(f"Found existing todo.md ({len(current_content)} bytes)") + await _safe_websocket_send(websocket, planner_stream_output) + + # Now call planner_agent.run() with specific instructions to update the plan + update_prompt = f""" + Here is the current todo.md content: + + {current_content} + + Please update this plan to mark the following task as completed: {completed_task} + Return ONLY the fully updated plan with appropriate tasks marked as [x] instead of [ ]. + """ + + planner_stream_output.steps.append("Asking planner to update the plan...") + await _safe_websocket_send(websocket, planner_stream_output) + + updated_plan_response = await planner_agent.run(user_prompt=update_prompt) + updated_plan = updated_plan_response.data.plan + + # Write the updated plan back to todo.md + with open(todo_path, "w") as file: + file.write(updated_plan) + + planner_stream_output.steps.append("Plan updated successfully") + planner_stream_output.output = updated_plan + planner_stream_output.status_code = 200 + await _safe_websocket_send(websocket, planner_stream_output) + + return updated_plan + + except Exception as e: + error_msg = f"Error during plan update operations: {str(e)}" + logfire.error(error_msg, exc_info=True) + + planner_stream_output.steps.append(f"Plan update failed: {str(e)}") + planner_stream_output.status_code = 500 + await _safe_websocket_send(websocket, planner_stream_output) + + return f"Failed to update the plan: {error_msg}" + + except Exception as e: + error_msg = f"Error updating plan: {str(e)}" + logfire.error(error_msg, exc_info=True) + + return f"Failed to update plan: {error_msg}" + + # Helper function for websocket communication + async def _safe_websocket_send(socket: WebSocket, message: Any) -> bool: + """Safely send message through websocket with error handling""" + try: + if socket and socket.client_state.CONNECTED: + await socket.send_text(json.dumps(asdict(message))) + logfire.debug("WebSocket message sent (_safe_websocket_send): {message}", message=message) + return True + return False + except Exception as e: + logfire.error(f"WebSocket send failed: {str(e)}") + return False + + # Now register all the generated tools with the MCP server + tool_definitions = { + "plan_task": (plan_task, "Plans the task and assigns it to the appropriate agents"), + "code_task": (code_task, "Assigns coding tasks to the coder agent"), + "web_surf_task": (web_surf_task, "Assigns web surfing tasks to the web surfer agent"), + "planner_agent_update": (planner_agent_update, "Updates the todo.md file to mark a task as completed") + } + + # Register each tool with the specified server instance + for name, (fn, desc) in tool_definitions.items(): + server_instance._tool_manager.add_tool(fn, name=name, description=desc) + + logfire.info(f"Successfully registered {len(tool_definitions)} tools with the MCP server on port {port or server_manager.default_port}") + + +def get_unique_tool_name(tool_name: str, registered_names: set) -> str: + """Ensure a tool name is unique by adding a suffix if necessary""" + if tool_name not in registered_names: + return tool_name + + # Add numeric suffix to make the name unique + base_name = tool_name + suffix = 1 + while f"{base_name}_{suffix}" in registered_names: + suffix += 1 + return f"{base_name}_{suffix}" + +def check_mcp_server_tools(server, registered_tools: set) -> None: + """Check and fix duplicate tool names in an MCP server""" + try: + # This relies on implementation details of MCP Server + if hasattr(server, '_mcp_api') and server._mcp_api: + api = server._mcp_api + + # Check if API has a tool manager + if hasattr(api, '_tool_manager'): + tool_manager = api._tool_manager + + # Check if the tool manager has tools + if hasattr(tool_manager, '_tools') and tool_manager._tools: + # Get a copy of original tool names + original_tools = list(tool_manager._tools.keys()) + for tool_name in original_tools: + # If this tool name conflicts with existing ones + if tool_name in registered_tools: + # Create a unique name + unique_name = get_unique_tool_name(tool_name, registered_tools) + # Get the tool + tool = tool_manager._tools[tool_name] + # Add it with the new name + tool_manager._tools[unique_name] = tool + # Remove the old one + del tool_manager._tools[tool_name] + # Add the new name to the registry + registered_tools.add(unique_name) + logfire.info(f"Renamed tool {tool_name} to {unique_name} to avoid duplicate") + else: + # Add the name to the registry + registered_tools.add(tool_name) + except Exception as e: + logfire.error(f"Error checking MCP server tools: {str(e)}") \ No newline at end of file diff --git a/cortex_on/agents/orchestrator_agent.py b/cortex_on/agents/orchestrator_agent.py index 6b001ba..bd16d1b 100644 --- a/cortex_on/agents/orchestrator_agent.py +++ b/cortex_on/agents/orchestrator_agent.py @@ -1,20 +1,22 @@ +#Standard library imports import os import json -import traceback from typing import List, Optional, Dict, Any, Union, Tuple -from datetime import datetime -from pydantic import BaseModel +import uuid from dataclasses import asdict, dataclass + +#Third party imports import logfire from fastapi import WebSocket from dotenv import load_dotenv from pydantic_ai.models.anthropic import AnthropicModel +from pydantic_ai.providers.anthropic import AnthropicProvider from pydantic_ai import Agent, RunContext -from agents.web_surfer import WebSurfer +from pydantic_ai.mcp import MCPServerHTTP + +#Local imports from utils.stream_response_format import StreamResponse -from agents.planner_agent import planner_agent, update_todo_status -from agents.code_agent import coder_agent, CoderAgentDeps -from utils.ant_client import get_client +load_dotenv() @dataclass class orchestrator_deps: @@ -23,6 +25,7 @@ class orchestrator_deps: # Add a collection to track agent-specific streams agent_responses: Optional[List[StreamResponse]] = None + orchestrator_system_prompt = """You are an AI orchestrator that manages a team of agents to solve tasks. You have access to tools for coordinating the agents and managing the task flow. [AGENT CAPABILITIES] @@ -34,6 +37,19 @@ class orchestrator_deps: - Implements technical solutions - Executes code operations +3. External MCP servers: + - Specialized servers for specific tasks like GitHub operations, Google Maps, etc. + - Each server provides its own set of tools that can be accessed with the server name prefix + - For example: github.search_repositories, google-maps.geocode + +[SERVER SELECTION GUIDELINES] +When deciding which service or agent to use: +1. For general code-related tasks: Use coder_agent +2. For general web browsing tasks: Use web_surfer_agent +3. For GitHub operations: Use github.* tools (search repos, manage issues, etc.) +4. For location and maps tasks: Use google-maps.* tools (geocoding, directions, places) +5. You can use multiple services in sequence for complex tasks + [AVAILABLE TOOLS] 1. plan_task(task: str) -> str: - Plans the given task and assigns it to appropriate agents @@ -75,6 +91,27 @@ class orchestrator_deps: - Returns the updated plan with completed tasks marked - Must be called after each agent completes a task +6. server_status_update(server_name: str, status_message: str, progress: float = 0, details: Dict[str, Any] = None) -> str: + - Sends live updates about external server access to the UI + - Use when accessing external APIs or MCP servers (like Google Maps, GitHub, etc.) + - Parameters: + * server_name: Name of the server (e.g., 'google_maps', 'github') + * status_message: Short, descriptive status message + * progress: Progress percentage (0-100) + * details: Optional detailed information + - Send frequent updates during lengthy operations + - Updates the UI in real-time with server interaction progress + - Call this when: + * Starting to access a server + * Making requests to external APIs + * Receiving responses from external systems + * Completing server interactions + - Examples: + * "Connecting to Google Maps API..." + * "Fetching location data for New York..." + * "Processing route information..." + * "Retrieved map data successfully" + [MANDATORY WORKFLOW] 1. On receiving task: IF task involves login/credentials/authentication: @@ -158,194 +195,32 @@ class orchestrator_deps: - Format: "Task description (agent_name)" """ +# Initialize MCP Server +#we can add multiple servers here +#example: +# server1 = MCPServerHTTP(url='http://localhost:8004/sse') +# server2 = MCPServerHTTP(url='http://localhost:8003/sse') +server = MCPServerHTTP(url='http://localhost:8002/sse') + +# Initialize Anthropic provider with API key +provider = AnthropicProvider(api_key=os.environ.get("ANTHROPIC_API_KEY")) + model = AnthropicModel( model_name=os.environ.get("ANTHROPIC_MODEL_NAME"), - anthropic_client=get_client() + provider=provider ) +# Initialize the agent with just the main MCP server for now +# External servers will be added dynamically at runtime orchestrator_agent = Agent( model=model, name="Orchestrator Agent", system_prompt=orchestrator_system_prompt, - deps_type=orchestrator_deps + deps_type=orchestrator_deps, + mcp_servers=[server], # Start with just the main server ) -@orchestrator_agent.tool -async def plan_task(ctx: RunContext[orchestrator_deps], task: str) -> str: - """Plans the task and assigns it to the appropriate agents""" - try: - logfire.info(f"Planning task: {task}") - - # Create a new StreamResponse for Planner Agent - planner_stream_output = StreamResponse( - agent_name="Planner Agent", - instructions=task, - steps=[], - output="", - status_code=0 - ) - - # Add to orchestrator's response collection if available - if ctx.deps.agent_responses is not None: - ctx.deps.agent_responses.append(planner_stream_output) - - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) - - # Update planner stream - planner_stream_output.steps.append("Planning task...") - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) - - # Run planner agent - planner_response = await planner_agent.run(user_prompt=task) - - # Update planner stream with results - plan_text = planner_response.data.plan - planner_stream_output.steps.append("Task planned successfully") - planner_stream_output.output = plan_text - planner_stream_output.status_code = 200 - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) - - # Also update orchestrator stream - ctx.deps.stream_output.steps.append("Task planned successfully") - await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output) - - return f"Task planned successfully\nTask: {plan_text}" - except Exception as e: - error_msg = f"Error planning task: {str(e)}" - logfire.error(error_msg, exc_info=True) - - # Update planner stream with error - if planner_stream_output: - planner_stream_output.steps.append(f"Planning failed: {str(e)}") - planner_stream_output.status_code = 500 - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) - - # Also update orchestrator stream - if ctx.deps.stream_output: - ctx.deps.stream_output.steps.append(f"Planning failed: {str(e)}") - await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output) - - return f"Failed to plan task: {error_msg}" - -@orchestrator_agent.tool -async def coder_task(ctx: RunContext[orchestrator_deps], task: str) -> str: - """Assigns coding tasks to the coder agent""" - try: - logfire.info(f"Assigning coding task: {task}") - - # Create a new StreamResponse for Coder Agent - coder_stream_output = StreamResponse( - agent_name="Coder Agent", - instructions=task, - steps=[], - output="", - status_code=0 - ) - - # Add to orchestrator's response collection if available - if ctx.deps.agent_responses is not None: - ctx.deps.agent_responses.append(coder_stream_output) - - # Send initial update for Coder Agent - await _safe_websocket_send(ctx.deps.websocket, coder_stream_output) - - # Create deps with the new stream_output - deps_for_coder_agent = CoderAgentDeps( - websocket=ctx.deps.websocket, - stream_output=coder_stream_output - ) - - # Run coder agent - coder_response = await coder_agent.run( - user_prompt=task, - deps=deps_for_coder_agent - ) - - # Extract response data - response_data = coder_response.data.content - - # Update coder_stream_output with coding results - coder_stream_output.output = response_data - coder_stream_output.status_code = 200 - coder_stream_output.steps.append("Coding task completed successfully") - await _safe_websocket_send(ctx.deps.websocket, coder_stream_output) - - # Add a reminder in the result message to update the plan using planner_agent_update - response_with_reminder = f"{response_data}\n\nReminder: You must now call planner_agent_update with the completed task description: \"{task} (coder_agent)\"" - - return response_with_reminder - except Exception as e: - error_msg = f"Error assigning coding task: {str(e)}" - logfire.error(error_msg, exc_info=True) - - # Update coder_stream_output with error - coder_stream_output.steps.append(f"Coding task failed: {str(e)}") - coder_stream_output.status_code = 500 - await _safe_websocket_send(ctx.deps.websocket, coder_stream_output) - - return f"Failed to assign coding task: {error_msg}" - -@orchestrator_agent.tool -async def web_surfer_task(ctx: RunContext[orchestrator_deps], task: str) -> str: - """Assigns web surfing tasks to the web surfer agent""" - try: - logfire.info(f"Assigning web surfing task: {task}") - - # Create a new StreamResponse for WebSurfer - web_surfer_stream_output = StreamResponse( - agent_name="Web Surfer", - instructions=task, - steps=[], - output="", - status_code=0, - live_url=None - ) - - # Add to orchestrator's response collection if available - if ctx.deps.agent_responses is not None: - ctx.deps.agent_responses.append(web_surfer_stream_output) - - await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output) - - # Initialize WebSurfer agent - web_surfer_agent = WebSurfer(api_url="http://localhost:8000/api/v1/web/stream") - - # Run WebSurfer with its own stream_output - success, message, messages = await web_surfer_agent.generate_reply( - instruction=task, - websocket=ctx.deps.websocket, - stream_output=web_surfer_stream_output - ) - - # Update WebSurfer's stream_output with final result - if success: - web_surfer_stream_output.steps.append("Web search completed successfully") - web_surfer_stream_output.output = message - web_surfer_stream_output.status_code = 200 - - # Add a reminder to update the plan - message_with_reminder = f"{message}\n\nReminder: You must now call planner_agent_update with the completed task description: \"{task} (web_surfer_agent)\"" - else: - web_surfer_stream_output.steps.append(f"Web search completed with issues: {message[:100]}") - web_surfer_stream_output.status_code = 500 - message_with_reminder = message - - await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output) - - web_surfer_stream_output.steps.append(f"WebSurfer completed: {'Success' if success else 'Failed'}") - await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output) - - return message_with_reminder - except Exception as e: - error_msg = f"Error assigning web surfing task: {str(e)}" - logfire.error(error_msg, exc_info=True) - - # Update WebSurfer's stream_output with error - web_surfer_stream_output.steps.append(f"Web search failed: {str(e)}") - web_surfer_stream_output.status_code = 500 - await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output) - return f"Failed to assign web surfing task: {error_msg}" - +# Human Input Tool attached to the orchestrator agent as a tool @orchestrator_agent.tool async def ask_human(ctx: RunContext[orchestrator_deps], question: str) -> str: """Sends a question to the frontend and waits for human input""" @@ -358,7 +233,8 @@ async def ask_human(ctx: RunContext[orchestrator_deps], question: str) -> str: instructions=question, steps=[], output="", - status_code=0 + status_code=0, + message_id=str(uuid.uuid4()) ) # Add to orchestrator's response collection if available @@ -394,121 +270,64 @@ async def ask_human(ctx: RunContext[orchestrator_deps], question: str) -> str: return f"Failed to get human input: {error_msg}" @orchestrator_agent.tool -async def planner_agent_update(ctx: RunContext[orchestrator_deps], completed_task: str) -> str: - """ - Updates the todo.md file to mark a task as completed and returns the full updated plan. +async def server_status_update(ctx: RunContext[orchestrator_deps], server_name: str, status_message: str, progress: float = 0, details: Dict[str, Any] = None) -> str: + """Send status update about an external server to the UI Args: - completed_task: Description of the completed task including which agent performed it - - Returns: - The complete updated todo.md content with tasks marked as completed + server_name: Name of the server being accessed (e.g., 'google_maps', 'github') + status_message: Short status message to display + progress: Progress percentage (0-100) + details: Optional detailed information about the server status """ try: - logfire.info(f"Updating plan with completed task: {completed_task}") - - # Create a new StreamResponse for Planner Agent update - planner_stream_output = StreamResponse( - agent_name="Planner Agent", - instructions=f"Update todo.md to mark as completed: {completed_task}", - steps=[], - output="", - status_code=0 - ) - - # Send initial update - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) - - # Directly read and update the todo.md file - base_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) - planner_dir = os.path.join(base_dir, "agents", "planner") - todo_path = os.path.join(planner_dir, "todo.md") - - planner_stream_output.steps.append("Reading current todo.md...") - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) - - # Make sure the directory exists - os.makedirs(planner_dir, exist_ok=True) - - try: - # Check if todo.md exists - if not os.path.exists(todo_path): - planner_stream_output.steps.append("No todo.md file found. Will create new one after task completion.") - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) - - # We'll directly call planner_agent.run() to create a new plan first - plan_prompt = f"Create a simple task plan based on this completed task: {completed_task}" - plan_response = await planner_agent.run(user_prompt=plan_prompt) - current_content = plan_response.data.plan - else: - # Read existing todo.md - with open(todo_path, "r") as file: - current_content = file.read() - planner_stream_output.steps.append(f"Found existing todo.md ({len(current_content)} bytes)") - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) - - # Now call planner_agent.run() with specific instructions to update the plan - update_prompt = f""" - Here is the current todo.md content: - - {current_content} - - Please update this plan to mark the following task as completed: {completed_task} - Return ONLY the fully updated plan with appropriate tasks marked as [x] instead of [ ]. - """ - - planner_stream_output.steps.append("Asking planner to update the plan...") - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) - - updated_plan_response = await planner_agent.run(user_prompt=update_prompt) - updated_plan = updated_plan_response.data.plan - - # Write the updated plan back to todo.md - with open(todo_path, "w") as file: - file.write(updated_plan) - - planner_stream_output.steps.append("Plan updated successfully") - planner_stream_output.output = updated_plan - planner_stream_output.status_code = 200 - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) - - # Update orchestrator stream - if ctx.deps.stream_output: - ctx.deps.stream_output.steps.append(f"Plan updated to mark task as completed: {completed_task}") - await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output) - - return updated_plan + if server_name == 'npx': + logfire.info(f"Server Initialisation with npx. No requirement of sending update to UI") + return f"Server Initialisation with npx. No requirement of sending update to UI" + + logfire.info(f"Server status update for {server_name}: {status_message}") + if ctx.deps.stream_output is None: + return f"Could not send status update: No stream output available" - except Exception as e: - error_msg = f"Error during plan update operations: {str(e)}" - logfire.error(error_msg, exc_info=True) + # Initialize server_status if needed + if ctx.deps.stream_output.server_status is None: + ctx.deps.stream_output.server_status = {} - planner_stream_output.steps.append(f"Plan update failed: {str(e)}") - planner_stream_output.status_code = a500 - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) + # Create status update + status_update = { + "status": status_message, + "progress": progress, + "timestamp": str(uuid.uuid4()) # Generate unique ID for this update + } + + # Add optional details + if details: + status_update["details"] = details - return f"Failed to update the plan: {error_msg}" + # Update stream_output + ctx.deps.stream_output.server_status[server_name] = status_update + ctx.deps.stream_output.steps.append(f"Server update from {server_name}: {status_message}") + + # Send update to WebSocket + success = await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output) + if success: + return f"Successfully sent status update for {server_name}" + else: + return f"Failed to send status update for {server_name}: WebSocket error" + except Exception as e: - error_msg = f"Error updating plan: {str(e)}" + error_msg = f"Error sending server status update: {str(e)}" logfire.error(error_msg, exc_info=True) - - # Update stream output with error - if ctx.deps.stream_output: - ctx.deps.stream_output.steps.append(f"Failed to update plan: {str(e)}") - await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output) - - return f"Failed to update plan: {error_msg}" + return f"Failed to send server status update: {error_msg}" -# Helper function for sending WebSocket messages -async def _safe_websocket_send(websocket: Optional[WebSocket], message: Any) -> bool: - """Safely send message through websocket with error handling""" - try: - if websocket and websocket.client_state.CONNECTED: - await websocket.send_text(json.dumps(asdict(message))) - logfire.debug("WebSocket message sent (_safe_websocket_send): {message}", message=message) - return True - return False - except Exception as e: - logfire.error(f"WebSocket send failed: {str(e)}") - return False \ No newline at end of file +async def _safe_websocket_send(socket: WebSocket, message: Any) -> bool: + """Safely send message through websocket with error handling""" + try: + if socket and socket.client_state.CONNECTED: + await socket.send_text(json.dumps(asdict(message))) + logfire.debug("WebSocket message sent (_safe_websocket_send): {message}", message=message) + return True + return False + except Exception as e: + logfire.error(f"WebSocket send failed: {str(e)}") + return False \ No newline at end of file diff --git a/cortex_on/agents/planner_agent.py b/cortex_on/agents/planner_agent.py index 897c22f..27db51e 100644 --- a/cortex_on/agents/planner_agent.py +++ b/cortex_on/agents/planner_agent.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field from pydantic_ai import Agent from pydantic_ai.models.anthropic import AnthropicModel +from pydantic_ai.providers.anthropic import AnthropicProvider # Local application imports from utils.ant_client import get_client @@ -176,9 +177,11 @@ class PlannerResult(BaseModel): plan: str = Field(description="The generated or updated plan in string format - this should be the complete plan text") +provider = AnthropicProvider(api_key=os.environ.get("ANTHROPIC_API_KEY")) + model = AnthropicModel( model_name=os.environ.get("ANTHROPIC_MODEL_NAME"), - anthropic_client=get_client() + provider = provider ) planner_agent = Agent( diff --git a/cortex_on/agents/web_surfer.py b/cortex_on/agents/web_surfer.py index 34e2cfd..38b1e6a 100644 --- a/cortex_on/agents/web_surfer.py +++ b/cortex_on/agents/web_surfer.py @@ -11,15 +11,7 @@ from dotenv import load_dotenv from fastapi import WebSocket import logfire -from pydantic_ai.messages import ( - ArgsJson, - ModelRequest, - ModelResponse, - ToolCallPart, - ToolReturnPart, - UserPromptPart, -) - +from pydantic_ai.messages import ModelResponse, ModelRequest, ToolReturnPart # Local application imports from utils.stream_response_format import StreamResponse diff --git a/cortex_on/config/external_mcp_servers.json b/cortex_on/config/external_mcp_servers.json new file mode 100644 index 0000000..f28d952 --- /dev/null +++ b/cortex_on/config/external_mcp_servers.json @@ -0,0 +1,24 @@ +{ + "github": { + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-github" + ], + "env": {}, + "description": "GitHub MCP server for repository operations", + "status": "disabled", + "secret_key": "GITHUB_PERSONAL_ACCESS_TOKEN" + }, + "google-maps": { + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-google-maps" + ], + "env": {}, + "description": "Google Maps MCP server for geocoding, directions, and place search", + "status": "disabled", + "secret_key": "GOOGLE_MAPS_API_KEY" + } +} \ No newline at end of file diff --git a/cortex_on/connect_to_external_server.py b/cortex_on/connect_to_external_server.py new file mode 100644 index 0000000..6c2146b --- /dev/null +++ b/cortex_on/connect_to_external_server.py @@ -0,0 +1,442 @@ +import asyncio +import json +import os +from contextlib import AsyncExitStack +from typing import Any, Dict, List, Optional, Tuple, Union +from datetime import datetime + +import nest_asyncio +from colorama import Fore, Style, init +import logfire +from mcp import ClientSession, StdioServerParameters +from pydantic_ai.models.anthropic import AnthropicModel +from pydantic_ai.providers.anthropic import AnthropicProvider + +init(autoreset=True) # Initialize colorama with autoreset=True + +from dotenv import load_dotenv +from pydantic_ai.mcp import MCPServerStdio + +load_dotenv() + +class StdioServerProvider: + """Class for creating and managing MCPServerStdio instances from a JSON configuration file""" + + def __init__(self, config_path: str = 'config/external_mcp_servers.json'): + self.config_path = config_path + self.servers: Dict[str, MCPServerStdio] = {} + self.server_configs: Dict[str, Dict[str, Any]] = {} + self.server_tools: Dict[str, List[Dict[str, Any]]] = {} + self.server_status: Dict[str, Dict[str, Any]] = {} + self.registered_tool_names: List[str] = [] # Track all registered tool names to ensure uniqueness + self.active_servers: List[MCPServerStdio] = [] # Track currently active servers + + async def shutdown_servers(self): + """Properly shut down all active servers""" + for server in self.active_servers: + try: + if hasattr(server, 'close') and callable(server.close): + await server.close() + elif hasattr(server, '__aexit__') and callable(server.__aexit__): + await server.__aexit__(None, None, None) + logfire.info(f"Shut down MCP server: {server}") + except Exception as e: + print(f"Error shutting down server: {str(e)}") + + # Clear the active servers list + self.active_servers = [] + self.servers = {} + self.server_tools = {} + print("All servers have been shut down") + + async def load_servers(self) -> Tuple[List[MCPServerStdio], str]: + """Load server configurations from JSON and create MCPServerStdio instances""" + # First shut down any existing servers + await self.shutdown_servers() + + # Clear registered tool names before loading new servers + self.registered_tool_names = [] + + # Check if config file exists in the specified path or try to find it + if not os.path.exists(self.config_path): + # Try to find it relative to the current file + alt_path = os.path.join(os.path.dirname(__file__), self.config_path) + if os.path.exists(alt_path): + self.config_path = alt_path + else: + # Try to find it in the parent directory + parent_dir = os.path.dirname(os.path.dirname(__file__)) + alt_path = os.path.join(parent_dir, self.config_path) + if os.path.exists(alt_path): + self.config_path = alt_path + else: + raise FileNotFoundError(f"Could not find config file: {self.config_path}") + + print(f"Loading stdio server configuration from: {self.config_path}") + + # Load the configuration file + with open(self.config_path, 'r') as f: + self.server_configs = json.load(f) + + # Create MCPServerStdio instances for each server + stdio_servers = [] + server_names = [] + + for server_name, config in self.server_configs.items(): + # Skip servers without a command + if 'command' not in config or not config['command']: + print(f"Skipping {server_name} - no command specified") + continue + + if config['status'] == "disabled": + print(f"Skipping {server_name} - server is disabled") + continue + + command = config['command'] + args = config.get('args', []) + env = config.get('env', {}) + + # Create the MCPServerStdio instance + try: + # Use namespaced tool names by setting the namespace parameter + server = MCPServerStdio( + command, + args=args, + env=env + ) + + self.servers[server_name] = server + stdio_servers.append(server) + server_names.append(server_name) + self.active_servers.append(server) # Track in active servers list + + # Initialize server status + self.server_status[server_name] = { + "status": "initializing", + "last_update": datetime.now().isoformat() + } + + print(f"Created MCPServerStdio for {server_name} with command: {command} {' '.join(args)}") + except Exception as e: + print(f"Error creating MCPServerStdio for {server_name}: {str(e)}") + + # Wait for servers to initialize before attempting to discover tools + await asyncio.sleep(2) + + # Try to discover tools from all servers + for server_name in server_names: + print(f"Attempting to discover tools for {server_name}...") + await self.get_server_tools(server_name) + + # Generate a combined system prompt with server information + system_prompt = self._generate_system_prompt(server_names) + + return stdio_servers, system_prompt + + def _generate_system_prompt(self, server_names: List[str]) -> str: + """Generate a system prompt for the agent with information about available servers and their tools""" + if not server_names: + return "You are an AI assistant that can help with various tasks." + + servers_list = ", ".join([f"`{name}`" for name in server_names]) + + prompt = f"""[EXTERNAL SERVER CAPABILITIES] + """ + + # Add details about each server and its tools if available + for server_name in server_names: + config = self.server_configs.get(server_name, {}) + description = config.get('description', f"MCP server for {server_name}") + + prompt += f"""- {server_name}: + Description: {description} + Usage scenarios: + """ + + # Add general usage scenarios based on server description + keywords = server_name.lower().split('-') + if "github" in keywords: + prompt += """ - Repository operations + - Code browsing and analysis + - Issue and PR management + """ + elif "google" in keywords and "maps" in keywords: + prompt += """ - Geocoding and location services + - Directions and routing + - Place search and information + """ + else: + # Generic description based on server name + prompt += f""" - {server_name.replace('-', ' ').title()} operations + """ + + # Add tool information + prompt += " Available tools:\n" + + # Use any tools we've discovered + tools = self.server_tools.get(server_name, []) + if tools: + for tool in tools: + tool_name = tool.get("name", f"{server_name}.unknown_tool") + tool_description = tool.get("description", "No description available") + prompt += f" - {tool_name}: {tool_description}\n" + else: + # If no tools are discovered, provide generic information + prompt += f" - Various tools prefixed with '{server_name}.'\n" + + prompt += """ + [HOW TO USE EXTERNAL SERVERS] + 1. When a user's task requires capabilities from an external server: + - Identify which server is most appropriate based on the task description + - Use the server's tools directly with the server name prefix (e.g., {server}.tool_name) + - Include all required parameters for the tool + + 2. Server selection guidelines: + """ + + # Dynamically generate server selection guidelines based on available servers + for server_name in server_names: + server_title = server_name.replace('-', ' ').title() + prompt += f" - For {server_title} operations: Choose the {server_name} server\n" + + prompt += """ + 3. Important notes: + - Always include the server name prefix with the tool name + - Multiple servers can be used in the same task when needed + - Provide detailed parameters based on the specific tool requirements + """ + + return prompt + + async def get_server_tools(self, server_name: str) -> List[Dict[str, Any]]: + """Fetch tool information from a running server by introspecting its capabilities""" + if server_name not in self.servers: + return [] + + try: + server = self.servers[server_name] + tools = [] + + # Access the private _mcp_api property to get tool information + # Note: This is implementation-specific and might need adjustment + # based on the actual MCPServerStdio implementation + if hasattr(server, '_mcp_api') and server._mcp_api: + api = server._mcp_api + + # Check if the API has a tool manager + if hasattr(api, '_tool_manager'): + tool_manager = api._tool_manager + + # Get tools from the tool manager + if hasattr(tool_manager, '_tools') and tool_manager._tools: + for tool_name, tool_info in tool_manager._tools.items(): + # Ensure tool name has server namespace prefix + if not tool_name.startswith(f"{server_name}."): + prefixed_name = f"{server_name}.{tool_name}" + else: + prefixed_name = tool_name + + # Check if this tool name is already registered + if prefixed_name in self.registered_tool_names: + # Make the name unique by adding a suffix + base_name = prefixed_name + suffix = 1 + while f"{base_name}_{suffix}" in self.registered_tool_names: + suffix += 1 + prefixed_name = f"{base_name}_{suffix}" + + # Add to the registered names list + self.registered_tool_names.append(prefixed_name) + + # Extract description if available + description = "No description available" + if hasattr(tool_info, 'description') and tool_info.description: + description = tool_info.description + elif hasattr(tool_info, '__doc__') and tool_info.__doc__: + description = tool_info.__doc__.strip() + + tools.append({ + "name": prefixed_name, + "description": description + }) + + # If we couldn't get tools through introspection, try to get schema + if not tools and hasattr(server, 'get_schema'): + try: + # Some MCP servers might have a get_schema method + schema = await server.get_schema() + if schema and 'tools' in schema: + for tool in schema['tools']: + name = tool.get('name', '') + # Ensure tool name has server namespace prefix + if not name.startswith(f"{server_name}."): + prefixed_name = f"{server_name}.{name}" + else: + prefixed_name = name + + # Check if this tool name is already registered + if prefixed_name in self.registered_tool_names: + # Make the name unique by adding a suffix + base_name = prefixed_name + suffix = 1 + while f"{base_name}_{suffix}" in self.registered_tool_names: + suffix += 1 + prefixed_name = f"{base_name}_{suffix}" + + # Add to the registered names list + self.registered_tool_names.append(prefixed_name) + + tools.append({ + "name": prefixed_name, + "description": tool.get('description', f"Tool from {server_name}") + }) + except Exception as schema_err: + print(f"Error getting schema from {server_name}: {str(schema_err)}") + + # Fallback for when we can't directly access the tool information: + # We'll use some typical tools based on server name to provide useful information + if not tools: + fallback_tools = [] + + if server_name == "github": + fallback_tools = [ + {"base_name": "search_repositories", "description": "Search for GitHub repositories"}, + {"base_name": "get_repository", "description": "Get details about a specific repository"}, + {"base_name": "list_issues", "description": "List issues for a repository"}, + {"base_name": "create_issue", "description": "Create a new issue in a repository"}, + {"base_name": "search_code", "description": "Search for code within repositories"} + ] + elif server_name == "google-maps": + fallback_tools = [ + {"base_name": "geocode", "description": "Convert addresses to geographic coordinates"}, + {"base_name": "directions", "description": "Get directions between locations"}, + {"base_name": "places", "description": "Search for places near a location"}, + {"base_name": "distance_matrix", "description": "Calculate distance and travel time"} + ] + else: + fallback_tools = [ + {"base_name": "use", "description": f"Use the {server_name} service"} + ] + + # Process the fallback tools with unique naming + for tool_info in fallback_tools: + base_name = tool_info["base_name"] + prefixed_name = f"{server_name}.{base_name}" + + # Check if this tool name is already registered + if prefixed_name in self.registered_tool_names: + # Make the name unique by adding a suffix + suffix = 1 + while f"{prefixed_name}_{suffix}" in self.registered_tool_names: + suffix += 1 + prefixed_name = f"{prefixed_name}_{suffix}" + + # Add to the registered names list + self.registered_tool_names.append(prefixed_name) + + tools.append({ + "name": prefixed_name, + "description": tool_info["description"] + }) + + print(f"Using fallback tool definitions for {server_name} - actual tools couldn't be discovered") + + # Store and return the tools + self.server_tools[server_name] = tools + print(f"Discovered {len(tools)} tools for {server_name}") + return tools + except Exception as e: + print(f"Error discovering tools from {server_name}: {str(e)}") + return [] + + async def monitor_server_status(self, server_name: str, callback: callable) -> None: + """ + Set up monitoring for a server's status and call the callback with updates + + Args: + server_name: The name of the server to monitor + callback: An async function to call with status updates (takes server_name and status dict) + """ + if server_name not in self.servers: + return + + try: + # Set initial status + status = { + "status": "monitoring", + "progress": 75, + "last_update": datetime.now().isoformat() + } + + # Call the callback with initial status + try: + await callback(server_name, status) + except Exception as cb_err: + logfire.error(f"Error calling status callback: {str(cb_err)}") + + # Check server health periodically + check_count = 0 + while server_name in self.servers: + check_count += 1 + + # Get the server + server = self.servers[server_name] + is_healthy = False + + # Try to determine if the server is healthy + try: + if hasattr(server, '_mcp_api') and server._mcp_api: + # We'll consider it healthy if it has an API + is_healthy = True + + # Get more detailed health info if available + if hasattr(server._mcp_api, 'health') and callable(server._mcp_api.health): + health_info = await server._mcp_api.health() + if isinstance(health_info, dict): + status.update(health_info) + except Exception: + is_healthy = False + + # Update the status based on health check + if is_healthy: + status = { + "status": "running", + "progress": 100, + "health": "ok", + "last_update": datetime.now().isoformat(), + "check_count": check_count + } + else: + status = { + "status": "degraded", + "progress": 80, + "health": "degraded", + "last_update": datetime.now().isoformat(), + "check_count": check_count + } + + # Call the callback with the status + try: + await callback(server_name, status) + except Exception as cb_err: + logfire.error(f"Error calling status callback: {str(cb_err)}") + + # Wait before the next check + await asyncio.sleep(5) # Check every 5 seconds + + except Exception as e: + logfire.error(f"Error monitoring server {server_name}: {str(e)}") + + # Try to send a final error status + try: + status = { + "status": "error", + "progress": 0, + "error": str(e), + "last_update": datetime.now().isoformat() + } + await callback(server_name, status) + except Exception: + pass + +server_provider = StdioServerProvider() \ No newline at end of file diff --git a/cortex_on/instructor.py b/cortex_on/instructor.py index b4f0efb..39ba608 100644 --- a/cortex_on/instructor.py +++ b/cortex_on/instructor.py @@ -2,46 +2,50 @@ import json import os import traceback +import yaml +import subprocess +import asyncio from dataclasses import asdict from datetime import datetime from typing import Any, Dict, List, Optional, Tuple, Union +import uuid # Third-party imports from dotenv import load_dotenv from fastapi import WebSocket import logfire from pydantic import BaseModel -from pydantic_ai import Agent -from pydantic_ai.messages import ModelMessage -from pydantic_ai.models.anthropic import AnthropicModel # Local application imports -from agents.code_agent import coder_agent -from agents.orchestrator_agent import orchestrator_agent, orchestrator_deps -from agents.planner_agent import planner_agent -from agents.web_surfer import WebSurfer -from utils.ant_client import get_client +from agents.orchestrator_agent import orchestrator_agent, orchestrator_deps, orchestrator_system_prompt from utils.stream_response_format import StreamResponse - +from agents.mcp_server import start_mcp_server, register_tools_for_main_mcp_server, server_manager, check_mcp_server_tools +from connect_to_external_server import server_provider +from agents.orchestrator_agent import server as main_server load_dotenv() - - - class DateTimeEncoder(json.JSONEncoder): - """Custom JSON encoder that can handle datetime objects""" + """Custom JSON encoder that can handle datetime objects and Pydantic models""" def default(self, obj): if isinstance(obj, datetime): return obj.isoformat() + if isinstance(obj, BaseModel): + # Handle both Pydantic v1 and v2 + if hasattr(obj, 'model_dump'): + return obj.model_dump() + elif hasattr(obj, 'dict'): + return obj.dict() + # Fallback for any other Pydantic structure + return {k: v for k, v in obj.__dict__.items() if not k.startswith('_')} return super().default(obj) - # Main Orchestrator Class class SystemInstructor: def __init__(self): self.websocket: Optional[WebSocket] = None self.stream_output: Optional[StreamResponse] = None self.orchestrator_response: List[StreamResponse] = [] + self.external_servers: Dict[str, Dict[str, Any]] = {} self._setup_logging() def _setup_logging(self) -> None: @@ -63,45 +67,291 @@ async def _safe_websocket_send(self, message: Any) -> bool: except Exception as e: logfire.error(f"WebSocket send failed: {str(e)}") return False + async def send_server_status_update(self, stream_output: StreamResponse, server_name: str, status: Dict[str, Any]) -> bool: + """Send server status update via WebSocket + + Args: + stream_output: The StreamResponse object to update + server_name: Name of the server being accessed + status: Status information to stream + """ + try: + # Ensure we have a server_status dictionary + if not hasattr(stream_output, 'server_status') or stream_output.server_status is None: + stream_output.server_status = {} + + # Add a timestamp to the status update + status_with_timestamp = {**status, "timestamp": datetime.now().isoformat()} + + # Update the status in the stream_output + stream_output.server_status[server_name] = status_with_timestamp + + # Add a step message for non-npx servers or if it's an important status + important_statuses = ["ready", "error", "failed", "connected"] + if server_name != 'npx' or status.get('status', '') in important_statuses: + stream_output.steps.append(f"Server update from {server_name}: {status.get('status', 'processing')}") + + # Make sure the WebSocket is still connected + if self.websocket and self.websocket.client_state.CONNECTED: + # Send the update and retry if needed + max_retries = 3 + for attempt in range(max_retries): + try: + # Try to send the message + await self.websocket.send_text(json.dumps(asdict(stream_output))) + logfire.debug(f"Server status update sent for {server_name}: {status.get('status')}") + return True + except Exception as send_err: + if attempt < max_retries - 1: + # Brief wait before retry + await asyncio.sleep(0.1 * (attempt + 1)) + logfire.warning(f"Retrying server status update ({attempt+1}/{max_retries})") + else: + # Last attempt failed + logfire.error(f"Failed to send server status update after {max_retries} attempts: {str(send_err)}") + return False + else: + logfire.warning(f"WebSocket disconnected, couldn't send status update for {server_name}") + return False + + except Exception as e: + logfire.error(f"Failed to send server status update: {str(e)}") + return False + + def _reset_orchestrator_agent(self): + """Reset the orchestrator agent for a new chat session""" + try: + # Keep only the main server (first one) and remove all external servers + if len(orchestrator_agent._mcp_servers) > 1: + main_server = orchestrator_agent._mcp_servers[0] + + # Log all servers that are being removed + for i, server in enumerate(orchestrator_agent._mcp_servers[1:], 1): + server_command = getattr(server, 'command', f'server_{i}') + logfire.info(f"Removing external MCP server: {server.__class__.__name__} with command: {server_command}") + + orchestrator_agent._mcp_servers = [main_server] + logfire.info("Reset orchestrator_agent MCP servers to just the main server") + + # Reset the system prompt to its original state + orchestrator_agent.system_prompt = orchestrator_system_prompt + logfire.info("Reset orchestrator_agent system prompt to default") + + # If there's a tools manager, clear any cache it might have + for server in orchestrator_agent._mcp_servers: + if hasattr(server, '_mcp_api') and server._mcp_api: + api = server._mcp_api + if hasattr(api, '_tool_manager'): + tool_manager = api._tool_manager + if hasattr(tool_manager, '_cached_tool_schemas'): + tool_manager._cached_tool_schemas = None + logfire.info(f"Cleared tool schema cache for server {server}") + + # Also clear any cached tools to ensure fresh registration + if hasattr(tool_manager, '_tools'): + # Don't clear main server tools, just log them + tool_count = len(tool_manager._tools) if tool_manager._tools else 0 + logfire.info(f"Server has {tool_count} registered tools") + + except Exception as e: + logfire.error(f"Error resetting orchestrator agent: {str(e)}") + # If reset fails, try more aggressive cleanup + try: + # Force reset to just the main server + if hasattr(orchestrator_agent, '_mcp_servers') and orchestrator_agent._mcp_servers: + orchestrator_agent._mcp_servers = orchestrator_agent._mcp_servers[:1] + logfire.info("Performed aggressive reset - kept only first server") + except Exception as cleanup_err: + logfire.error(f"Aggressive reset also failed: {str(cleanup_err)}") - async def run(self, task: str, websocket: WebSocket) -> List[Dict[str, Any]]: - """Main orchestration loop with comprehensive error handling""" + async def run(self, task: str, websocket: WebSocket, server_config: Optional[Dict[str, int]] = None) -> List[Dict[str, Any]]: + """ + Main orchestration loop with comprehensive error handling + + Args: + task: The task instructions + websocket: The active WebSocket connection + server_config: Optional configuration for MCP servers {name: port} + """ + # Only reset if we have external servers to reset (i.e., this is not the first run) + if len(orchestrator_agent._mcp_servers) > 1: + logfire.info("Resetting orchestrator agent for new chat session (external servers detected)") + self._reset_orchestrator_agent() + else: + logfire.info("First run detected - skipping reset to allow initial server registration") + self.websocket = websocket stream_output = StreamResponse( agent_name="Orchestrator", instructions=task, steps=[], output="", - status_code=0 + status_code=0, + message_id=str(uuid.uuid4()) ) - self.orchestrator_response.append(stream_output) - + self.orchestrator_response = [stream_output] # Reset the response list for new chat + # Create dependencies with list to track agent responses deps_for_orchestrator = orchestrator_deps( websocket=self.websocket, stream_output=stream_output, - agent_responses=self.orchestrator_response # Pass reference to collection + agent_responses=self.orchestrator_response ) try: # Initialize system await self._safe_websocket_send(stream_output) - stream_output.steps.append("Agents initialized successfully") - await self._safe_websocket_send(stream_output) - - orchestrator_response = await orchestrator_agent.run( - user_prompt=task, - deps=deps_for_orchestrator - ) - stream_output.output = orchestrator_response.data + + # Use the default port for main MCP server + main_port = server_manager.default_port # This is 8002 + + # Merge default and external server configurations + if server_config is None: + server_config = { + "main": main_port + } + + # Start the main MCP server - already handled by the framework + start_mcp_server(port=main_port, name="main") + register_tools_for_main_mcp_server(websocket=self.websocket, port=main_port) + + # Start each configured external MCP server + servers, system_prompt = await server_provider.load_servers() + + logfire.info(f"Loaded {len(servers)} external servers from server_provider") + for i, server in enumerate(servers): + server_command = getattr(server, 'command', f'unknown_server_{i}') + logfire.info(f" Server {i}: {server.__class__.__name__} with command: {server_command}") + + # Verify we still have the main server after any operations + if not orchestrator_agent._mcp_servers: + logfire.error("No MCP servers found after reset - this should not happen!") + # Re-add the main server if somehow lost + orchestrator_agent._mcp_servers = [main_server] + logfire.info("Re-added main MCP server after unexpected loss") + + logfire.info(f"Starting server registration process. Current servers: {len(orchestrator_agent._mcp_servers)}, New external servers to process: {len(servers)}") + + # Send status update for each server being loaded + for i, server in enumerate(servers): + server_name = server.command.split('/')[-1] if hasattr(server, 'command') else f"server_{i}" + await self.send_server_status_update( + stream_output, + server_name, + {"status": "initializing", "progress": i/len(servers)*100} + ) + + # We need to make sure each MCP server has unique tool names + # First, check the main MCP server's tools + registered_tools = set() + main_server = orchestrator_agent._mcp_servers[0] + check_mcp_server_tools(main_server, registered_tools) + + # Check if we already have external servers registered to avoid duplicates + existing_server_commands = set() + existing_server_ids = set() + + for existing_server in orchestrator_agent._mcp_servers[1:]: # Skip main server + if hasattr(existing_server, 'command'): + existing_server_commands.add(existing_server.command) + # Also track server object IDs to prevent adding the exact same object + existing_server_ids.add(id(existing_server)) + + logfire.info(f"Existing server commands: {existing_server_commands}") + logfire.info(f"Servers to register: {[getattr(s, 'command', str(s)) for s in servers]}") + + # Now add each external server only if it's not already registered + servers_added = 0 + for server in servers: + server_command = getattr(server, 'command', str(server)) + server_id = id(server) + + logfire.info(f"Processing server with command: {server_command}, ID: {server_id}") + + # For the first run or when we have new servers, be more permissive + # Only skip if we find an exact command match AND it's the same object ID + should_skip = (server_command in existing_server_commands and + server_id in existing_server_ids) + + if not should_skip: + # Check and deduplicate tools before adding + check_mcp_server_tools(server, registered_tools) + # Adding one at a time after checking + orchestrator_agent._mcp_servers.append(server) + existing_server_commands.add(server_command) + existing_server_ids.add(server_id) + servers_added += 1 + logfire.info(f"✓ Added new MCP server: {server.__class__.__name__} with command: {server_command}") + else: + logfire.info(f"✗ Skipped duplicate MCP server with command: {server_command} (exact duplicate found)") + + logfire.info(f"Total MCP servers after registration: {len(orchestrator_agent._mcp_servers)} (added {servers_added} new servers)") + + # Properly integrate external server capabilities into the system prompt + updated_system_prompt = orchestrator_system_prompt + if system_prompt and system_prompt.strip(): + if "[AVAILABLE TOOLS]" in updated_system_prompt: + sections = updated_system_prompt.split("[AVAILABLE TOOLS]") + updated_system_prompt = sections[0] + system_prompt + "\n\n[AVAILABLE TOOLS]" + sections[1] + else: + # If we can't find the section, just append to the end (fallback) + updated_system_prompt = updated_system_prompt + "\n\n" + system_prompt + + orchestrator_agent.system_prompt = updated_system_prompt + logfire.info(f"Updated orchestrator agent with {len(servers)} MCP servers. Current MCP servers: {orchestrator_agent._mcp_servers}") + # Configure orchestrator_agent to use all configured MCP servers + logfire.info("Starting to register MCP server tools with Claude") + + # Send another status update before starting MCP servers + for i, server in enumerate(servers): + server_name = server.command.split('/')[-1] if hasattr(server, 'command') else f"server_{i}" + await self.send_server_status_update( + stream_output, + server_name, + {"status": "connecting", "progress": 50 + i/len(servers)*25} + ) + await asyncio.sleep(0.1) # Brief pause to allow updates to be sent + + async with orchestrator_agent.run_mcp_servers(): + # Send status update that servers are ready + for i, server in enumerate(servers): + server_name = server.command.split('/')[-1] if hasattr(server, 'command') else f"server_{i}" + await self.send_server_status_update( + stream_output, + server_name, + {"status": "ready", "progress": 100} + ) + await asyncio.sleep(0.1) # Brief pause to allow updates to be sent + + # Start monitoring this server's status in the background + asyncio.create_task( + server_provider.monitor_server_status( + server_name, + lambda s, status: self.send_server_status_update(stream_output, s, status) + ) + ) + + orchestrator_response = await orchestrator_agent.run( + user_prompt=task, + deps=deps_for_orchestrator + ) + stream_output.output = orchestrator_response.output stream_output.status_code = 200 - logfire.debug(f"Orchestrator response: {orchestrator_response.data}") + logfire.debug(f"Orchestrator response: {orchestrator_response.output}") await self._safe_websocket_send(stream_output) logfire.info("Task completed successfully") return [json.loads(json.dumps(asdict(i), cls=DateTimeEncoder)) for i in self.orchestrator_response] except Exception as e: + if "WebSocketDisconnect" in str(e): + logfire.info("WebSocket disconnected. Client likely closed the connection.") + try: + await self.shutdown() + except Exception as shutdown_err: + logfire.error(f"Error during cleanup after disconnect: {shutdown_err}") + return [json.loads(json.dumps(asdict(i), cls=DateTimeEncoder)) for i in self.orchestrator_response] + error_msg = f"Critical orchestration error: {str(e)}\n{traceback.format_exc()}" logfire.error(error_msg) @@ -111,15 +361,25 @@ async def run(self, task: str, websocket: WebSocket) -> List[Dict[str, Any]]: self.orchestrator_response.append(stream_output) await self._safe_websocket_send(stream_output) - # Even in case of critical error, return what we have - return [asdict(i) for i in self.orchestrator_response] + try: + return [json.loads(json.dumps(asdict(i), cls=DateTimeEncoder)) for i in self.orchestrator_response] + except Exception as serialize_error: + logfire.error(f"Failed to serialize response: {str(serialize_error)}") + # Last resort - return a simple error message + return [{"error": error_msg, "status_code": 500}] finally: logfire.info("Orchestration process complete") - # Clear any sensitive data + async def shutdown(self): """Clean shutdown of orchestrator""" try: + # Reset the orchestrator agent + self._reset_orchestrator_agent() + + # Shut down all external MCP servers + await server_provider.shutdown_servers() + # Close websocket if open if self.websocket: await self.websocket.close() diff --git a/cortex_on/main.py b/cortex_on/main.py index a8dd4de..8dc4a62 100644 --- a/cortex_on/main.py +++ b/cortex_on/main.py @@ -1,15 +1,27 @@ # Standard library imports from typing import List, Optional +import json # Third-party imports -from fastapi import FastAPI, WebSocket +from fastapi import FastAPI, WebSocket, HTTPException +from fastapi.middleware.cors import CORSMiddleware # Local application imports from instructor import SystemInstructor +from utils.models import MCPRequest app: FastAPI = FastAPI() +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:3000"], # Frontend URL + allow_credentials=True, + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers +) + async def generate_response(task: str, websocket: Optional[WebSocket] = None): orchestrator: SystemInstructor = SystemInstructor() return await orchestrator.run(task, websocket) @@ -25,3 +37,80 @@ async def websocket_endpoint(websocket: WebSocket): while True: data = await websocket.receive_text() await generate_response(data, websocket) + +@app.get("/agent/mcp/servers") +async def get_mcp_servers(): + with open("config/external_mcp_servers.json", "r") as f: + servers = json.load(f) + + servers_list = [] + for server in servers: + servers_list.append({ + "name": server, + "description": servers[server]["description"], + "status": servers[server]["status"] + }) + return servers_list + +@app.get("/agent/mcp/servers/{server_name}") +async def get_mcp_server(server_name: str): + with open("config/external_mcp_servers.json", "r") as f: + servers = json.load(f) + + if server_name not in servers: + raise HTTPException(status_code=404, detail="Server not found") + + config = { + 'command': servers[server_name]['command'], + 'args': servers[server_name]['args'] + # 'env': servers[server_name]['env'] + } if servers[server_name]['status'] == 'enabled' else {} + + return { + 'name': server_name, + 'status': servers[server_name]['status'], + 'description': servers[server_name]['description'], + 'config': config + } + +@app.post("/agent/mcp/servers") +async def configure_mcp_server(mcp_request: MCPRequest): + with open("config/external_mcp_servers.json", "r") as f: + servers = json.load(f) + + if mcp_request.server_name not in servers: + raise HTTPException(status_code=404, detail="Server not found") + + if not mcp_request.server_secret: + raise HTTPException(status_code=400, detail=f"Server secret is required to enable {mcp_request.server_name}") + + if mcp_request.action == 'enable': + if servers[mcp_request.server_name]['status'] == 'enabled': + raise HTTPException(status_code=400, detail=f"{mcp_request.server_name} is already enabled") + servers[mcp_request.server_name]['status'] = 'enabled' + server_secret_key = servers[mcp_request.server_name]['secret_key'] + servers[mcp_request.server_name]['env'][server_secret_key] = mcp_request.server_secret + + elif mcp_request.action == 'disable': + if servers[mcp_request.server_name]['status'] == 'disabled': + raise HTTPException(status_code=400, detail=f"{mcp_request.server_name} is already disabled") + servers[mcp_request.server_name]['status'] = 'disabled' + servers[mcp_request.server_name]['env'] = {} + + + with open("config/external_mcp_servers.json", "w") as f: + json.dump(servers, f, indent=4) + + config = { + 'command': servers[mcp_request.server_name]['command'], + 'args': servers[mcp_request.server_name]['args'] + # 'env': servers[server_name]['env'] + } if servers[mcp_request.server_name]['status'] == 'enabled' else {} + + return { + 'name': mcp_request.server_name, + 'status': servers[mcp_request.server_name]['status'], + 'description': servers[mcp_request.server_name]['description'], + 'config': config + } + diff --git a/cortex_on/requirements.txt b/cortex_on/requirements.txt index a635c7e..5bc13a4 100644 --- a/cortex_on/requirements.txt +++ b/cortex_on/requirements.txt @@ -2,7 +2,7 @@ aiohappyeyeballs==2.4.4 aiohttp==3.11.11 aiosignal==1.3.2 annotated-types==0.7.0 -anthropic==0.42.0 +anthropic==0.49.0 anyio==4.7.0 asyncio-atexit==1.0.1 attrs==24.3.0 @@ -25,7 +25,7 @@ frozenlist==1.5.0 google-auth==2.37.0 googleapis-common-protos==1.66.0 griffe==1.5.4 -groq==0.13.1 +groq==0.15.0 h11==0.14.0 httpcore==1.0.7 httpx==0.27.2 @@ -44,7 +44,7 @@ mistralai==1.2.5 multidict==6.1.0 mypy-extensions==1.0.0 numpy==2.2.1 -openai==1.58.1 +openai==1.74.0 opentelemetry-api==1.29.0 opentelemetry-exporter-otlp-proto-common==1.29.0 opentelemetry-exporter-otlp-proto-http==1.29.0 @@ -65,9 +65,12 @@ pyasn1_modules==0.4.1 pycparser==2.22 pycryptodome==3.21.0 pydantic==2.10.4 -pydantic-ai==0.0.17 -pydantic-ai-slim==0.0.17 +pydantic-ai==0.1.2 +pydantic-ai-slim==0.1.2 pydantic_core==2.27.2 +pydantic-ai==0.1.2 +pydantic-ai-slim==0.1.2 +mcp==1.6.0 Pygments==2.18.0 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 @@ -92,4 +95,5 @@ XlsxWriter==3.2.0 yarl==1.18.3 zipp==3.21.0 fast-graphrag==0.0.4 -llama_parse==0.5.19 \ No newline at end of file +llama_parse==0.5.19 +mcp-server-time \ No newline at end of file diff --git a/cortex_on/utils/models.py b/cortex_on/utils/models.py index 2666c64..e8e55c3 100644 --- a/cortex_on/utils/models.py +++ b/cortex_on/utils/models.py @@ -1,5 +1,6 @@ from pydantic import BaseModel from typing import Dict, Optional +from enum import Enum class FactModel(BaseModel): facts: str @@ -21,3 +22,12 @@ class LedgerModel(BaseModel): is_progress_being_made: LedgerAnswer next_speaker: LedgerAnswer instruction_or_question: LedgerAnswer + +class Action(str, Enum): + enable = "enable" + disable = "disable" + +class MCPRequest(BaseModel): + server_name: str + action: Action + server_secret: str diff --git a/cortex_on/utils/stream_response_format.py b/cortex_on/utils/stream_response_format.py index d99ac9a..6d02040 100644 --- a/cortex_on/utils/stream_response_format.py +++ b/cortex_on/utils/stream_response_format.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Dict, Any +import uuid @dataclass class StreamResponse: @@ -9,3 +10,5 @@ class StreamResponse: status_code: int output: str live_url: Optional[str] = None + message_id: str = "" # Unique identifier for each message + server_status: Optional[Dict[str, Any]] = None # Status updates from external servers diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index ed4c83e..8589064 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -2,6 +2,7 @@ import {BrowserRouter, Route, Routes} from "react-router-dom"; import Home from "./pages/Home"; import {Layout} from "./pages/Layout"; import {Vault} from "./pages/Vault"; +import {MCP} from "./pages/MCP"; function App() { return ( @@ -10,6 +11,7 @@ function App() { }> } /> } /> + } /> diff --git a/frontend/src/components/home/ChatList.tsx b/frontend/src/components/home/ChatList.tsx index 9d8cb21..0ab5b18 100644 --- a/frontend/src/components/home/ChatList.tsx +++ b/frontend/src/components/home/ChatList.tsx @@ -130,7 +130,7 @@ const ChatList = ({isLoading, setIsLoading}: ChatListPageProps) => { setIsLoading(true); const lastMessageData = lastMessage.data || []; - const {agent_name, instructions, steps, output, status_code, live_url} = + const {agent_name, instructions, steps, output, status_code, live_url, message_id} = lastJsonMessage as SystemMessage; console.log(lastJsonMessage); @@ -146,12 +146,14 @@ const ChatList = ({isLoading, setIsLoading}: ChatListPageProps) => { setLiveUrl(""); } - const agentIndex = lastMessageData.findIndex( - (agent: SystemMessage) => agent.agent_name === agent_name + // Check if we already have a message with this message_id + const existingMessageIndex = lastMessageData.findIndex( + (msg: SystemMessage) => msg.message_id === message_id ); let updatedLastMessageData; - if (agentIndex !== -1) { + if (existingMessageIndex !== -1 && message_id) { + // If we have a message_id and it already exists, update that specific message let filteredSteps = steps; if (agent_name === "Web Surfer") { const plannerStep = steps.find((step) => step.startsWith("Plan")); @@ -163,15 +165,17 @@ const ChatList = ({isLoading, setIsLoading}: ChatListPageProps) => { : steps.filter((step) => step.startsWith("Current")); } updatedLastMessageData = [...lastMessageData]; - updatedLastMessageData[agentIndex] = { + updatedLastMessageData[existingMessageIndex] = { agent_name, instructions, steps: filteredSteps, output, status_code, live_url, + message_id }; } else { + // If message_id doesn't exist or we don't have a message_id, add a new entry updatedLastMessageData = [ ...lastMessageData, { @@ -181,6 +185,7 @@ const ChatList = ({isLoading, setIsLoading}: ChatListPageProps) => { output, status_code, live_url, + message_id: message_id || `${agent_name}-${Date.now()}` // Fallback unique ID if message_id isn't provided }, ]; } @@ -197,19 +202,31 @@ const ChatList = ({isLoading, setIsLoading}: ChatListPageProps) => { if (status_code === 200) { setOutputsList((prevList) => { - const existingIndex = prevList.findIndex( - (item) => item.agent === agent_name - ); - + // If we have a message_id, look for an exact match + // If not found, create a new entry rather than updating by agent name + const indexByMessageId = message_id + ? prevList.findIndex(item => item.id === message_id) + : -1; + let newList; let newOutputIndex; - if (existingIndex >= 0) { + if (indexByMessageId >= 0) { + // Update the existing entry with matching message_id newList = [...prevList]; - newList[existingIndex] = {agent: agent_name, output}; - newOutputIndex = existingIndex; + newList[indexByMessageId] = { + agent: agent_name, + output, + id: message_id + }; + newOutputIndex = indexByMessageId; } else { - newList = [...prevList, {agent: agent_name, output}]; + // Create a new entry with this output + newList = [...prevList, { + agent: agent_name, + output, + id: message_id || `${agent_name}-${Date.now()}` + }]; newOutputIndex = newList.length - 1; } @@ -580,7 +597,7 @@ const ChatList = ({isLoading, setIsLoading}: ChatListPageProps) => { systemMessage.agent_name === "Orchestrator" ? (
{systemMessage.steps && @@ -609,7 +626,7 @@ const ChatList = ({isLoading, setIsLoading}: ChatListPageProps) => { ) : systemMessage.agent_name === "Human Input" ? (
@@ -761,7 +778,7 @@ const ChatList = ({isLoading, setIsLoading}: ChatListPageProps) => {
) : ( @@ -814,15 +831,16 @@ const ChatList = ({isLoading, setIsLoading}: ChatListPageProps) => { (systemMessage.agent_name !== "Web Surfer" && systemMessage.agent_name !== "Human Input" ? (
- handleOutputSelection( - outputsList.findIndex( - (item) => - item.agent === - systemMessage.agent_name - ) - ) - } + onClick={() => { + // First try to find by message_id, then fall back to agent name + const outputIndex = systemMessage.message_id + ? outputsList.findIndex(item => item.id === systemMessage.message_id) + : outputsList.findIndex(item => item.agent === systemMessage.agent_name); + + if (outputIndex >= 0) { + handleOutputSelection(outputIndex); + } + }} className="rounded-md w- py-2 px-4 bg-secondary text-secondary-foreground flex items-center justify-between cursor-pointer transition-all hover:shadow-md hover:scale-102 duration-300 animate-pulse-once" > {getAgentOutputCard( @@ -852,43 +870,43 @@ const ChatList = ({isLoading, setIsLoading}: ChatListPageProps) => { systemMessage?.output ) && (
- {message.data.find( - (systemMessage) => - systemMessage.agent_name === "Orchestrator" - )?.status_code === 200 ? ( -
- handleOutputSelection( - outputsList.findIndex( - (item) => item.agent === "Orchestrator" - ) - ) - } - className="rounded-md py-2 bg-[#F7E8FA] text-[#BD24CA] cursor-pointer transition-all hover:shadow-md hover:scale-102 duration-300 animate-pulse-once" - > -
- -

- Task has been completed. Click here to - view results. -

- + {(() => { + const orchestratorMessage = message.data.find( + (systemMessage) => + systemMessage.agent_name === "Orchestrator" + ); + return orchestratorMessage?.status_code === 200 ? ( +
{ + // First try to find by message_id, then fall back to agent name + const outputIndex = orchestratorMessage?.message_id + ? outputsList.findIndex(item => item.id === orchestratorMessage.message_id) + : outputsList.findIndex(item => item.agent === "Orchestrator"); + + if (outputIndex >= 0) { + handleOutputSelection(outputIndex); + } + }} + className="rounded-md py-2 bg-[#F7E8FA] text-[#BD24CA] cursor-pointer transition-all hover:shadow-md hover:scale-102 duration-300 animate-pulse-once" + > +
+ +

+ Task has been completed. Click here to + view results. +

+ +
-
- ) : ( - - systemMessage.agent_name === - "Orchestrator" - )?.output - } - /> - )} + ) : ( + + ); + })()}
)}
diff --git a/frontend/src/components/home/Header.tsx b/frontend/src/components/home/Header.tsx index d6217fd..db548ab 100644 --- a/frontend/src/components/home/Header.tsx +++ b/frontend/src/components/home/Header.tsx @@ -21,7 +21,7 @@ const Header = () => { > Logo
-
+
nav("/vault")} className={`w-[10%] h-full flex justify-center items-center cursor-pointer border-b-2 hover:border-[#BD24CA] ${ @@ -32,6 +32,17 @@ const Header = () => { >

Vault

+ +
nav("/mcp")} + className={`w-[10%] h-full flex justify-center items-center cursor-pointer border-b-2 hover:border-[#BD24CA] ${ + location.includes("/mcp") + ? "border-[#BD24CA]" + : "border-background" + }`} + > +

MCP

+
+ ); + })} +
+ ); +}; + +export default Sidebar; \ No newline at end of file diff --git a/frontend/src/components/mcp/services/ServiceView.tsx b/frontend/src/components/mcp/services/ServiceView.tsx new file mode 100644 index 0000000..436e563 --- /dev/null +++ b/frontend/src/components/mcp/services/ServiceView.tsx @@ -0,0 +1,156 @@ +import React from 'react'; +import { Input } from "@/components/ui/input"; +import { Button } from "@/components/ui/button"; +import DefaultView from "@/components/mcp/DefaultView"; +import { useState, useEffect } from "react"; +import { Loader2 } from "lucide-react"; + +interface ServiceViewProps { + service: string | null; +} + +const ServiceView: React.FC = ({ service }) => { + const [token, setToken] = useState(""); + const [isLoading, setIsLoading] = useState(false); + const [feedback, setFeedback] = useState<{ + status: "success" | "error"; + message: string; + } | null>(null); + + // Clear feedback message after 2 seconds whenever it changes + useEffect(() => { + if (feedback) { + const timer = setTimeout(() => { + setFeedback(null); + }, 2000); + + // Cleanup timeout on component unmount or when feedback changes + return () => clearTimeout(timer); + } + }, [feedback]); + + const handleToggle = async (action: "enable" | "disable") => { + if (!token.trim()) { + setFeedback({ status: "error", message: "Please enter a token." }); + return; + } + setIsLoading(true); + setFeedback(null); + try { + const response = await fetch( + "http://localhost:8081/agent/mcp/servers", + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + server_name: service, + server_secret: token, + action, + }), + } + ); + if (response.ok) { + setFeedback({ + status: "success", + message: + action === "enable" + ? `${service.charAt(0).toUpperCase() + service.slice(1)} token enabled successfully!` + : `${service.charAt(0).toUpperCase() + service.slice(1)} token disabled successfully!`, + }); + } else { + setFeedback({ + status: "error", + message: `Failed to update ${service.charAt(0).toUpperCase() + service.slice(1)} token. Please try again.`, + }); + } + } catch { + setFeedback({ + status: "error", + message: "Network error. Please try again.", + }); + } finally { + setIsLoading(false); + } + }; + + if (!service) { + return ; + } + + return ( +
+
+

{service.charAt(0).toUpperCase() + service.slice(1)} MCP

+

+ Configure your {service.charAt(0).toUpperCase() + service.slice(1)} Personal Access Token to enable {service.charAt(0).toUpperCase() + service.slice(1)} mcp integration +

+
+ +
+ setToken(e.target.value)} + /> +
+ + +
+
+ + {/* Feedback message box */} + {feedback && feedback.message && ( +
+

+ {feedback.message} +

+
+ )} +
+ ); +}; + +export default ServiceView; \ No newline at end of file diff --git a/frontend/src/pages/MCP.tsx b/frontend/src/pages/MCP.tsx new file mode 100644 index 0000000..5ce5d75 --- /dev/null +++ b/frontend/src/pages/MCP.tsx @@ -0,0 +1,21 @@ +import { ScrollArea } from "@/components/ui/scroll-area"; +import Sidebar from "@/components/mcp/Sidebar"; +import ServiceView from "@/components/mcp/services/ServiceView"; +import { useState } from "react"; + +export const MCP = () => { + const [selectedService, setSelectedService] = useState(null); + + return ( +
+ +
+ + + +
+
+ ); +}; + +export default MCP; \ No newline at end of file diff --git a/frontend/src/types/chatTypes.ts b/frontend/src/types/chatTypes.ts index a314798..56acfe7 100644 --- a/frontend/src/types/chatTypes.ts +++ b/frontend/src/types/chatTypes.ts @@ -15,6 +15,7 @@ export interface SystemMessage { output: string; status_code: number; live_url: string; + message_id?: string; } export interface Message { @@ -27,4 +28,5 @@ export interface Message { export interface AgentOutput { agent: string; output: string; + id?: string; } diff --git a/ta-browser/core/orchestrator.py b/ta-browser/core/orchestrator.py index 9dbf16e..fc6a3b2 100644 --- a/ta-browser/core/orchestrator.py +++ b/ta-browser/core/orchestrator.py @@ -674,11 +674,11 @@ async def run(self, command): return await self.handle_context_limit_error() await self.handle_agent_error('planner', e) - self.log_token_usage( - agent_type='planner', - usage=planner_response._usage, - step=self.iteration_counter - ) + # self.log_token_usage( + # agent_type='planner', + # usage=planner_response.usage, + # step=self.iteration_counter + # ) browser_error = None tool_interactions_str = None @@ -717,11 +717,11 @@ async def run(self, command): logger.info(f"Browser Agent Response: {browser_response.data}") await self._update_current_url() - self.log_token_usage( - agent_type='browser', - usage=browser_response._usage, - step=self.iteration_counter - ) + # self.log_token_usage( + # agent_type='browser', + # usage=browser_response.usage, + # step=self.iteration_counter + # ) except BrowserNavigationError as e: # Immediately terminate the task with error details @@ -778,11 +778,11 @@ async def run(self, command): logger.info(f"Critique Response: {critique_data.final_response}") logger.info(f"Critique Terminate: {critique_data.terminate}") - self.log_token_usage( - agent_type='critique', - usage=critique_response._usage, - step=self.iteration_counter - ) + # self.log_token_usage( + # agent_type='critique', + # usage=critique_response.usage, + # step=self.iteration_counter + # ) if critique_data.terminate: # Generate final_response if missing diff --git a/ta-browser/core/skills/final_response.py b/ta-browser/core/skills/final_response.py index fe3e3a9..c658b8d 100644 --- a/ta-browser/core/skills/final_response.py +++ b/ta-browser/core/skills/final_response.py @@ -50,14 +50,14 @@ def get_final_response_provider(): from core.utils.anthropic_client import get_client as get_anthropic_client from pydantic_ai.models.anthropic import AnthropicModel client = get_anthropic_client() - model = AnthropicModel(model_name=model_name, anthropic_client=client) + model = AnthropicModel(model_name=model_name, provider = "anthropic") provider = "anthropic" else: # OpenAI provider (default) from core.utils.openai_client import get_client as get_openai_client from pydantic_ai.models.openai import OpenAIModel client = get_openai_client() - model = OpenAIModel(model_name=model_name, openai_client=client) + model = OpenAIModel(model_name=model_name, provider = "openai") provider = "openai" return provider, client, model diff --git a/ta-browser/core/utils/init_client.py b/ta-browser/core/utils/init_client.py index 7d170c6..d33fa37 100644 --- a/ta-browser/core/utils/init_client.py +++ b/ta-browser/core/utils/init_client.py @@ -34,7 +34,7 @@ async def initialize_client(): # Create model instance from pydantic_ai.models.anthropic import AnthropicModel - model_instance = AnthropicModel(model_name=model_name, anthropic_client=client_instance) + model_instance = AnthropicModel(model_name=model_name, provider = "anthropic") logger.info(f"Anthropic client initialized successfully with model: {model_name}") return client_instance, model_instance diff --git a/ta-browser/requirements.txt b/ta-browser/requirements.txt index af8c9b0..7d51ddb 100644 --- a/ta-browser/requirements.txt +++ b/ta-browser/requirements.txt @@ -6,7 +6,7 @@ aiosignal==1.3.2 aiosmtplib==3.0.2 alembic==1.14.1 annotated-types==0.7.0 -anthropic==0.42.0 +anthropic==0.49.0 anyio==4.8.0 asgiref==3.8.1 asyncpg==0.30.0 @@ -41,7 +41,7 @@ google-auth==2.37.0 googleapis-common-protos==1.66.0 greenlet==3.0.3 griffe==1.5.4 -groq==0.13.1 +groq==0.15.0 grpcio==1.67.0 grpcio-status==1.62.3 h11==0.14.0 @@ -69,7 +69,7 @@ mypy-extensions==1.0.0 nest-asyncio==1.6.0 nltk==3.8.1 numpy==1.26.4 -openai==1.59.3 +openai==1.74.0 opentelemetry-api==1.29.0 opentelemetry-exporter-otlp-proto-common==1.29.0 opentelemetry-exporter-otlp-proto-http==1.29.0 @@ -96,8 +96,8 @@ pyautogen==0.2.27 pycparser==2.22 pycryptodome==3.20.0 pydantic==2.10.4 -pydantic-ai==0.0.17 -pydantic-ai-slim==0.0.17 +pydantic-ai==0.1.0 +pydantic-ai-slim==0.1.0 pydantic-core==2.27.2 pyee==11.1.0 pygments==2.18.0 @@ -137,7 +137,6 @@ typing-inspect==0.9.0 uritemplate==4.1.1 urllib3==2.3.0 uvicorn==0.30.3 -uvloop==0.21.0 watchfiles==0.24.0 websockets==13.1 wrapt==1.17.0