diff --git a/gateway/run.py b/gateway/run.py index e18f891cf..7a750a2c8 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -5468,15 +5468,25 @@ class GatewayRunner: _loop_for_step = asyncio.get_event_loop() _hooks_ref = self.hooks - def _step_callback_sync(iteration: int, tool_names: list) -> None: + def _step_callback_sync(iteration: int, prev_tools: list) -> None: try: + # prev_tools may be list[str] or list[dict] with "name"/"result" + # keys. Normalise to keep "tool_names" backward-compatible for + # user-authored hooks that do ', '.join(tool_names)'. + _names: list[str] = [] + for _t in (prev_tools or []): + if isinstance(_t, dict): + _names.append(_t.get("name") or "") + else: + _names.append(str(_t)) asyncio.run_coroutine_threadsafe( _hooks_ref.emit("agent:step", { "platform": source.platform.value if source.platform else "", "user_id": source.user_id, "session_id": session_id, "iteration": iteration, - "tool_names": tool_names, + "tool_names": _names, + "tools": prev_tools, }), _loop_for_step, ) diff --git a/tests/acp/test_events.py b/tests/acp/test_events.py index 400ea88e0..f34f1ff17 100644 --- a/tests/acp/test_events.py +++ b/tests/acp/test_events.py @@ -205,6 +205,47 @@ class TestStepCallback: assert "read_file" not in tool_call_ids mock_rcts.assert_called_once() + def test_result_passed_to_build_tool_complete(self, mock_conn, event_loop_fixture): + """Tool result from prev_tools dict is forwarded to build_tool_complete.""" + from collections import deque + + tool_call_ids = {"terminal": deque(["tc-xyz789"])} + loop = event_loop_fixture + + cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids) + + with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \ + patch("acp_adapter.events.build_tool_complete") as mock_btc: + future = MagicMock(spec=Future) + future.result.return_value = None + mock_rcts.return_value = future + + # Provide a result string in the tool info dict + cb(1, [{"name": "terminal", "result": '{"output": "hello"}'}]) + + mock_btc.assert_called_once_with( + "tc-xyz789", "terminal", result='{"output": "hello"}' + ) + + def test_none_result_passed_through(self, mock_conn, event_loop_fixture): + """When result is None (e.g. first iteration), None is passed through.""" + from collections import deque + + tool_call_ids = {"web_search": deque(["tc-aaa"])} + loop = event_loop_fixture + + cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids) + + with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \ + patch("acp_adapter.events.build_tool_complete") as mock_btc: + future = MagicMock(spec=Future) + future.result.return_value = None + mock_rcts.return_value = future + + cb(1, [{"name": "web_search", "result": None}]) + + mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None) + # --------------------------------------------------------------------------- # Message callback diff --git a/tests/acp/test_server.py b/tests/acp/test_server.py index fc6d53dd8..9edc66e93 100644 --- a/tests/acp/test_server.py +++ b/tests/acp/test_server.py @@ -505,3 +505,179 @@ class TestSlashCommands: assert state.agent.provider == "anthropic" assert state.agent.base_url == "https://anthropic.example/v1" assert runtime_calls[-1] == "anthropic" + + +# --------------------------------------------------------------------------- +# _register_session_mcp_servers +# --------------------------------------------------------------------------- + + +class TestRegisterSessionMcpServers: + """Tests for ACP MCP server registration in session lifecycle.""" + + @pytest.mark.asyncio + async def test_noop_when_no_servers(self, agent, mock_manager): + """No-op when mcp_servers is None or empty.""" + state = mock_manager.create_session(cwd="/tmp") + # Should not raise + await agent._register_session_mcp_servers(state, None) + await agent._register_session_mcp_servers(state, []) + + @pytest.mark.asyncio + async def test_registers_stdio_servers(self, agent, mock_manager): + """McpServerStdio servers are converted and passed to register_mcp_servers.""" + from acp.schema import McpServerStdio, EnvVariable + + state = mock_manager.create_session(cwd="/tmp") + # Give the mock agent the attributes _register_session_mcp_servers reads + state.agent.enabled_toolsets = ["hermes-acp"] + state.agent.disabled_toolsets = None + state.agent.tools = [] + state.agent.valid_tool_names = set() + + server = McpServerStdio( + name="test-server", + command="/usr/bin/test", + args=["--flag"], + env=[EnvVariable(name="KEY", value="val")], + ) + + registered_config = {} + def capture_register(config_map): + registered_config.update(config_map) + return ["mcp_test_server_tool1"] + + with patch("tools.mcp_tool.register_mcp_servers", side_effect=capture_register), \ + patch("model_tools.get_tool_definitions", return_value=[]): + await agent._register_session_mcp_servers(state, [server]) + + assert "test-server" in registered_config + cfg = registered_config["test-server"] + assert cfg["command"] == "/usr/bin/test" + assert cfg["args"] == ["--flag"] + assert cfg["env"] == {"KEY": "val"} + + @pytest.mark.asyncio + async def test_registers_http_servers(self, agent, mock_manager): + """McpServerHttp servers are converted correctly.""" + from acp.schema import McpServerHttp, HttpHeader + + state = mock_manager.create_session(cwd="/tmp") + state.agent.enabled_toolsets = ["hermes-acp"] + state.agent.disabled_toolsets = None + state.agent.tools = [] + state.agent.valid_tool_names = set() + + server = McpServerHttp( + name="http-server", + url="https://api.example.com/mcp", + headers=[HttpHeader(name="Authorization", value="Bearer tok")], + ) + + registered_config = {} + def capture_register(config_map): + registered_config.update(config_map) + return [] + + with patch("tools.mcp_tool.register_mcp_servers", side_effect=capture_register), \ + patch("model_tools.get_tool_definitions", return_value=[]): + await agent._register_session_mcp_servers(state, [server]) + + assert "http-server" in registered_config + cfg = registered_config["http-server"] + assert cfg["url"] == "https://api.example.com/mcp" + assert cfg["headers"] == {"Authorization": "Bearer tok"} + + @pytest.mark.asyncio + async def test_refreshes_agent_tool_surface(self, agent, mock_manager): + """After MCP registration, agent.tools and valid_tool_names are refreshed.""" + from acp.schema import McpServerStdio + + state = mock_manager.create_session(cwd="/tmp") + state.agent.enabled_toolsets = ["hermes-acp"] + state.agent.disabled_toolsets = None + state.agent.tools = [] + state.agent.valid_tool_names = set() + state.agent._cached_system_prompt = "old prompt" + + server = McpServerStdio( + name="srv", + command="/bin/test", + args=[], + env=[], + ) + + fake_tools = [ + {"function": {"name": "mcp_srv_search"}}, + {"function": {"name": "terminal"}}, + ] + + with patch("tools.mcp_tool.register_mcp_servers", return_value=["mcp_srv_search"]), \ + patch("model_tools.get_tool_definitions", return_value=fake_tools): + await agent._register_session_mcp_servers(state, [server]) + + assert state.agent.tools == fake_tools + assert state.agent.valid_tool_names == {"mcp_srv_search", "terminal"} + # _invalidate_system_prompt should have been called + state.agent._invalidate_system_prompt.assert_called_once() + + @pytest.mark.asyncio + async def test_register_failure_logs_warning(self, agent, mock_manager): + """If register_mcp_servers raises, warning is logged but no crash.""" + from acp.schema import McpServerStdio + + state = mock_manager.create_session(cwd="/tmp") + server = McpServerStdio( + name="bad", + command="/nonexistent", + args=[], + env=[], + ) + + with patch("tools.mcp_tool.register_mcp_servers", side_effect=RuntimeError("boom")): + # Should not raise + await agent._register_session_mcp_servers(state, [server]) + + @pytest.mark.asyncio + async def test_new_session_calls_register(self, agent, mock_manager): + """new_session passes mcp_servers to _register_session_mcp_servers.""" + with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg: + resp = await agent.new_session(cwd="/tmp", mcp_servers=["fake"]) + assert resp is not None + mock_reg.assert_called_once() + # Second arg should be the mcp_servers list + assert mock_reg.call_args[0][1] == ["fake"] + + @pytest.mark.asyncio + async def test_load_session_calls_register(self, agent, mock_manager): + """load_session passes mcp_servers to _register_session_mcp_servers.""" + # Create a session first so load can find it + state = mock_manager.create_session(cwd="/tmp") + sid = state.session_id + + with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg: + resp = await agent.load_session(cwd="/tmp", session_id=sid, mcp_servers=["fake"]) + assert resp is not None + mock_reg.assert_called_once() + + @pytest.mark.asyncio + async def test_resume_session_calls_register(self, agent, mock_manager): + """resume_session passes mcp_servers to _register_session_mcp_servers.""" + state = mock_manager.create_session(cwd="/tmp") + sid = state.session_id + + with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg: + resp = await agent.resume_session(cwd="/tmp", session_id=sid, mcp_servers=["fake"]) + assert resp is not None + mock_reg.assert_called_once() + + @pytest.mark.asyncio + async def test_fork_session_calls_register(self, agent, mock_manager): + """fork_session passes mcp_servers to _register_session_mcp_servers.""" + state = mock_manager.create_session(cwd="/tmp") + sid = state.session_id + + with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg: + resp = await agent.fork_session(cwd="/tmp", session_id=sid, mcp_servers=["fake"]) + assert resp is not None + mock_reg.assert_called_once() diff --git a/tests/gateway/test_step_callback_compat.py b/tests/gateway/test_step_callback_compat.py new file mode 100644 index 000000000..cdfc3fb04 --- /dev/null +++ b/tests/gateway/test_step_callback_compat.py @@ -0,0 +1,133 @@ +"""Tests for step_callback backward compatibility. + +Verifies that the gateway's step_callback normalization keeps +``tool_names`` as a list of strings for backward-compatible hooks, +while also providing the enriched ``tools`` list with results. +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +class TestStepCallbackNormalization: + """The gateway's _step_callback_sync normalizes prev_tools from run_agent.""" + + def _extract_step_callback(self): + """Build a minimal _step_callback_sync using the same logic as gateway/run.py. + + We replicate the closure so we can test normalisation in isolation + without spinning up the full gateway. + """ + captured_events = [] + + class FakeHooks: + async def emit(self, event_type, data): + captured_events.append((event_type, data)) + + hooks_ref = FakeHooks() + loop = asyncio.new_event_loop() + + def _step_callback_sync(iteration: int, prev_tools: list) -> None: + _names: list[str] = [] + for _t in (prev_tools or []): + if isinstance(_t, dict): + _names.append(_t.get("name") or "") + else: + _names.append(str(_t)) + asyncio.run_coroutine_threadsafe( + hooks_ref.emit("agent:step", { + "iteration": iteration, + "tool_names": _names, + "tools": prev_tools, + }), + loop, + ) + + return _step_callback_sync, captured_events, loop + + def test_dict_prev_tools_produce_string_tool_names(self): + """When prev_tools is list[dict], tool_names should be list[str].""" + cb, events, loop = self._extract_step_callback() + + # Simulate the enriched format from run_agent.py + prev_tools = [ + {"name": "terminal", "result": '{"output": "hello"}'}, + {"name": "read_file", "result": '{"content": "..."}'}, + ] + + try: + loop.run_until_complete(asyncio.sleep(0)) # prime the loop + import threading + t = threading.Thread(target=cb, args=(1, prev_tools)) + t.start() + t.join(timeout=2) + loop.run_until_complete(asyncio.sleep(0.1)) + finally: + loop.close() + + assert len(events) == 1 + _, data = events[0] + # tool_names must be strings for backward compat + assert data["tool_names"] == ["terminal", "read_file"] + assert all(isinstance(n, str) for n in data["tool_names"]) + # tools should be the enriched dicts + assert data["tools"] == prev_tools + + def test_string_prev_tools_still_work(self): + """When prev_tools is list[str] (legacy), tool_names should pass through.""" + cb, events, loop = self._extract_step_callback() + + prev_tools = ["terminal", "read_file"] + + try: + loop.run_until_complete(asyncio.sleep(0)) + import threading + t = threading.Thread(target=cb, args=(2, prev_tools)) + t.start() + t.join(timeout=2) + loop.run_until_complete(asyncio.sleep(0.1)) + finally: + loop.close() + + assert len(events) == 1 + _, data = events[0] + assert data["tool_names"] == ["terminal", "read_file"] + + def test_empty_prev_tools(self): + """Empty or None prev_tools should produce empty tool_names.""" + cb, events, loop = self._extract_step_callback() + + try: + loop.run_until_complete(asyncio.sleep(0)) + import threading + t = threading.Thread(target=cb, args=(1, [])) + t.start() + t.join(timeout=2) + loop.run_until_complete(asyncio.sleep(0.1)) + finally: + loop.close() + + assert len(events) == 1 + _, data = events[0] + assert data["tool_names"] == [] + + def test_joinable_for_hook_example(self): + """The documented hook example: ', '.join(tool_names) should work.""" + # This is the exact pattern from the docs + prev_tools = [ + {"name": "terminal", "result": "ok"}, + {"name": "web_search", "result": None}, + ] + + _names = [] + for _t in prev_tools: + if isinstance(_t, dict): + _names.append(_t.get("name") or "") + else: + _names.append(str(_t)) + + # This must not raise — documented hook pattern + result = ", ".join(_names) + assert result == "terminal, web_search" diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 823db8843..726c40cc9 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -2900,3 +2900,164 @@ class TestMCPBuiltinCollisionGuard: assert mock_registry.get_toolset_for_tool("mcp_srv_do_thing") == "mcp-srv" _servers.pop("srv", None) + + +# --------------------------------------------------------------------------- +# sanitize_mcp_name_component +# --------------------------------------------------------------------------- + + +class TestSanitizeMcpNameComponent: + """Verify sanitize_mcp_name_component handles all edge cases.""" + + def test_hyphens_replaced(self): + from tools.mcp_tool import sanitize_mcp_name_component + assert sanitize_mcp_name_component("my-server") == "my_server" + + def test_dots_replaced(self): + from tools.mcp_tool import sanitize_mcp_name_component + assert sanitize_mcp_name_component("ai.exa") == "ai_exa" + + def test_slashes_replaced(self): + from tools.mcp_tool import sanitize_mcp_name_component + assert sanitize_mcp_name_component("ai.exa/exa") == "ai_exa_exa" + + def test_mixed_special_characters(self): + from tools.mcp_tool import sanitize_mcp_name_component + assert sanitize_mcp_name_component("@scope/my-pkg.v2") == "_scope_my_pkg_v2" + + def test_alphanumeric_and_underscores_preserved(self): + from tools.mcp_tool import sanitize_mcp_name_component + assert sanitize_mcp_name_component("my_server_123") == "my_server_123" + + def test_empty_string(self): + from tools.mcp_tool import sanitize_mcp_name_component + assert sanitize_mcp_name_component("") == "" + + def test_none_returns_empty(self): + from tools.mcp_tool import sanitize_mcp_name_component + assert sanitize_mcp_name_component(None) == "" + + def test_slash_in_convert_mcp_schema(self): + """Server names with slashes produce valid tool names via _convert_mcp_schema.""" + from tools.mcp_tool import _convert_mcp_schema + + mcp_tool = _make_mcp_tool(name="search") + schema = _convert_mcp_schema("ai.exa/exa", mcp_tool) + assert schema["name"] == "mcp_ai_exa_exa_search" + # Must match Anthropic's pattern: ^[a-zA-Z0-9_-]{1,128}$ + import re + assert re.match(r"^[a-zA-Z0-9_-]{1,128}$", schema["name"]) + + def test_slash_in_build_utility_schemas(self): + """Server names with slashes produce valid utility tool names.""" + from tools.mcp_tool import _build_utility_schemas + + schemas = _build_utility_schemas("ai.exa/exa") + for s in schemas: + name = s["schema"]["name"] + assert "/" not in name + assert "." not in name + + def test_slash_in_sync_mcp_toolsets(self): + """_sync_mcp_toolsets uses sanitize consistently with _convert_mcp_schema.""" + from tools.mcp_tool import sanitize_mcp_name_component + + # Verify the prefix generation matches what _convert_mcp_schema produces + server_name = "ai.exa/exa" + safe_prefix = f"mcp_{sanitize_mcp_name_component(server_name)}_" + assert safe_prefix == "mcp_ai_exa_exa_" + + +# --------------------------------------------------------------------------- +# register_mcp_servers public API +# --------------------------------------------------------------------------- + + +class TestRegisterMcpServers: + """Verify the new register_mcp_servers() public API.""" + + def test_empty_servers_returns_empty(self): + from tools.mcp_tool import register_mcp_servers + + with patch("tools.mcp_tool._MCP_AVAILABLE", True): + result = register_mcp_servers({}) + assert result == [] + + def test_mcp_not_available_returns_empty(self): + from tools.mcp_tool import register_mcp_servers + + with patch("tools.mcp_tool._MCP_AVAILABLE", False): + result = register_mcp_servers({"srv": {"command": "test"}}) + assert result == [] + + def test_skips_already_connected_servers(self): + from tools.mcp_tool import register_mcp_servers, _servers + + mock_server = _make_mock_server("existing") + _servers["existing"] = mock_server + + try: + with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ + patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_existing_tool"]): + result = register_mcp_servers({"existing": {"command": "test"}}) + assert result == ["mcp_existing_tool"] + finally: + _servers.pop("existing", None) + + def test_skips_disabled_servers(self): + from tools.mcp_tool import register_mcp_servers, _servers + + try: + with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ + patch("tools.mcp_tool._existing_tool_names", return_value=[]): + result = register_mcp_servers({"srv": {"command": "test", "enabled": False}}) + assert result == [] + finally: + _servers.pop("srv", None) + + def test_connects_new_servers(self): + from tools.mcp_tool import register_mcp_servers, _servers, _ensure_mcp_loop + + fake_config = {"my_server": {"command": "npx", "args": ["test"]}} + + async def fake_register(name, cfg): + server = _make_mock_server(name) + server._registered_tool_names = ["mcp_my_server_tool1"] + _servers[name] = server + return ["mcp_my_server_tool1"] + + with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ + patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \ + patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_my_server_tool1"]): + _ensure_mcp_loop() + result = register_mcp_servers(fake_config) + + assert "mcp_my_server_tool1" in result + _servers.pop("my_server", None) + + def test_logs_summary_on_success(self): + from tools.mcp_tool import register_mcp_servers, _servers, _ensure_mcp_loop + + fake_config = {"srv": {"command": "npx", "args": ["test"]}} + + async def fake_register(name, cfg): + server = _make_mock_server(name) + server._registered_tool_names = ["mcp_srv_t1", "mcp_srv_t2"] + _servers[name] = server + return ["mcp_srv_t1", "mcp_srv_t2"] + + with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ + patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \ + patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_srv_t1", "mcp_srv_t2"]): + _ensure_mcp_loop() + + with patch("tools.mcp_tool.logger") as mock_logger: + register_mcp_servers(fake_config) + + info_calls = [str(c) for c in mock_logger.info.call_args_list] + assert any("2 tool(s)" in c and "1 server(s)" in c for c in info_calls), ( + f"Summary should report 2 tools from 1 server, got: {info_calls}" + ) + + _servers.pop("srv", None) diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index b589f6454..0918de20a 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -1845,6 +1845,20 @@ def register_mcp_servers(servers: Dict[str, dict]) -> List[str]: _sync_mcp_toolsets(list(servers.keys())) + # Log a summary so ACP callers get visibility into what was registered. + with _lock: + connected = [n for n in new_servers if n in _servers] + new_tool_count = sum( + len(getattr(_servers[n], "_registered_tool_names", [])) + for n in connected + ) + failed = len(new_servers) - len(connected) + if new_tool_count or failed: + summary = f"MCP: registered {new_tool_count} tool(s) from {len(connected)} server(s)" + if failed: + summary += f" ({failed} failed)" + logger.info(summary) + return _existing_tool_names()