fix(mcp): coerce numeric tool args defensively

This commit is contained in:
qWaitCrypto 2026-05-07 14:05:26 +08:00 committed by Teknium
parent 43cf72a458
commit 62c2f5d8d2
2 changed files with 132 additions and 1 deletions

View file

@ -9,6 +9,7 @@ Three layers of tests:
"""
import asyncio
import inspect
import json
import os
import sqlite3
@ -207,6 +208,54 @@ def mock_session_db(tmp_path, populated_sessions_dir):
return TestSessionDB()
class _FakeTool:
def __init__(self, fn):
self.name = fn.__name__
self.description = inspect.getdoc(fn) or ""
self.fn = fn
class _FakeToolManager:
def __init__(self):
self._tools = {}
def add_tool(self, fn):
self._tools[fn.__name__] = _FakeTool(fn)
async def call_tool(self, name, args=None):
return self._tools[name].fn(**(args or {}))
def list_tools(self):
return list(self._tools.values())
class _FakeFastMCP:
def __init__(self, *args, **kwargs):
self._tool_manager = _FakeToolManager()
def tool(self):
def decorator(fn):
self._tool_manager.add_tool(fn)
return fn
return decorator
@pytest.fixture
def fake_mcp_server(populated_sessions_dir, mock_session_db, monkeypatch):
import mcp_serve
monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: populated_sessions_dir)
monkeypatch.setattr(mcp_serve, "_get_session_db", lambda: mock_session_db)
monkeypatch.setattr(mcp_serve, "_load_channel_directory", lambda: {})
monkeypatch.setattr(mcp_serve, "_MCP_SERVER_AVAILABLE", True)
monkeypatch.setattr(mcp_serve, "FastMCP", _FakeFastMCP)
bridge = mcp_serve.EventBridge()
server = mcp_serve.create_mcp_server(event_bridge=bridge)
return server, bridge
# ---------------------------------------------------------------------------
# 1. UNIT TESTS — helpers, extraction, attachments
# ---------------------------------------------------------------------------
@ -229,6 +278,15 @@ class TestHelpers:
result = _get_sessions_dir()
assert result == tmp_path / "sessions"
def test_coerce_int_handles_invalid_and_out_of_range_values(self):
from mcp_serve import _coerce_int
assert _coerce_int(None, default=50, minimum=1, maximum=200) == 50
assert _coerce_int("20", default=50, minimum=1, maximum=200) == 20
assert _coerce_int("bad", default=50, minimum=1, maximum=200) == 50
assert _coerce_int(999, default=50, minimum=1, maximum=200) == 200
assert _coerce_int(-5, default=50, minimum=1, maximum=200) == 1
def test_load_sessions_index_empty(self, sessions_dir, monkeypatch):
import mcp_serve
monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: sessions_dir)
@ -689,6 +747,49 @@ class TestE2EEventsWait:
result = _run_tool(server, "events_wait", {"timeout_ms": 999999})
assert result["event"] is not None
class TestMCPToolParameterCoercion:
def test_conversations_list_coerces_string_limit(self, fake_mcp_server, _event_loop):
server, _ = fake_mcp_server
result = _run_tool(server, "conversations_list", {"limit": "2"})
assert result["count"] == 2
def test_messages_read_coerces_string_limit(self, fake_mcp_server, _event_loop):
server, _ = fake_mcp_server
result = _run_tool(
server,
"messages_read",
{"session_key": "agent:main:telegram:dm:123456", "limit": "2"},
)
assert result["count"] == 2
def test_events_poll_coerces_string_cursor_and_limit(self, fake_mcp_server, _event_loop):
from mcp_serve import QueueEvent
server, bridge = fake_mcp_server
bridge._enqueue(QueueEvent(cursor=0, type="message", session_key="a"))
bridge._enqueue(QueueEvent(cursor=0, type="message", session_key="b"))
result = _run_tool(server, "events_poll", {"after_cursor": "0", "limit": "1"})
assert len(result["events"]) == 1
assert result["next_cursor"] == 1
def test_events_wait_coerces_invalid_timeout(self, fake_mcp_server, _event_loop):
from mcp_serve import QueueEvent
server, bridge = fake_mcp_server
bridge._enqueue(
QueueEvent(
cursor=0,
type="message",
session_key="test",
data={"content": "waiting for this"},
)
)
result = _run_tool(server, "events_wait", {"after_cursor": "0", "timeout_ms": "bad"})
assert result["event"] is not None
assert result["event"]["content"] == "waiting for this"
class TestE2EMessagesSend:
def test_send_missing_args(self, mcp_server_e2e, _event_loop):