Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions MCPForUnity/Editor/Services/ToolDiscoveryService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,6 @@ private void EnsurePreferenceInitialized(ToolMetadata metadata)
{
bool defaultValue = metadata.AutoRegister || metadata.IsBuiltIn;
EditorPrefs.SetBool(key, defaultValue);
return;
}

if (metadata.IsBuiltIn && !metadata.AutoRegister)
{
bool currentValue = EditorPrefs.GetBool(key, metadata.AutoRegister);
if (currentValue == metadata.AutoRegister)
{
EditorPrefs.SetBool(key, true);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ public interface IMcpTransportClient
Task<bool> StartAsync();
Task StopAsync();
Task<bool> VerifyAsync();
Task ReregisterToolsAsync();
}
}
24 changes: 14 additions & 10 deletions MCPForUnity/Editor/Services/Transport/TransportManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,6 @@ private IMcpTransportClient GetOrCreateClient(TransportMode mode)
};
}

private IMcpTransportClient GetClient(TransportMode mode)
{
return mode switch
{
TransportMode.Http => _httpClient,
TransportMode.Stdio => _stdioClient,
_ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"),
};
}

public async Task<bool> StartAsync(TransportMode mode)
{
IMcpTransportClient client = GetOrCreateClient(mode);
Expand Down Expand Up @@ -163,6 +153,20 @@ public void ForceStop(TransportMode mode)
}
}

/// <summary>
/// Gets the active transport client for the specified mode.
/// Returns null if the client hasn't been created yet.
/// </summary>
public IMcpTransportClient GetClient(TransportMode mode)
{
return mode switch
{
TransportMode.Http => _httpClient,
TransportMode.Stdio => _stdioClient,
_ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"),
};
}

private void UpdateState(TransportMode mode, TransportState state)
{
switch (mode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,12 @@ public Task<bool> VerifyAsync()
return Task.FromResult(running);
}

public Task ReregisterToolsAsync()
{
// Stdio transport doesn't support dynamic tool reregistration
// Tools are registered at server startup
return Task.CompletedTask;
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,29 @@ private async Task SendRegisterToolsAsync(CancellationToken token)
McpLog.Info($"[WebSocket] Sent {tools.Count} tools registration", false);
}

public async Task ReregisterToolsAsync()
{
if (!IsConnected || _lifecycleCts == null)
{
McpLog.Warn("[WebSocket] Cannot reregister tools: not connected");
return;
}

try
{
await SendRegisterToolsAsync(_lifecycleCts.Token).ConfigureAwait(false);
McpLog.Info("[WebSocket] Tool reregistration completed", false);
}
catch (System.OperationCanceledException)
{
McpLog.Warn("[WebSocket] Tool reregistration cancelled");
}
catch (System.Exception ex)
{
McpLog.Error($"[WebSocket] Tool reregistration failed: {ex.Message}");
}
}

private async Task HandleExecuteAsync(JObject payload, CancellationToken token)
{
string commandId = payload.Value<string>("id");
Expand Down
56 changes: 53 additions & 3 deletions MCPForUnity/Editor/Windows/Components/Tools/McpToolsSection.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using MCPForUnity.Editor.Constants;
using MCPForUnity.Editor.Helpers;
using MCPForUnity.Editor.Services;
using MCPForUnity.Editor.Services.Transport;
using MCPForUnity.Editor.Tools;
using UnityEditor;
using UnityEditor.UIElements;
using UnityEngine.UIElements;

namespace MCPForUnity.Editor.Windows.Components.Tools
Expand Down Expand Up @@ -228,23 +231,63 @@ private VisualElement CreateToolRow(ToolMetadata tool)
return row;
}

private void HandleToggleChange(ToolMetadata tool, bool enabled, bool updateSummary = true)
private void HandleToggleChange(
ToolMetadata tool,
bool enabled,
bool updateSummary = true,
bool reregisterTools = true)
{
MCPServiceLocator.ToolDiscovery.SetToolEnabled(tool.Name, enabled);

if (updateSummary)
{
UpdateSummary();
}

if (reregisterTools)
{
// Trigger tool reregistration with connected MCP server
ReregisterToolsAsync();
}
}

private void ReregisterToolsAsync()
{
// Fire and forget - don't block UI thread
var transportManager = MCPServiceLocator.TransportManager;
var client = transportManager.GetClient(TransportMode.Http);
if (client == null || !client.IsConnected)
{
return;
}

_ = Task.Run(async () =>
{
try
{
await client.ReregisterToolsAsync().ConfigureAwait(false);
}
catch (Exception ex)
{
McpLog.Warn($"Failed to reregister tools: {ex}");
}
});
}

private void SetAllToolsState(bool enabled)
{
bool hasChanges = false;

foreach (var tool in allTools)
{
if (!toolToggleMap.TryGetValue(tool.Name, out var toggle))
{
MCPServiceLocator.ToolDiscovery.SetToolEnabled(tool.Name, enabled);
bool currentEnabled = MCPServiceLocator.ToolDiscovery.IsToolEnabled(tool.Name);
if (currentEnabled != enabled)
{
MCPServiceLocator.ToolDiscovery.SetToolEnabled(tool.Name, enabled);
hasChanges = true;
}
continue;
}

Expand All @@ -254,10 +297,17 @@ private void SetAllToolsState(bool enabled)
}

toggle.SetValueWithoutNotify(enabled);
HandleToggleChange(tool, enabled, updateSummary: false);
HandleToggleChange(tool, enabled, updateSummary: false, reregisterTools: false);
hasChanges = true;
}

UpdateSummary();

if (hasChanges)
{
// Trigger a single reregistration after bulk change
ReregisterToolsAsync();
}
}

private void UpdateSummary()
Expand Down
56 changes: 49 additions & 7 deletions Server/src/services/custom_tool_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from starlette.requests import Request
from starlette.responses import JSONResponse

from core.config import config
from models.models import MCPResponse, ToolDefinitionModel, ToolParameterModel
from core.logging_decorator import log_execution
from core.telemetry_decorator import telemetry_tool
Expand All @@ -28,6 +29,23 @@
_MAX_POLL_SECONDS = 600


def get_user_id_from_context(ctx: Context) -> str | None:
"""Read user_id from request-scoped context in remote-hosted mode."""
if not config.http_remote_hosted:
return None

get_state = getattr(ctx, "get_state", None)
if not callable(get_state):
return None

try:
Comment on lines +32 to +41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Broad exception handling around ctx.get_state may hide programming/configuration issues

This helper currently catches all Exception from ctx.get_state("user_id") and returns None, which can mask real context/configuration bugs (e.g., attribute/type errors). Please narrow the except to the specific failure modes you expect from get_state, or at least log the exception (e.g., at debug level), so genuine integration issues remain visible while still tolerating missing state in normal operation.

Suggested implementation:

import logging

from starlette.requests import Request
_MAX_POLL_SECONDS = 600

logger = logging.getLogger(__name__)
def get_user_id_from_context(ctx: Context) -> str | None:
    """Read user_id from request-scoped context in remote-hosted mode."""
    if not config.http_remote_hosted:
        return None

    get_state = getattr(ctx, "get_state", None)
    if not callable(get_state):
        return None

    try:
        user_id = get_state("user_id")
    except KeyError:
        # Expected case: user_id state is not present
        return None
    except Exception as exc:
        # Log unexpected integration/configuration issues while remaining tolerant
        logger.debug("Failed to read 'user_id' from context state", exc_info=exc)
        return None

    return user_id if isinstance(user_id, str) and user_id else None

If ctx.get_state can raise more specific, known exceptions (e.g. a custom context error), you should replace or augment the generic Exception handler with those specific types so that truly unexpected errors can still propagate.

user_id = get_state("user_id")
except Exception:
return None

return user_id if isinstance(user_id, str) and user_id else None


class RegisterToolsPayload(BaseModel):
project_id: str
project_hash: str | None = None
Expand Down Expand Up @@ -84,30 +102,40 @@ async def register_tools(request: Request) -> JSONResponse:
return JSONResponse(response.model_dump())

# --- Public API for MCP tools ---------------------------------------
async def list_registered_tools(self, project_id: str) -> list[ToolDefinitionModel]:
async def list_registered_tools(
self,
project_id: str,
user_id: str | None = None,
) -> list[ToolDefinitionModel]:
legacy = list(self._project_tools.get(project_id, {}).values())
hub_tools = await PluginHub.get_tools_for_project(project_id)
hub_tools = await PluginHub.get_tools_for_project(project_id, user_id=user_id)
return legacy + hub_tools

async def get_tool_definition(self, project_id: str, tool_name: str) -> ToolDefinitionModel | None:
async def get_tool_definition(
self,
project_id: str,
tool_name: str,
user_id: str | None = None,
) -> ToolDefinitionModel | None:
tool = self._project_tools.get(project_id, {}).get(tool_name)
if tool:
return tool
return await PluginHub.get_tool_definition(project_id, tool_name)
return await PluginHub.get_tool_definition(project_id, tool_name, user_id=user_id)

async def execute_tool(
self,
project_id: str,
tool_name: str,
unity_instance: str | None,
params: dict[str, object] | None = None,
user_id: str | None = None,
) -> MCPResponse:
params = params or {}
logger.info(
f"Executing tool '{tool_name}' for project '{project_id}' (instance={unity_instance}) with params: {params}"
)

definition = await self.get_tool_definition(project_id, tool_name)
definition = await self.get_tool_definition(project_id, tool_name, user_id=user_id)
if definition is None:
return MCPResponse(
success=False,
Expand All @@ -119,6 +147,7 @@ async def execute_tool(
unity_instance,
tool_name,
params,
user_id=user_id,
)

if not definition.requires_polling:
Expand All @@ -132,6 +161,7 @@ async def execute_tool(
params,
response,
definition.poll_action or "status",
user_id=user_id,
)
logger.info(f"Tool '{tool_name}' polled response: {result}")
return result
Expand All @@ -156,6 +186,7 @@ async def _poll_until_complete(
initial_params: dict[str, object],
initial_response,
poll_action: str,
user_id: str | None = None,
) -> MCPResponse:
poll_params = dict(initial_params)
poll_params["action"] = poll_action or "status"
Expand All @@ -180,7 +211,11 @@ async def _poll_until_complete(

try:
response = await send_with_unity_instance(
async_send_command_with_retry, unity_instance, tool_name, poll_params
async_send_command_with_retry,
unity_instance,
tool_name,
poll_params,
user_id=user_id,
)
except Exception as exc: # pragma: no cover - network/domain reload variability
logger.debug(f"Polling {tool_name} failed, will retry: {exc}")
Expand Down Expand Up @@ -347,8 +382,15 @@ async def _handler(ctx: Context, **kwargs) -> MCPResponse:
)

params = {k: v for k, v in kwargs.items() if v is not None}
user_id = get_user_id_from_context(ctx)
service = CustomToolService.get_instance()
return await service.execute_tool(project_id, definition.name, unity_instance, params)
return await service.execute_tool(
project_id,
definition.name,
unity_instance,
params,
user_id=user_id,
)

_handler.__name__ = f"custom_tool_{definition.name}"
_handler.__doc__ = definition.description or ""
Expand Down
25 changes: 24 additions & 1 deletion Server/src/services/registry/tool_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
def mcp_for_unity_tool(
name: str | None = None,
description: str | None = None,
unity_target: str | None = "self",
**kwargs
) -> Callable:
"""
Expand All @@ -20,6 +21,10 @@ def mcp_for_unity_tool(
Args:
name: Tool name (defaults to function name)
description: Tool description
unity_target: Visibility target used by middleware filtering.
- "self" (default): tool follows its own enabled state.
- None: server-only tool, always visible in tool listing.
- "<tool_name>": alias tool that follows another Unity tool state.
**kwargs: Additional arguments passed to @mcp.tool()

Example:
Expand All @@ -29,11 +34,29 @@ async def my_custom_tool(ctx: Context, ...):
"""
def decorator(func: Callable) -> Callable:
tool_name = name if name is not None else func.__name__
# Safety guard: unity_target is internal metadata and must never leak into mcp.tool kwargs.
tool_kwargs = dict(kwargs) # Create a copy to avoid side effects
if "unity_target" in tool_kwargs:
del tool_kwargs["unity_target"]

if unity_target is None:
normalized_unity_target: str | None = None
elif isinstance(unity_target, str) and unity_target.strip():
normalized_unity_target = (
tool_name if unity_target == "self" else unity_target.strip()
)
else:
raise ValueError(
f"Invalid unity_target for tool '{tool_name}': {unity_target!r}. "
"Expected None or a non-empty string."
)

_tool_registry.append({
'func': func,
'name': tool_name,
'description': description,
'kwargs': kwargs
'unity_target': normalized_unity_target,
'kwargs': tool_kwargs,
})

return func
Expand Down
Loading