mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-13 03:52:00 +00:00
fix(mcp): coerce numeric tool args defensively
This commit is contained in:
parent
43cf72a458
commit
62c2f5d8d2
2 changed files with 132 additions and 1 deletions
32
mcp_serve.py
32
mcp_serve.py
|
|
@ -115,6 +115,25 @@ def _load_channel_directory() -> dict:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_int(
|
||||||
|
value,
|
||||||
|
*,
|
||||||
|
default: int,
|
||||||
|
minimum: int,
|
||||||
|
maximum: int,
|
||||||
|
) -> int:
|
||||||
|
"""Coerce value to int with fallback and clamping.
|
||||||
|
|
||||||
|
Used at MCP tool boundaries to handle invalid types from external clients.
|
||||||
|
Returns default if value cannot be converted to int.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
coerced = int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
coerced = default
|
||||||
|
return max(minimum, min(coerced, maximum))
|
||||||
|
|
||||||
|
|
||||||
def _extract_message_content(msg: dict) -> str:
|
def _extract_message_content(msg: dict) -> str:
|
||||||
"""Extract text content from a message, handling multi-part content."""
|
"""Extract text content from a message, handling multi-part content."""
|
||||||
content = msg.get("content", "")
|
content = msg.get("content", "")
|
||||||
|
|
@ -465,6 +484,7 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP":
|
||||||
limit: Maximum number of conversations to return (default 50)
|
limit: Maximum number of conversations to return (default 50)
|
||||||
search: Optional text to filter conversations by name
|
search: Optional text to filter conversations by name
|
||||||
"""
|
"""
|
||||||
|
limit = _coerce_int(limit, default=50, minimum=1, maximum=200)
|
||||||
entries = _load_sessions_index()
|
entries = _load_sessions_index()
|
||||||
conversations = []
|
conversations = []
|
||||||
|
|
||||||
|
|
@ -552,6 +572,7 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP":
|
||||||
session_key: The session key from conversations_list
|
session_key: The session key from conversations_list
|
||||||
limit: Maximum number of messages to return (default 50, most recent)
|
limit: Maximum number of messages to return (default 50, most recent)
|
||||||
"""
|
"""
|
||||||
|
limit = _coerce_int(limit, default=50, minimum=1, maximum=200)
|
||||||
entries = _load_sessions_index()
|
entries = _load_sessions_index()
|
||||||
entry = entries.get(session_key)
|
entry = entries.get(session_key)
|
||||||
if not entry:
|
if not entry:
|
||||||
|
|
@ -664,6 +685,8 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP":
|
||||||
session_key: Optional filter to one conversation
|
session_key: Optional filter to one conversation
|
||||||
limit: Maximum events to return (default 20)
|
limit: Maximum events to return (default 20)
|
||||||
"""
|
"""
|
||||||
|
after_cursor = _coerce_int(after_cursor, default=0, minimum=0, maximum=10**18)
|
||||||
|
limit = _coerce_int(limit, default=20, minimum=1, maximum=200)
|
||||||
result = bridge.poll_events(
|
result = bridge.poll_events(
|
||||||
after_cursor=after_cursor,
|
after_cursor=after_cursor,
|
||||||
session_key=session_key,
|
session_key=session_key,
|
||||||
|
|
@ -689,10 +712,17 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP":
|
||||||
session_key: Optional filter to one conversation
|
session_key: Optional filter to one conversation
|
||||||
timeout_ms: Maximum wait time in milliseconds (default 30000)
|
timeout_ms: Maximum wait time in milliseconds (default 30000)
|
||||||
"""
|
"""
|
||||||
|
after_cursor = _coerce_int(after_cursor, default=0, minimum=0, maximum=10**18)
|
||||||
|
timeout_ms = _coerce_int(
|
||||||
|
timeout_ms,
|
||||||
|
default=30000,
|
||||||
|
minimum=0,
|
||||||
|
maximum=300000,
|
||||||
|
) # Cap at 5 minutes
|
||||||
event = bridge.wait_for_event(
|
event = bridge.wait_for_event(
|
||||||
after_cursor=after_cursor,
|
after_cursor=after_cursor,
|
||||||
session_key=session_key,
|
session_key=session_key,
|
||||||
timeout_ms=min(timeout_ms, 300000), # Cap at 5 minutes
|
timeout_ms=timeout_ms,
|
||||||
)
|
)
|
||||||
if event:
|
if event:
|
||||||
return json.dumps({"event": event}, indent=2)
|
return json.dumps({"event": event}, indent=2)
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ Three layers of tests:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
|
@ -207,6 +208,54 @@ def mock_session_db(tmp_path, populated_sessions_dir):
|
||||||
return TestSessionDB()
|
return TestSessionDB()
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeTool:
|
||||||
|
def __init__(self, fn):
|
||||||
|
self.name = fn.__name__
|
||||||
|
self.description = inspect.getdoc(fn) or ""
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeToolManager:
|
||||||
|
def __init__(self):
|
||||||
|
self._tools = {}
|
||||||
|
|
||||||
|
def add_tool(self, fn):
|
||||||
|
self._tools[fn.__name__] = _FakeTool(fn)
|
||||||
|
|
||||||
|
async def call_tool(self, name, args=None):
|
||||||
|
return self._tools[name].fn(**(args or {}))
|
||||||
|
|
||||||
|
def list_tools(self):
|
||||||
|
return list(self._tools.values())
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeFastMCP:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._tool_manager = _FakeToolManager()
|
||||||
|
|
||||||
|
def tool(self):
|
||||||
|
def decorator(fn):
|
||||||
|
self._tool_manager.add_tool(fn)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_mcp_server(populated_sessions_dir, mock_session_db, monkeypatch):
|
||||||
|
import mcp_serve
|
||||||
|
|
||||||
|
monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: populated_sessions_dir)
|
||||||
|
monkeypatch.setattr(mcp_serve, "_get_session_db", lambda: mock_session_db)
|
||||||
|
monkeypatch.setattr(mcp_serve, "_load_channel_directory", lambda: {})
|
||||||
|
monkeypatch.setattr(mcp_serve, "_MCP_SERVER_AVAILABLE", True)
|
||||||
|
monkeypatch.setattr(mcp_serve, "FastMCP", _FakeFastMCP)
|
||||||
|
|
||||||
|
bridge = mcp_serve.EventBridge()
|
||||||
|
server = mcp_serve.create_mcp_server(event_bridge=bridge)
|
||||||
|
return server, bridge
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# 1. UNIT TESTS — helpers, extraction, attachments
|
# 1. UNIT TESTS — helpers, extraction, attachments
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -229,6 +278,15 @@ class TestHelpers:
|
||||||
result = _get_sessions_dir()
|
result = _get_sessions_dir()
|
||||||
assert result == tmp_path / "sessions"
|
assert result == tmp_path / "sessions"
|
||||||
|
|
||||||
|
def test_coerce_int_handles_invalid_and_out_of_range_values(self):
|
||||||
|
from mcp_serve import _coerce_int
|
||||||
|
|
||||||
|
assert _coerce_int(None, default=50, minimum=1, maximum=200) == 50
|
||||||
|
assert _coerce_int("20", default=50, minimum=1, maximum=200) == 20
|
||||||
|
assert _coerce_int("bad", default=50, minimum=1, maximum=200) == 50
|
||||||
|
assert _coerce_int(999, default=50, minimum=1, maximum=200) == 200
|
||||||
|
assert _coerce_int(-5, default=50, minimum=1, maximum=200) == 1
|
||||||
|
|
||||||
def test_load_sessions_index_empty(self, sessions_dir, monkeypatch):
|
def test_load_sessions_index_empty(self, sessions_dir, monkeypatch):
|
||||||
import mcp_serve
|
import mcp_serve
|
||||||
monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: sessions_dir)
|
monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: sessions_dir)
|
||||||
|
|
@ -689,6 +747,49 @@ class TestE2EEventsWait:
|
||||||
result = _run_tool(server, "events_wait", {"timeout_ms": 999999})
|
result = _run_tool(server, "events_wait", {"timeout_ms": 999999})
|
||||||
assert result["event"] is not None
|
assert result["event"] is not None
|
||||||
|
|
||||||
|
class TestMCPToolParameterCoercion:
|
||||||
|
def test_conversations_list_coerces_string_limit(self, fake_mcp_server, _event_loop):
|
||||||
|
server, _ = fake_mcp_server
|
||||||
|
result = _run_tool(server, "conversations_list", {"limit": "2"})
|
||||||
|
assert result["count"] == 2
|
||||||
|
|
||||||
|
def test_messages_read_coerces_string_limit(self, fake_mcp_server, _event_loop):
|
||||||
|
server, _ = fake_mcp_server
|
||||||
|
result = _run_tool(
|
||||||
|
server,
|
||||||
|
"messages_read",
|
||||||
|
{"session_key": "agent:main:telegram:dm:123456", "limit": "2"},
|
||||||
|
)
|
||||||
|
assert result["count"] == 2
|
||||||
|
|
||||||
|
def test_events_poll_coerces_string_cursor_and_limit(self, fake_mcp_server, _event_loop):
|
||||||
|
from mcp_serve import QueueEvent
|
||||||
|
|
||||||
|
server, bridge = fake_mcp_server
|
||||||
|
bridge._enqueue(QueueEvent(cursor=0, type="message", session_key="a"))
|
||||||
|
bridge._enqueue(QueueEvent(cursor=0, type="message", session_key="b"))
|
||||||
|
|
||||||
|
result = _run_tool(server, "events_poll", {"after_cursor": "0", "limit": "1"})
|
||||||
|
assert len(result["events"]) == 1
|
||||||
|
assert result["next_cursor"] == 1
|
||||||
|
|
||||||
|
def test_events_wait_coerces_invalid_timeout(self, fake_mcp_server, _event_loop):
|
||||||
|
from mcp_serve import QueueEvent
|
||||||
|
|
||||||
|
server, bridge = fake_mcp_server
|
||||||
|
bridge._enqueue(
|
||||||
|
QueueEvent(
|
||||||
|
cursor=0,
|
||||||
|
type="message",
|
||||||
|
session_key="test",
|
||||||
|
data={"content": "waiting for this"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _run_tool(server, "events_wait", {"after_cursor": "0", "timeout_ms": "bad"})
|
||||||
|
assert result["event"] is not None
|
||||||
|
assert result["event"]["content"] == "waiting for this"
|
||||||
|
|
||||||
|
|
||||||
class TestE2EMessagesSend:
|
class TestE2EMessagesSend:
|
||||||
def test_send_missing_args(self, mcp_server_e2e, _event_loop):
|
def test_send_missing_args(self, mcp_server_e2e, _event_loop):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue