fix(gateway): normalize step_callback prev_tools for backward compat

The PR changed prev_tools from list[str] to list[dict] with name/result
keys.  The gateway's _step_callback_sync passed this directly to hooks
as 'tool_names', breaking user-authored hooks that call
', '.join(tool_names).

Now:
- 'tool_names' always contains strings (backward-compatible)
- 'tools' carries the enriched dicts for hooks that want results

Also adds summary logging to register_mcp_servers() and comprehensive
tests for all three PR changes:
- sanitize_mcp_name_component edge cases
- register_mcp_servers public API
- _register_session_mcp_servers ACP integration
- step_callback result forwarding
- gateway normalization backward compat
This commit is contained in:
Teknium 2026-04-02 16:59:13 -07:00 committed by Teknium
parent f66b3fe76b
commit 21c2d32471
6 changed files with 537 additions and 2 deletions

View file

@ -5468,15 +5468,25 @@ class GatewayRunner:
_loop_for_step = asyncio.get_event_loop()
_hooks_ref = self.hooks
def _step_callback_sync(iteration: int, tool_names: list) -> None:
def _step_callback_sync(iteration: int, prev_tools: list) -> None:
try:
# prev_tools may be list[str] or list[dict] with "name"/"result"
# keys. Normalise to keep "tool_names" backward-compatible for
# user-authored hooks that do ', '.join(tool_names)'.
_names: list[str] = []
for _t in (prev_tools or []):
if isinstance(_t, dict):
_names.append(_t.get("name") or "")
else:
_names.append(str(_t))
asyncio.run_coroutine_threadsafe(
_hooks_ref.emit("agent:step", {
"platform": source.platform.value if source.platform else "",
"user_id": source.user_id,
"session_id": session_id,
"iteration": iteration,
"tool_names": tool_names,
"tool_names": _names,
"tools": prev_tools,
}),
_loop_for_step,
)

View file

@ -205,6 +205,47 @@ class TestStepCallback:
assert "read_file" not in tool_call_ids
mock_rcts.assert_called_once()
def test_result_passed_to_build_tool_complete(self, mock_conn, event_loop_fixture):
"""Tool result from prev_tools dict is forwarded to build_tool_complete."""
from collections import deque
tool_call_ids = {"terminal": deque(["tc-xyz789"])}
loop = event_loop_fixture
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
patch("acp_adapter.events.build_tool_complete") as mock_btc:
future = MagicMock(spec=Future)
future.result.return_value = None
mock_rcts.return_value = future
# Provide a result string in the tool info dict
cb(1, [{"name": "terminal", "result": '{"output": "hello"}'}])
mock_btc.assert_called_once_with(
"tc-xyz789", "terminal", result='{"output": "hello"}'
)
def test_none_result_passed_through(self, mock_conn, event_loop_fixture):
"""When result is None (e.g. first iteration), None is passed through."""
from collections import deque
tool_call_ids = {"web_search": deque(["tc-aaa"])}
loop = event_loop_fixture
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
patch("acp_adapter.events.build_tool_complete") as mock_btc:
future = MagicMock(spec=Future)
future.result.return_value = None
mock_rcts.return_value = future
cb(1, [{"name": "web_search", "result": None}])
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None)
# ---------------------------------------------------------------------------
# Message callback

View file

@ -505,3 +505,179 @@ class TestSlashCommands:
assert state.agent.provider == "anthropic"
assert state.agent.base_url == "https://anthropic.example/v1"
assert runtime_calls[-1] == "anthropic"
# ---------------------------------------------------------------------------
# _register_session_mcp_servers
# ---------------------------------------------------------------------------
class TestRegisterSessionMcpServers:
"""Tests for ACP MCP server registration in session lifecycle."""
@pytest.mark.asyncio
async def test_noop_when_no_servers(self, agent, mock_manager):
"""No-op when mcp_servers is None or empty."""
state = mock_manager.create_session(cwd="/tmp")
# Should not raise
await agent._register_session_mcp_servers(state, None)
await agent._register_session_mcp_servers(state, [])
@pytest.mark.asyncio
async def test_registers_stdio_servers(self, agent, mock_manager):
"""McpServerStdio servers are converted and passed to register_mcp_servers."""
from acp.schema import McpServerStdio, EnvVariable
state = mock_manager.create_session(cwd="/tmp")
# Give the mock agent the attributes _register_session_mcp_servers reads
state.agent.enabled_toolsets = ["hermes-acp"]
state.agent.disabled_toolsets = None
state.agent.tools = []
state.agent.valid_tool_names = set()
server = McpServerStdio(
name="test-server",
command="/usr/bin/test",
args=["--flag"],
env=[EnvVariable(name="KEY", value="val")],
)
registered_config = {}
def capture_register(config_map):
registered_config.update(config_map)
return ["mcp_test_server_tool1"]
with patch("tools.mcp_tool.register_mcp_servers", side_effect=capture_register), \
patch("model_tools.get_tool_definitions", return_value=[]):
await agent._register_session_mcp_servers(state, [server])
assert "test-server" in registered_config
cfg = registered_config["test-server"]
assert cfg["command"] == "/usr/bin/test"
assert cfg["args"] == ["--flag"]
assert cfg["env"] == {"KEY": "val"}
@pytest.mark.asyncio
async def test_registers_http_servers(self, agent, mock_manager):
"""McpServerHttp servers are converted correctly."""
from acp.schema import McpServerHttp, HttpHeader
state = mock_manager.create_session(cwd="/tmp")
state.agent.enabled_toolsets = ["hermes-acp"]
state.agent.disabled_toolsets = None
state.agent.tools = []
state.agent.valid_tool_names = set()
server = McpServerHttp(
name="http-server",
url="https://api.example.com/mcp",
headers=[HttpHeader(name="Authorization", value="Bearer tok")],
)
registered_config = {}
def capture_register(config_map):
registered_config.update(config_map)
return []
with patch("tools.mcp_tool.register_mcp_servers", side_effect=capture_register), \
patch("model_tools.get_tool_definitions", return_value=[]):
await agent._register_session_mcp_servers(state, [server])
assert "http-server" in registered_config
cfg = registered_config["http-server"]
assert cfg["url"] == "https://api.example.com/mcp"
assert cfg["headers"] == {"Authorization": "Bearer tok"}
@pytest.mark.asyncio
async def test_refreshes_agent_tool_surface(self, agent, mock_manager):
"""After MCP registration, agent.tools and valid_tool_names are refreshed."""
from acp.schema import McpServerStdio
state = mock_manager.create_session(cwd="/tmp")
state.agent.enabled_toolsets = ["hermes-acp"]
state.agent.disabled_toolsets = None
state.agent.tools = []
state.agent.valid_tool_names = set()
state.agent._cached_system_prompt = "old prompt"
server = McpServerStdio(
name="srv",
command="/bin/test",
args=[],
env=[],
)
fake_tools = [
{"function": {"name": "mcp_srv_search"}},
{"function": {"name": "terminal"}},
]
with patch("tools.mcp_tool.register_mcp_servers", return_value=["mcp_srv_search"]), \
patch("model_tools.get_tool_definitions", return_value=fake_tools):
await agent._register_session_mcp_servers(state, [server])
assert state.agent.tools == fake_tools
assert state.agent.valid_tool_names == {"mcp_srv_search", "terminal"}
# _invalidate_system_prompt should have been called
state.agent._invalidate_system_prompt.assert_called_once()
@pytest.mark.asyncio
async def test_register_failure_logs_warning(self, agent, mock_manager):
"""If register_mcp_servers raises, warning is logged but no crash."""
from acp.schema import McpServerStdio
state = mock_manager.create_session(cwd="/tmp")
server = McpServerStdio(
name="bad",
command="/nonexistent",
args=[],
env=[],
)
with patch("tools.mcp_tool.register_mcp_servers", side_effect=RuntimeError("boom")):
# Should not raise
await agent._register_session_mcp_servers(state, [server])
@pytest.mark.asyncio
async def test_new_session_calls_register(self, agent, mock_manager):
"""new_session passes mcp_servers to _register_session_mcp_servers."""
with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg:
resp = await agent.new_session(cwd="/tmp", mcp_servers=["fake"])
assert resp is not None
mock_reg.assert_called_once()
# Second arg should be the mcp_servers list
assert mock_reg.call_args[0][1] == ["fake"]
@pytest.mark.asyncio
async def test_load_session_calls_register(self, agent, mock_manager):
"""load_session passes mcp_servers to _register_session_mcp_servers."""
# Create a session first so load can find it
state = mock_manager.create_session(cwd="/tmp")
sid = state.session_id
with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg:
resp = await agent.load_session(cwd="/tmp", session_id=sid, mcp_servers=["fake"])
assert resp is not None
mock_reg.assert_called_once()
@pytest.mark.asyncio
async def test_resume_session_calls_register(self, agent, mock_manager):
"""resume_session passes mcp_servers to _register_session_mcp_servers."""
state = mock_manager.create_session(cwd="/tmp")
sid = state.session_id
with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg:
resp = await agent.resume_session(cwd="/tmp", session_id=sid, mcp_servers=["fake"])
assert resp is not None
mock_reg.assert_called_once()
@pytest.mark.asyncio
async def test_fork_session_calls_register(self, agent, mock_manager):
"""fork_session passes mcp_servers to _register_session_mcp_servers."""
state = mock_manager.create_session(cwd="/tmp")
sid = state.session_id
with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg:
resp = await agent.fork_session(cwd="/tmp", session_id=sid, mcp_servers=["fake"])
assert resp is not None
mock_reg.assert_called_once()

View file

@ -0,0 +1,133 @@
"""Tests for step_callback backward compatibility.
Verifies that the gateway's step_callback normalization keeps
``tool_names`` as a list of strings for backward-compatible hooks,
while also providing the enriched ``tools`` list with results.
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
class TestStepCallbackNormalization:
"""The gateway's _step_callback_sync normalizes prev_tools from run_agent."""
def _extract_step_callback(self):
"""Build a minimal _step_callback_sync using the same logic as gateway/run.py.
We replicate the closure so we can test normalisation in isolation
without spinning up the full gateway.
"""
captured_events = []
class FakeHooks:
async def emit(self, event_type, data):
captured_events.append((event_type, data))
hooks_ref = FakeHooks()
loop = asyncio.new_event_loop()
def _step_callback_sync(iteration: int, prev_tools: list) -> None:
_names: list[str] = []
for _t in (prev_tools or []):
if isinstance(_t, dict):
_names.append(_t.get("name") or "")
else:
_names.append(str(_t))
asyncio.run_coroutine_threadsafe(
hooks_ref.emit("agent:step", {
"iteration": iteration,
"tool_names": _names,
"tools": prev_tools,
}),
loop,
)
return _step_callback_sync, captured_events, loop
def test_dict_prev_tools_produce_string_tool_names(self):
"""When prev_tools is list[dict], tool_names should be list[str]."""
cb, events, loop = self._extract_step_callback()
# Simulate the enriched format from run_agent.py
prev_tools = [
{"name": "terminal", "result": '{"output": "hello"}'},
{"name": "read_file", "result": '{"content": "..."}'},
]
try:
loop.run_until_complete(asyncio.sleep(0)) # prime the loop
import threading
t = threading.Thread(target=cb, args=(1, prev_tools))
t.start()
t.join(timeout=2)
loop.run_until_complete(asyncio.sleep(0.1))
finally:
loop.close()
assert len(events) == 1
_, data = events[0]
# tool_names must be strings for backward compat
assert data["tool_names"] == ["terminal", "read_file"]
assert all(isinstance(n, str) for n in data["tool_names"])
# tools should be the enriched dicts
assert data["tools"] == prev_tools
def test_string_prev_tools_still_work(self):
"""When prev_tools is list[str] (legacy), tool_names should pass through."""
cb, events, loop = self._extract_step_callback()
prev_tools = ["terminal", "read_file"]
try:
loop.run_until_complete(asyncio.sleep(0))
import threading
t = threading.Thread(target=cb, args=(2, prev_tools))
t.start()
t.join(timeout=2)
loop.run_until_complete(asyncio.sleep(0.1))
finally:
loop.close()
assert len(events) == 1
_, data = events[0]
assert data["tool_names"] == ["terminal", "read_file"]
def test_empty_prev_tools(self):
"""Empty or None prev_tools should produce empty tool_names."""
cb, events, loop = self._extract_step_callback()
try:
loop.run_until_complete(asyncio.sleep(0))
import threading
t = threading.Thread(target=cb, args=(1, []))
t.start()
t.join(timeout=2)
loop.run_until_complete(asyncio.sleep(0.1))
finally:
loop.close()
assert len(events) == 1
_, data = events[0]
assert data["tool_names"] == []
def test_joinable_for_hook_example(self):
"""The documented hook example: ', '.join(tool_names) should work."""
# This is the exact pattern from the docs
prev_tools = [
{"name": "terminal", "result": "ok"},
{"name": "web_search", "result": None},
]
_names = []
for _t in prev_tools:
if isinstance(_t, dict):
_names.append(_t.get("name") or "")
else:
_names.append(str(_t))
# This must not raise — documented hook pattern
result = ", ".join(_names)
assert result == "terminal, web_search"

View file

@ -2900,3 +2900,164 @@ class TestMCPBuiltinCollisionGuard:
assert mock_registry.get_toolset_for_tool("mcp_srv_do_thing") == "mcp-srv"
_servers.pop("srv", None)
# ---------------------------------------------------------------------------
# sanitize_mcp_name_component
# ---------------------------------------------------------------------------
class TestSanitizeMcpNameComponent:
"""Verify sanitize_mcp_name_component handles all edge cases."""
def test_hyphens_replaced(self):
from tools.mcp_tool import sanitize_mcp_name_component
assert sanitize_mcp_name_component("my-server") == "my_server"
def test_dots_replaced(self):
from tools.mcp_tool import sanitize_mcp_name_component
assert sanitize_mcp_name_component("ai.exa") == "ai_exa"
def test_slashes_replaced(self):
from tools.mcp_tool import sanitize_mcp_name_component
assert sanitize_mcp_name_component("ai.exa/exa") == "ai_exa_exa"
def test_mixed_special_characters(self):
from tools.mcp_tool import sanitize_mcp_name_component
assert sanitize_mcp_name_component("@scope/my-pkg.v2") == "_scope_my_pkg_v2"
def test_alphanumeric_and_underscores_preserved(self):
from tools.mcp_tool import sanitize_mcp_name_component
assert sanitize_mcp_name_component("my_server_123") == "my_server_123"
def test_empty_string(self):
from tools.mcp_tool import sanitize_mcp_name_component
assert sanitize_mcp_name_component("") == ""
def test_none_returns_empty(self):
from tools.mcp_tool import sanitize_mcp_name_component
assert sanitize_mcp_name_component(None) == ""
def test_slash_in_convert_mcp_schema(self):
"""Server names with slashes produce valid tool names via _convert_mcp_schema."""
from tools.mcp_tool import _convert_mcp_schema
mcp_tool = _make_mcp_tool(name="search")
schema = _convert_mcp_schema("ai.exa/exa", mcp_tool)
assert schema["name"] == "mcp_ai_exa_exa_search"
# Must match Anthropic's pattern: ^[a-zA-Z0-9_-]{1,128}$
import re
assert re.match(r"^[a-zA-Z0-9_-]{1,128}$", schema["name"])
def test_slash_in_build_utility_schemas(self):
"""Server names with slashes produce valid utility tool names."""
from tools.mcp_tool import _build_utility_schemas
schemas = _build_utility_schemas("ai.exa/exa")
for s in schemas:
name = s["schema"]["name"]
assert "/" not in name
assert "." not in name
def test_slash_in_sync_mcp_toolsets(self):
"""_sync_mcp_toolsets uses sanitize consistently with _convert_mcp_schema."""
from tools.mcp_tool import sanitize_mcp_name_component
# Verify the prefix generation matches what _convert_mcp_schema produces
server_name = "ai.exa/exa"
safe_prefix = f"mcp_{sanitize_mcp_name_component(server_name)}_"
assert safe_prefix == "mcp_ai_exa_exa_"
# ---------------------------------------------------------------------------
# register_mcp_servers public API
# ---------------------------------------------------------------------------
class TestRegisterMcpServers:
"""Verify the new register_mcp_servers() public API."""
def test_empty_servers_returns_empty(self):
from tools.mcp_tool import register_mcp_servers
with patch("tools.mcp_tool._MCP_AVAILABLE", True):
result = register_mcp_servers({})
assert result == []
def test_mcp_not_available_returns_empty(self):
from tools.mcp_tool import register_mcp_servers
with patch("tools.mcp_tool._MCP_AVAILABLE", False):
result = register_mcp_servers({"srv": {"command": "test"}})
assert result == []
def test_skips_already_connected_servers(self):
from tools.mcp_tool import register_mcp_servers, _servers
mock_server = _make_mock_server("existing")
_servers["existing"] = mock_server
try:
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_existing_tool"]):
result = register_mcp_servers({"existing": {"command": "test"}})
assert result == ["mcp_existing_tool"]
finally:
_servers.pop("existing", None)
def test_skips_disabled_servers(self):
from tools.mcp_tool import register_mcp_servers, _servers
try:
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
patch("tools.mcp_tool._existing_tool_names", return_value=[]):
result = register_mcp_servers({"srv": {"command": "test", "enabled": False}})
assert result == []
finally:
_servers.pop("srv", None)
def test_connects_new_servers(self):
from tools.mcp_tool import register_mcp_servers, _servers, _ensure_mcp_loop
fake_config = {"my_server": {"command": "npx", "args": ["test"]}}
async def fake_register(name, cfg):
server = _make_mock_server(name)
server._registered_tool_names = ["mcp_my_server_tool1"]
_servers[name] = server
return ["mcp_my_server_tool1"]
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \
patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_my_server_tool1"]):
_ensure_mcp_loop()
result = register_mcp_servers(fake_config)
assert "mcp_my_server_tool1" in result
_servers.pop("my_server", None)
def test_logs_summary_on_success(self):
from tools.mcp_tool import register_mcp_servers, _servers, _ensure_mcp_loop
fake_config = {"srv": {"command": "npx", "args": ["test"]}}
async def fake_register(name, cfg):
server = _make_mock_server(name)
server._registered_tool_names = ["mcp_srv_t1", "mcp_srv_t2"]
_servers[name] = server
return ["mcp_srv_t1", "mcp_srv_t2"]
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \
patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_srv_t1", "mcp_srv_t2"]):
_ensure_mcp_loop()
with patch("tools.mcp_tool.logger") as mock_logger:
register_mcp_servers(fake_config)
info_calls = [str(c) for c in mock_logger.info.call_args_list]
assert any("2 tool(s)" in c and "1 server(s)" in c for c in info_calls), (
f"Summary should report 2 tools from 1 server, got: {info_calls}"
)
_servers.pop("srv", None)

View file

@ -1845,6 +1845,20 @@ def register_mcp_servers(servers: Dict[str, dict]) -> List[str]:
_sync_mcp_toolsets(list(servers.keys()))
# Log a summary so ACP callers get visibility into what was registered.
with _lock:
connected = [n for n in new_servers if n in _servers]
new_tool_count = sum(
len(getattr(_servers[n], "_registered_tool_names", []))
for n in connected
)
failed = len(new_servers) - len(connected)
if new_tool_count or failed:
summary = f"MCP: registered {new_tool_count} tool(s) from {len(connected)} server(s)"
if failed:
summary += f" ({failed} failed)"
logger.info(summary)
return _existing_tool_names()