fix: keep mcp dynamic refresh tasks tracked

This commit is contained in:
Pony.Ma 2026-04-28 10:41:28 +08:00 committed by Teknium
parent 02ae152222
commit 1350d12b0b
2 changed files with 177 additions and 20 deletions

View file

@ -706,6 +706,106 @@ class TestMCPServerTask:
asyncio.run(_test())
def test_refresh_tools_deregisters_removed_tools(self):
"""Dynamic refresh removes stale registry entries for deleted tools."""
from tools.registry import ToolRegistry
from tools.mcp_tool import MCPServerTask
mock_registry = ToolRegistry()
server = MCPServerTask("srv")
server._config = {"command": "test"}
server._tools = [_make_mcp_tool("old"), _make_mcp_tool("keep")]
server._registered_tool_names = ["mcp_srv_old", "mcp_srv_keep"]
server.session = MagicMock()
server.session.list_tools = AsyncMock(
return_value=SimpleNamespace(tools=[_make_mcp_tool("keep"), _make_mcp_tool("new")])
)
with patch("tools.registry.registry", mock_registry):
mock_registry.register(
name="mcp_srv_old",
toolset="mcp-srv",
schema={"name": "mcp_srv_old", "description": "Old"},
handler=lambda *_args, **_kwargs: "{}",
)
mock_registry.register(
name="mcp_srv_keep",
toolset="mcp-srv",
schema={"name": "mcp_srv_keep", "description": "Keep"},
handler=lambda *_args, **_kwargs: "{}",
)
asyncio.run(server._refresh_tools())
names = mock_registry.get_all_tool_names()
assert "mcp_srv_old" not in names
assert "mcp_srv_keep" in names
assert "mcp_srv_new" in names
assert set(server._registered_tool_names) == {
"mcp_srv_keep",
"mcp_srv_new",
"mcp_srv_list_resources",
"mcp_srv_read_resource",
"mcp_srv_list_prompts",
"mcp_srv_get_prompt",
}
def test_schedule_tools_refresh_keeps_task_until_done(self):
"""Background refresh tasks are strongly referenced and then discarded."""
from tools.mcp_tool import MCPServerTask
async def _test():
started = asyncio.Event()
finish = asyncio.Event()
server = MCPServerTask("srv")
async def fake_refresh(_server):
started.set()
await finish.wait()
with patch.object(MCPServerTask, "_refresh_tools", new=fake_refresh):
server._schedule_tools_refresh()
await started.wait()
assert len(server._pending_refresh_tasks) == 1
task = next(iter(server._pending_refresh_tasks))
assert not task.done()
finish.set()
await task
await asyncio.sleep(0)
assert server._pending_refresh_tasks == set()
asyncio.run(_test())
def test_shutdown_cancels_pending_refresh_tasks(self):
"""shutdown() cancels in-flight background refresh tasks."""
from tools.mcp_tool import MCPServerTask
async def _test():
started = asyncio.Event()
cancelled = asyncio.Event()
server = MCPServerTask("srv")
async def fake_refresh(_server):
started.set()
try:
await asyncio.sleep(3600)
except asyncio.CancelledError:
cancelled.set()
raise
with patch.object(MCPServerTask, "_refresh_tools", new=fake_refresh):
server._schedule_tools_refresh()
await started.wait()
await server.shutdown()
assert cancelled.is_set()
assert server._pending_refresh_tasks == set()
asyncio.run(_test())
def test_empty_env_gets_safe_defaults(self):
"""Empty env dict gets safe default env vars (PATH, HOME, etc.)."""
from tools.mcp_tool import MCPServerTask
@ -1993,7 +2093,13 @@ try:
except ImportError:
ToolUseContent = _CompatType
from tools.mcp_tool import SamplingHandler, _safe_numeric
from tools.mcp_tool import (
CreateMessageResultWithTools,
SamplingHandler,
SamplingToolsCapability,
ToolUseContent,
_safe_numeric,
)
# ---------------------------------------------------------------------------