From 62c2f5d8d2a6a21adfdea2d8d1f28fd8f04b5dd7 Mon Sep 17 00:00:00 2001 From: qWaitCrypto Date: Thu, 7 May 2026 14:05:26 +0800 Subject: [PATCH] fix(mcp): coerce numeric tool args defensively --- mcp_serve.py | 32 ++++++++++++- tests/test_mcp_serve.py | 101 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 1 deletion(-) diff --git a/mcp_serve.py b/mcp_serve.py index e0aeb70619..d895120b18 100644 --- a/mcp_serve.py +++ b/mcp_serve.py @@ -115,6 +115,25 @@ def _load_channel_directory() -> dict: return {} +def _coerce_int( + value, + *, + default: int, + minimum: int, + maximum: int, +) -> int: + """Coerce value to int with fallback and clamping. + + Used at MCP tool boundaries to handle invalid types from external clients. + Returns default if value cannot be converted to int. + """ + try: + coerced = int(value) + except (TypeError, ValueError): + coerced = default + return max(minimum, min(coerced, maximum)) + + def _extract_message_content(msg: dict) -> str: """Extract text content from a message, handling multi-part content.""" content = msg.get("content", "") @@ -465,6 +484,7 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP": limit: Maximum number of conversations to return (default 50) search: Optional text to filter conversations by name """ + limit = _coerce_int(limit, default=50, minimum=1, maximum=200) entries = _load_sessions_index() conversations = [] @@ -552,6 +572,7 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP": session_key: The session key from conversations_list limit: Maximum number of messages to return (default 50, most recent) """ + limit = _coerce_int(limit, default=50, minimum=1, maximum=200) entries = _load_sessions_index() entry = entries.get(session_key) if not entry: @@ -664,6 +685,8 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP": session_key: Optional filter to one conversation limit: Maximum events to return (default 20) """ + after_cursor = _coerce_int(after_cursor, default=0, minimum=0, maximum=10**18) + limit = _coerce_int(limit, default=20, minimum=1, maximum=200) result = bridge.poll_events( after_cursor=after_cursor, session_key=session_key, @@ -689,10 +712,17 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP": session_key: Optional filter to one conversation timeout_ms: Maximum wait time in milliseconds (default 30000) """ + after_cursor = _coerce_int(after_cursor, default=0, minimum=0, maximum=10**18) + timeout_ms = _coerce_int( + timeout_ms, + default=30000, + minimum=0, + maximum=300000, + ) # Cap at 5 minutes event = bridge.wait_for_event( after_cursor=after_cursor, session_key=session_key, - timeout_ms=min(timeout_ms, 300000), # Cap at 5 minutes + timeout_ms=timeout_ms, ) if event: return json.dumps({"event": event}, indent=2) diff --git a/tests/test_mcp_serve.py b/tests/test_mcp_serve.py index 9dc013cace..db82fa7882 100644 --- a/tests/test_mcp_serve.py +++ b/tests/test_mcp_serve.py @@ -9,6 +9,7 @@ Three layers of tests: """ import asyncio +import inspect import json import os import sqlite3 @@ -207,6 +208,54 @@ def mock_session_db(tmp_path, populated_sessions_dir): return TestSessionDB() +class _FakeTool: + def __init__(self, fn): + self.name = fn.__name__ + self.description = inspect.getdoc(fn) or "" + self.fn = fn + + +class _FakeToolManager: + def __init__(self): + self._tools = {} + + def add_tool(self, fn): + self._tools[fn.__name__] = _FakeTool(fn) + + async def call_tool(self, name, args=None): + return self._tools[name].fn(**(args or {})) + + def list_tools(self): + return list(self._tools.values()) + + +class _FakeFastMCP: + def __init__(self, *args, **kwargs): + self._tool_manager = _FakeToolManager() + + def tool(self): + def decorator(fn): + self._tool_manager.add_tool(fn) + return fn + + return decorator + + +@pytest.fixture +def fake_mcp_server(populated_sessions_dir, mock_session_db, monkeypatch): + import mcp_serve + + monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: populated_sessions_dir) + monkeypatch.setattr(mcp_serve, "_get_session_db", lambda: mock_session_db) + monkeypatch.setattr(mcp_serve, "_load_channel_directory", lambda: {}) + monkeypatch.setattr(mcp_serve, "_MCP_SERVER_AVAILABLE", True) + monkeypatch.setattr(mcp_serve, "FastMCP", _FakeFastMCP) + + bridge = mcp_serve.EventBridge() + server = mcp_serve.create_mcp_server(event_bridge=bridge) + return server, bridge + + # --------------------------------------------------------------------------- # 1. UNIT TESTS — helpers, extraction, attachments # --------------------------------------------------------------------------- @@ -229,6 +278,15 @@ class TestHelpers: result = _get_sessions_dir() assert result == tmp_path / "sessions" + def test_coerce_int_handles_invalid_and_out_of_range_values(self): + from mcp_serve import _coerce_int + + assert _coerce_int(None, default=50, minimum=1, maximum=200) == 50 + assert _coerce_int("20", default=50, minimum=1, maximum=200) == 20 + assert _coerce_int("bad", default=50, minimum=1, maximum=200) == 50 + assert _coerce_int(999, default=50, minimum=1, maximum=200) == 200 + assert _coerce_int(-5, default=50, minimum=1, maximum=200) == 1 + def test_load_sessions_index_empty(self, sessions_dir, monkeypatch): import mcp_serve monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: sessions_dir) @@ -689,6 +747,49 @@ class TestE2EEventsWait: result = _run_tool(server, "events_wait", {"timeout_ms": 999999}) assert result["event"] is not None +class TestMCPToolParameterCoercion: + def test_conversations_list_coerces_string_limit(self, fake_mcp_server, _event_loop): + server, _ = fake_mcp_server + result = _run_tool(server, "conversations_list", {"limit": "2"}) + assert result["count"] == 2 + + def test_messages_read_coerces_string_limit(self, fake_mcp_server, _event_loop): + server, _ = fake_mcp_server + result = _run_tool( + server, + "messages_read", + {"session_key": "agent:main:telegram:dm:123456", "limit": "2"}, + ) + assert result["count"] == 2 + + def test_events_poll_coerces_string_cursor_and_limit(self, fake_mcp_server, _event_loop): + from mcp_serve import QueueEvent + + server, bridge = fake_mcp_server + bridge._enqueue(QueueEvent(cursor=0, type="message", session_key="a")) + bridge._enqueue(QueueEvent(cursor=0, type="message", session_key="b")) + + result = _run_tool(server, "events_poll", {"after_cursor": "0", "limit": "1"}) + assert len(result["events"]) == 1 + assert result["next_cursor"] == 1 + + def test_events_wait_coerces_invalid_timeout(self, fake_mcp_server, _event_loop): + from mcp_serve import QueueEvent + + server, bridge = fake_mcp_server + bridge._enqueue( + QueueEvent( + cursor=0, + type="message", + session_key="test", + data={"content": "waiting for this"}, + ) + ) + + result = _run_tool(server, "events_wait", {"after_cursor": "0", "timeout_ms": "bad"}) + assert result["event"] is not None + assert result["event"]["content"] == "waiting for this" + class TestE2EMessagesSend: def test_send_missing_args(self, mcp_server_e2e, _event_loop):