mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-08 03:01:47 +00:00
fix(mcp): coerce numeric tool args defensively
This commit is contained in:
parent
43cf72a458
commit
62c2f5d8d2
2 changed files with 132 additions and 1 deletions
|
|
@ -9,6 +9,7 @@ Three layers of tests:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
|
|
@ -207,6 +208,54 @@ def mock_session_db(tmp_path, populated_sessions_dir):
|
|||
return TestSessionDB()
|
||||
|
||||
|
||||
class _FakeTool:
|
||||
def __init__(self, fn):
|
||||
self.name = fn.__name__
|
||||
self.description = inspect.getdoc(fn) or ""
|
||||
self.fn = fn
|
||||
|
||||
|
||||
class _FakeToolManager:
|
||||
def __init__(self):
|
||||
self._tools = {}
|
||||
|
||||
def add_tool(self, fn):
|
||||
self._tools[fn.__name__] = _FakeTool(fn)
|
||||
|
||||
async def call_tool(self, name, args=None):
|
||||
return self._tools[name].fn(**(args or {}))
|
||||
|
||||
def list_tools(self):
|
||||
return list(self._tools.values())
|
||||
|
||||
|
||||
class _FakeFastMCP:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._tool_manager = _FakeToolManager()
|
||||
|
||||
def tool(self):
|
||||
def decorator(fn):
|
||||
self._tool_manager.add_tool(fn)
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_mcp_server(populated_sessions_dir, mock_session_db, monkeypatch):
|
||||
import mcp_serve
|
||||
|
||||
monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: populated_sessions_dir)
|
||||
monkeypatch.setattr(mcp_serve, "_get_session_db", lambda: mock_session_db)
|
||||
monkeypatch.setattr(mcp_serve, "_load_channel_directory", lambda: {})
|
||||
monkeypatch.setattr(mcp_serve, "_MCP_SERVER_AVAILABLE", True)
|
||||
monkeypatch.setattr(mcp_serve, "FastMCP", _FakeFastMCP)
|
||||
|
||||
bridge = mcp_serve.EventBridge()
|
||||
server = mcp_serve.create_mcp_server(event_bridge=bridge)
|
||||
return server, bridge
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. UNIT TESTS — helpers, extraction, attachments
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -229,6 +278,15 @@ class TestHelpers:
|
|||
result = _get_sessions_dir()
|
||||
assert result == tmp_path / "sessions"
|
||||
|
||||
def test_coerce_int_handles_invalid_and_out_of_range_values(self):
|
||||
from mcp_serve import _coerce_int
|
||||
|
||||
assert _coerce_int(None, default=50, minimum=1, maximum=200) == 50
|
||||
assert _coerce_int("20", default=50, minimum=1, maximum=200) == 20
|
||||
assert _coerce_int("bad", default=50, minimum=1, maximum=200) == 50
|
||||
assert _coerce_int(999, default=50, minimum=1, maximum=200) == 200
|
||||
assert _coerce_int(-5, default=50, minimum=1, maximum=200) == 1
|
||||
|
||||
def test_load_sessions_index_empty(self, sessions_dir, monkeypatch):
|
||||
import mcp_serve
|
||||
monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: sessions_dir)
|
||||
|
|
@ -689,6 +747,49 @@ class TestE2EEventsWait:
|
|||
result = _run_tool(server, "events_wait", {"timeout_ms": 999999})
|
||||
assert result["event"] is not None
|
||||
|
||||
class TestMCPToolParameterCoercion:
|
||||
def test_conversations_list_coerces_string_limit(self, fake_mcp_server, _event_loop):
|
||||
server, _ = fake_mcp_server
|
||||
result = _run_tool(server, "conversations_list", {"limit": "2"})
|
||||
assert result["count"] == 2
|
||||
|
||||
def test_messages_read_coerces_string_limit(self, fake_mcp_server, _event_loop):
|
||||
server, _ = fake_mcp_server
|
||||
result = _run_tool(
|
||||
server,
|
||||
"messages_read",
|
||||
{"session_key": "agent:main:telegram:dm:123456", "limit": "2"},
|
||||
)
|
||||
assert result["count"] == 2
|
||||
|
||||
def test_events_poll_coerces_string_cursor_and_limit(self, fake_mcp_server, _event_loop):
|
||||
from mcp_serve import QueueEvent
|
||||
|
||||
server, bridge = fake_mcp_server
|
||||
bridge._enqueue(QueueEvent(cursor=0, type="message", session_key="a"))
|
||||
bridge._enqueue(QueueEvent(cursor=0, type="message", session_key="b"))
|
||||
|
||||
result = _run_tool(server, "events_poll", {"after_cursor": "0", "limit": "1"})
|
||||
assert len(result["events"]) == 1
|
||||
assert result["next_cursor"] == 1
|
||||
|
||||
def test_events_wait_coerces_invalid_timeout(self, fake_mcp_server, _event_loop):
|
||||
from mcp_serve import QueueEvent
|
||||
|
||||
server, bridge = fake_mcp_server
|
||||
bridge._enqueue(
|
||||
QueueEvent(
|
||||
cursor=0,
|
||||
type="message",
|
||||
session_key="test",
|
||||
data={"content": "waiting for this"},
|
||||
)
|
||||
)
|
||||
|
||||
result = _run_tool(server, "events_wait", {"after_cursor": "0", "timeout_ms": "bad"})
|
||||
assert result["event"] is not None
|
||||
assert result["event"]["content"] == "waiting for this"
|
||||
|
||||
|
||||
class TestE2EMessagesSend:
|
||||
def test_send_missing_args(self, mcp_server_e2e, _event_loop):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue