diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 6e19a901218..5fddf8ec092 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -706,6 +706,106 @@ class TestMCPServerTask: asyncio.run(_test()) + def test_refresh_tools_deregisters_removed_tools(self): + """Dynamic refresh removes stale registry entries for deleted tools.""" + from tools.registry import ToolRegistry + from tools.mcp_tool import MCPServerTask + + mock_registry = ToolRegistry() + server = MCPServerTask("srv") + server._config = {"command": "test"} + server._tools = [_make_mcp_tool("old"), _make_mcp_tool("keep")] + server._registered_tool_names = ["mcp_srv_old", "mcp_srv_keep"] + server.session = MagicMock() + server.session.list_tools = AsyncMock( + return_value=SimpleNamespace(tools=[_make_mcp_tool("keep"), _make_mcp_tool("new")]) + ) + + with patch("tools.registry.registry", mock_registry): + mock_registry.register( + name="mcp_srv_old", + toolset="mcp-srv", + schema={"name": "mcp_srv_old", "description": "Old"}, + handler=lambda *_args, **_kwargs: "{}", + ) + mock_registry.register( + name="mcp_srv_keep", + toolset="mcp-srv", + schema={"name": "mcp_srv_keep", "description": "Keep"}, + handler=lambda *_args, **_kwargs: "{}", + ) + + asyncio.run(server._refresh_tools()) + + names = mock_registry.get_all_tool_names() + assert "mcp_srv_old" not in names + assert "mcp_srv_keep" in names + assert "mcp_srv_new" in names + assert set(server._registered_tool_names) == { + "mcp_srv_keep", + "mcp_srv_new", + "mcp_srv_list_resources", + "mcp_srv_read_resource", + "mcp_srv_list_prompts", + "mcp_srv_get_prompt", + } + + def test_schedule_tools_refresh_keeps_task_until_done(self): + """Background refresh tasks are strongly referenced and then discarded.""" + from tools.mcp_tool import MCPServerTask + + async def _test(): + started = asyncio.Event() + finish = asyncio.Event() + server = MCPServerTask("srv") + + async def fake_refresh(_server): + started.set() + await finish.wait() + + with patch.object(MCPServerTask, "_refresh_tools", new=fake_refresh): + server._schedule_tools_refresh() + + await started.wait() + assert len(server._pending_refresh_tasks) == 1 + task = next(iter(server._pending_refresh_tasks)) + assert not task.done() + + finish.set() + await task + await asyncio.sleep(0) + assert server._pending_refresh_tasks == set() + + asyncio.run(_test()) + + def test_shutdown_cancels_pending_refresh_tasks(self): + """shutdown() cancels in-flight background refresh tasks.""" + from tools.mcp_tool import MCPServerTask + + async def _test(): + started = asyncio.Event() + cancelled = asyncio.Event() + server = MCPServerTask("srv") + + async def fake_refresh(_server): + started.set() + try: + await asyncio.sleep(3600) + except asyncio.CancelledError: + cancelled.set() + raise + + with patch.object(MCPServerTask, "_refresh_tools", new=fake_refresh): + server._schedule_tools_refresh() + await started.wait() + + await server.shutdown() + + assert cancelled.is_set() + assert server._pending_refresh_tasks == set() + + asyncio.run(_test()) + def test_empty_env_gets_safe_defaults(self): """Empty env dict gets safe default env vars (PATH, HOME, etc.).""" from tools.mcp_tool import MCPServerTask @@ -1993,7 +2093,13 @@ try: except ImportError: ToolUseContent = _CompatType -from tools.mcp_tool import SamplingHandler, _safe_numeric +from tools.mcp_tool import ( + CreateMessageResultWithTools, + SamplingHandler, + SamplingToolsCapability, + ToolUseContent, + _safe_numeric, +) # --------------------------------------------------------------------------- diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 2e056eb91b3..d3868f7c20b 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -167,10 +167,22 @@ _MCP_HTTP_AVAILABLE = False _MCP_SAMPLING_TYPES = False _MCP_NOTIFICATION_TYPES = False _MCP_MESSAGE_HANDLER_SUPPORTED = False +_MCP_NEW_HTTP = False +streamablehttp_client = None +streamable_http_client = None # Conservative fallback for SDK builds that don't export LATEST_PROTOCOL_VERSION. # Streamable HTTP was introduced by 2025-03-26, so this remains valid for the # HTTP transport path even on older-but-supported SDK versions. LATEST_PROTOCOL_VERSION = "2025-03-26" + + +class _CompatType: + """Minimal attribute bag for MCP SDK types missing in older/newer builds.""" + + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + try: from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client @@ -191,20 +203,28 @@ try: from mcp.types import LATEST_PROTOCOL_VERSION except ImportError: logger.debug("mcp.types.LATEST_PROTOCOL_VERSION not available -- using fallback protocol version") - # Sampling types -- separated so older SDK versions don't break MCP support + # Sampling types -- import individually because SDK names changed across releases. try: - from mcp.types import ( - CreateMessageResult, - CreateMessageResultWithTools, - ErrorData, - SamplingCapability, - SamplingToolsCapability, - TextContent, - ToolUseContent, - ) + from mcp.types import CreateMessageResult, ErrorData, SamplingCapability, TextContent + + try: + from mcp.types import CreateMessageResultWithTools + except ImportError: + CreateMessageResultWithTools = _CompatType + + try: + from mcp.types import SamplingToolsCapability + except ImportError: + SamplingToolsCapability = _CompatType + + try: + from mcp.types import ToolUseContent + except ImportError: + ToolUseContent = _CompatType + _MCP_SAMPLING_TYPES = True except ImportError: - logger.debug("MCP sampling types not available -- sampling disabled") + logger.debug("MCP sampling base types not available -- sampling disabled") # Notification types for dynamic tool discovery (tools/list_changed) try: from mcp.types import ( @@ -868,7 +888,7 @@ class MCPServerTask: "_task", "_ready", "_shutdown_event", "_reconnect_event", "_tools", "_error", "_config", "_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock", - "_rpc_lock", + "_rpc_lock", "_pending_refresh_tasks", ) def __init__(self, name: str): @@ -895,8 +915,10 @@ class MCPServerTask: # list_changed notifications during startup; if the notification # handler calls list_tools while a normal tool call is in flight, the # stream can wedge and the user-visible tool call times out. Serialize - # client-initiated RPCs per server. + # client-initiated RPCs per server. The lock is also applied to HTTP + # transports for conservative per-server ordering. self._rpc_lock = asyncio.Lock() + self._pending_refresh_tasks: set[asyncio.Task] = set() def _is_http(self) -> bool: """Check if this server uses HTTP transport.""" @@ -904,6 +926,21 @@ class MCPServerTask: # ----- Dynamic tool discovery (notifications/tools/list_changed) ----- + async def _refresh_tools_task(self): + """Run a dynamic tool refresh and log failures from background tasks.""" + try: + await self._refresh_tools() + except asyncio.CancelledError: + raise + except Exception: + logger.exception("MCP server '%s': dynamic tool refresh failed", self.name) + + def _schedule_tools_refresh(self) -> None: + """Schedule a background tool refresh and keep it strongly referenced.""" + task = asyncio.create_task(self._refresh_tools_task()) + self._pending_refresh_tasks.add(task) + task.add_done_callback(self._pending_refresh_tasks.discard) + def _make_message_handler(self): """Build a ``message_handler`` callback for ``ClientSession``. @@ -932,7 +969,7 @@ class MCPServerTask: # subsequent tool calls time out. Do the refresh in # a separate task and let the handler return # promptly. - asyncio.create_task(self._refresh_tools()) + self._schedule_tools_refresh() case PromptListChangedNotification(): logger.debug("MCP server '%s': prompts/list_changed (ignored)", self.name) case ResourceListChangedNotification(): @@ -962,11 +999,20 @@ class MCPServerTask: tools_result = await self.session.list_tools() new_mcp_tools = tools_result.tools if hasattr(tools_result, "tools") else [] - # 2. Re-register with fresh tool list. Avoid deregistering first: - # live agent turns already have tool-call IDs pointing at the - # existing handler functions. Replacing entries in-place is enough - # for unchanged names and avoids transient "tool not connected" / - # stale-handler races during startup notifications. + # 2. Re-register with fresh tool list. Avoid nuke-and-repave for + # all names: live agent turns may already have tool-call IDs + # pointing at existing handler functions. Replacing entries + # in-place is enough for unchanged names and avoids transient + # "tool not connected" / stale-handler races during startup + # notifications. Tools absent from the fresh list are no longer + # callable, so remove only those stale registry entries first. + stale_tool_names = old_tool_names - { + f"mcp_{sanitize_mcp_name_component(self.name)}_" + f"{sanitize_mcp_name_component(tool.name)}" + for tool in new_mcp_tools + } + for tool_name in stale_tool_names: + registry.deregister(tool_name) # 3. Re-register with fresh tool list self._tools = new_mcp_tools @@ -1383,6 +1429,11 @@ class MCPServerTask: await self._task except asyncio.CancelledError: pass + if self._pending_refresh_tasks: + for task in list(self._pending_refresh_tasks): + task.cancel() + await asyncio.gather(*self._pending_refresh_tasks, return_exceptions=True) + self._pending_refresh_tasks.clear() for tool_name in list(getattr(self, "_registered_tool_names", [])): registry.deregister(tool_name) self._registered_tool_names = []