mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-08 03:01:47 +00:00
fix(mcp): coerce numeric tool args defensively
This commit is contained in:
parent
43cf72a458
commit
62c2f5d8d2
2 changed files with 132 additions and 1 deletions
32
mcp_serve.py
32
mcp_serve.py
|
|
@ -115,6 +115,25 @@ def _load_channel_directory() -> dict:
|
|||
return {}
|
||||
|
||||
|
||||
def _coerce_int(
|
||||
value,
|
||||
*,
|
||||
default: int,
|
||||
minimum: int,
|
||||
maximum: int,
|
||||
) -> int:
|
||||
"""Coerce value to int with fallback and clamping.
|
||||
|
||||
Used at MCP tool boundaries to handle invalid types from external clients.
|
||||
Returns default if value cannot be converted to int.
|
||||
"""
|
||||
try:
|
||||
coerced = int(value)
|
||||
except (TypeError, ValueError):
|
||||
coerced = default
|
||||
return max(minimum, min(coerced, maximum))
|
||||
|
||||
|
||||
def _extract_message_content(msg: dict) -> str:
|
||||
"""Extract text content from a message, handling multi-part content."""
|
||||
content = msg.get("content", "")
|
||||
|
|
@ -465,6 +484,7 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP":
|
|||
limit: Maximum number of conversations to return (default 50)
|
||||
search: Optional text to filter conversations by name
|
||||
"""
|
||||
limit = _coerce_int(limit, default=50, minimum=1, maximum=200)
|
||||
entries = _load_sessions_index()
|
||||
conversations = []
|
||||
|
||||
|
|
@ -552,6 +572,7 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP":
|
|||
session_key: The session key from conversations_list
|
||||
limit: Maximum number of messages to return (default 50, most recent)
|
||||
"""
|
||||
limit = _coerce_int(limit, default=50, minimum=1, maximum=200)
|
||||
entries = _load_sessions_index()
|
||||
entry = entries.get(session_key)
|
||||
if not entry:
|
||||
|
|
@ -664,6 +685,8 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP":
|
|||
session_key: Optional filter to one conversation
|
||||
limit: Maximum events to return (default 20)
|
||||
"""
|
||||
after_cursor = _coerce_int(after_cursor, default=0, minimum=0, maximum=10**18)
|
||||
limit = _coerce_int(limit, default=20, minimum=1, maximum=200)
|
||||
result = bridge.poll_events(
|
||||
after_cursor=after_cursor,
|
||||
session_key=session_key,
|
||||
|
|
@ -689,10 +712,17 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP":
|
|||
session_key: Optional filter to one conversation
|
||||
timeout_ms: Maximum wait time in milliseconds (default 30000)
|
||||
"""
|
||||
after_cursor = _coerce_int(after_cursor, default=0, minimum=0, maximum=10**18)
|
||||
timeout_ms = _coerce_int(
|
||||
timeout_ms,
|
||||
default=30000,
|
||||
minimum=0,
|
||||
maximum=300000,
|
||||
) # Cap at 5 minutes
|
||||
event = bridge.wait_for_event(
|
||||
after_cursor=after_cursor,
|
||||
session_key=session_key,
|
||||
timeout_ms=min(timeout_ms, 300000), # Cap at 5 minutes
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
if event:
|
||||
return json.dumps({"event": event}, indent=2)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ Three layers of tests:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
|
|
@ -207,6 +208,54 @@ def mock_session_db(tmp_path, populated_sessions_dir):
|
|||
return TestSessionDB()
|
||||
|
||||
|
||||
class _FakeTool:
|
||||
def __init__(self, fn):
|
||||
self.name = fn.__name__
|
||||
self.description = inspect.getdoc(fn) or ""
|
||||
self.fn = fn
|
||||
|
||||
|
||||
class _FakeToolManager:
|
||||
def __init__(self):
|
||||
self._tools = {}
|
||||
|
||||
def add_tool(self, fn):
|
||||
self._tools[fn.__name__] = _FakeTool(fn)
|
||||
|
||||
async def call_tool(self, name, args=None):
|
||||
return self._tools[name].fn(**(args or {}))
|
||||
|
||||
def list_tools(self):
|
||||
return list(self._tools.values())
|
||||
|
||||
|
||||
class _FakeFastMCP:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._tool_manager = _FakeToolManager()
|
||||
|
||||
def tool(self):
|
||||
def decorator(fn):
|
||||
self._tool_manager.add_tool(fn)
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_mcp_server(populated_sessions_dir, mock_session_db, monkeypatch):
|
||||
import mcp_serve
|
||||
|
||||
monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: populated_sessions_dir)
|
||||
monkeypatch.setattr(mcp_serve, "_get_session_db", lambda: mock_session_db)
|
||||
monkeypatch.setattr(mcp_serve, "_load_channel_directory", lambda: {})
|
||||
monkeypatch.setattr(mcp_serve, "_MCP_SERVER_AVAILABLE", True)
|
||||
monkeypatch.setattr(mcp_serve, "FastMCP", _FakeFastMCP)
|
||||
|
||||
bridge = mcp_serve.EventBridge()
|
||||
server = mcp_serve.create_mcp_server(event_bridge=bridge)
|
||||
return server, bridge
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. UNIT TESTS — helpers, extraction, attachments
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -229,6 +278,15 @@ class TestHelpers:
|
|||
result = _get_sessions_dir()
|
||||
assert result == tmp_path / "sessions"
|
||||
|
||||
def test_coerce_int_handles_invalid_and_out_of_range_values(self):
|
||||
from mcp_serve import _coerce_int
|
||||
|
||||
assert _coerce_int(None, default=50, minimum=1, maximum=200) == 50
|
||||
assert _coerce_int("20", default=50, minimum=1, maximum=200) == 20
|
||||
assert _coerce_int("bad", default=50, minimum=1, maximum=200) == 50
|
||||
assert _coerce_int(999, default=50, minimum=1, maximum=200) == 200
|
||||
assert _coerce_int(-5, default=50, minimum=1, maximum=200) == 1
|
||||
|
||||
def test_load_sessions_index_empty(self, sessions_dir, monkeypatch):
|
||||
import mcp_serve
|
||||
monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: sessions_dir)
|
||||
|
|
@ -689,6 +747,49 @@ class TestE2EEventsWait:
|
|||
result = _run_tool(server, "events_wait", {"timeout_ms": 999999})
|
||||
assert result["event"] is not None
|
||||
|
||||
class TestMCPToolParameterCoercion:
|
||||
def test_conversations_list_coerces_string_limit(self, fake_mcp_server, _event_loop):
|
||||
server, _ = fake_mcp_server
|
||||
result = _run_tool(server, "conversations_list", {"limit": "2"})
|
||||
assert result["count"] == 2
|
||||
|
||||
def test_messages_read_coerces_string_limit(self, fake_mcp_server, _event_loop):
|
||||
server, _ = fake_mcp_server
|
||||
result = _run_tool(
|
||||
server,
|
||||
"messages_read",
|
||||
{"session_key": "agent:main:telegram:dm:123456", "limit": "2"},
|
||||
)
|
||||
assert result["count"] == 2
|
||||
|
||||
def test_events_poll_coerces_string_cursor_and_limit(self, fake_mcp_server, _event_loop):
|
||||
from mcp_serve import QueueEvent
|
||||
|
||||
server, bridge = fake_mcp_server
|
||||
bridge._enqueue(QueueEvent(cursor=0, type="message", session_key="a"))
|
||||
bridge._enqueue(QueueEvent(cursor=0, type="message", session_key="b"))
|
||||
|
||||
result = _run_tool(server, "events_poll", {"after_cursor": "0", "limit": "1"})
|
||||
assert len(result["events"]) == 1
|
||||
assert result["next_cursor"] == 1
|
||||
|
||||
def test_events_wait_coerces_invalid_timeout(self, fake_mcp_server, _event_loop):
|
||||
from mcp_serve import QueueEvent
|
||||
|
||||
server, bridge = fake_mcp_server
|
||||
bridge._enqueue(
|
||||
QueueEvent(
|
||||
cursor=0,
|
||||
type="message",
|
||||
session_key="test",
|
||||
data={"content": "waiting for this"},
|
||||
)
|
||||
)
|
||||
|
||||
result = _run_tool(server, "events_wait", {"after_cursor": "0", "timeout_ms": "bad"})
|
||||
assert result["event"] is not None
|
||||
assert result["event"]["content"] == "waiting for this"
|
||||
|
||||
|
||||
class TestE2EMessagesSend:
|
||||
def test_send_missing_args(self, mcp_server_e2e, _event_loop):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue