fix(mcp): coerce numeric tool args defensively

This commit is contained in:
qWaitCrypto 2026-05-07 14:05:26 +08:00 committed by Teknium
parent 43cf72a458
commit 62c2f5d8d2
2 changed files with 132 additions and 1 deletions

View file

@ -115,6 +115,25 @@ def _load_channel_directory() -> dict:
return {}
def _coerce_int(
value,
*,
default: int,
minimum: int,
maximum: int,
) -> int:
"""Coerce value to int with fallback and clamping.
Used at MCP tool boundaries to handle invalid types from external clients.
Returns default if value cannot be converted to int.
"""
try:
coerced = int(value)
except (TypeError, ValueError):
coerced = default
return max(minimum, min(coerced, maximum))
def _extract_message_content(msg: dict) -> str:
"""Extract text content from a message, handling multi-part content."""
content = msg.get("content", "")
@ -465,6 +484,7 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP":
limit: Maximum number of conversations to return (default 50)
search: Optional text to filter conversations by name
"""
limit = _coerce_int(limit, default=50, minimum=1, maximum=200)
entries = _load_sessions_index()
conversations = []
@ -552,6 +572,7 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP":
session_key: The session key from conversations_list
limit: Maximum number of messages to return (default 50, most recent)
"""
limit = _coerce_int(limit, default=50, minimum=1, maximum=200)
entries = _load_sessions_index()
entry = entries.get(session_key)
if not entry:
@ -664,6 +685,8 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP":
session_key: Optional filter to one conversation
limit: Maximum events to return (default 20)
"""
after_cursor = _coerce_int(after_cursor, default=0, minimum=0, maximum=10**18)
limit = _coerce_int(limit, default=20, minimum=1, maximum=200)
result = bridge.poll_events(
after_cursor=after_cursor,
session_key=session_key,
@ -689,10 +712,17 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP":
session_key: Optional filter to one conversation
timeout_ms: Maximum wait time in milliseconds (default 30000)
"""
after_cursor = _coerce_int(after_cursor, default=0, minimum=0, maximum=10**18)
timeout_ms = _coerce_int(
timeout_ms,
default=30000,
minimum=0,
maximum=300000,
) # Cap at 5 minutes
event = bridge.wait_for_event(
after_cursor=after_cursor,
session_key=session_key,
timeout_ms=min(timeout_ms, 300000), # Cap at 5 minutes
timeout_ms=timeout_ms,
)
if event:
return json.dumps({"event": event}, indent=2)

View file

@ -9,6 +9,7 @@ Three layers of tests:
"""
import asyncio
import inspect
import json
import os
import sqlite3
@ -207,6 +208,54 @@ def mock_session_db(tmp_path, populated_sessions_dir):
return TestSessionDB()
class _FakeTool:
def __init__(self, fn):
self.name = fn.__name__
self.description = inspect.getdoc(fn) or ""
self.fn = fn
class _FakeToolManager:
def __init__(self):
self._tools = {}
def add_tool(self, fn):
self._tools[fn.__name__] = _FakeTool(fn)
async def call_tool(self, name, args=None):
return self._tools[name].fn(**(args or {}))
def list_tools(self):
return list(self._tools.values())
class _FakeFastMCP:
def __init__(self, *args, **kwargs):
self._tool_manager = _FakeToolManager()
def tool(self):
def decorator(fn):
self._tool_manager.add_tool(fn)
return fn
return decorator
@pytest.fixture
def fake_mcp_server(populated_sessions_dir, mock_session_db, monkeypatch):
import mcp_serve
monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: populated_sessions_dir)
monkeypatch.setattr(mcp_serve, "_get_session_db", lambda: mock_session_db)
monkeypatch.setattr(mcp_serve, "_load_channel_directory", lambda: {})
monkeypatch.setattr(mcp_serve, "_MCP_SERVER_AVAILABLE", True)
monkeypatch.setattr(mcp_serve, "FastMCP", _FakeFastMCP)
bridge = mcp_serve.EventBridge()
server = mcp_serve.create_mcp_server(event_bridge=bridge)
return server, bridge
# ---------------------------------------------------------------------------
# 1. UNIT TESTS — helpers, extraction, attachments
# ---------------------------------------------------------------------------
@ -229,6 +278,15 @@ class TestHelpers:
result = _get_sessions_dir()
assert result == tmp_path / "sessions"
def test_coerce_int_handles_invalid_and_out_of_range_values(self):
from mcp_serve import _coerce_int
assert _coerce_int(None, default=50, minimum=1, maximum=200) == 50
assert _coerce_int("20", default=50, minimum=1, maximum=200) == 20
assert _coerce_int("bad", default=50, minimum=1, maximum=200) == 50
assert _coerce_int(999, default=50, minimum=1, maximum=200) == 200
assert _coerce_int(-5, default=50, minimum=1, maximum=200) == 1
def test_load_sessions_index_empty(self, sessions_dir, monkeypatch):
import mcp_serve
monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: sessions_dir)
@ -689,6 +747,49 @@ class TestE2EEventsWait:
result = _run_tool(server, "events_wait", {"timeout_ms": 999999})
assert result["event"] is not None
class TestMCPToolParameterCoercion:
def test_conversations_list_coerces_string_limit(self, fake_mcp_server, _event_loop):
server, _ = fake_mcp_server
result = _run_tool(server, "conversations_list", {"limit": "2"})
assert result["count"] == 2
def test_messages_read_coerces_string_limit(self, fake_mcp_server, _event_loop):
server, _ = fake_mcp_server
result = _run_tool(
server,
"messages_read",
{"session_key": "agent:main:telegram:dm:123456", "limit": "2"},
)
assert result["count"] == 2
def test_events_poll_coerces_string_cursor_and_limit(self, fake_mcp_server, _event_loop):
from mcp_serve import QueueEvent
server, bridge = fake_mcp_server
bridge._enqueue(QueueEvent(cursor=0, type="message", session_key="a"))
bridge._enqueue(QueueEvent(cursor=0, type="message", session_key="b"))
result = _run_tool(server, "events_poll", {"after_cursor": "0", "limit": "1"})
assert len(result["events"]) == 1
assert result["next_cursor"] == 1
def test_events_wait_coerces_invalid_timeout(self, fake_mcp_server, _event_loop):
from mcp_serve import QueueEvent
server, bridge = fake_mcp_server
bridge._enqueue(
QueueEvent(
cursor=0,
type="message",
session_key="test",
data={"content": "waiting for this"},
)
)
result = _run_tool(server, "events_wait", {"after_cursor": "0", "timeout_ms": "bad"})
assert result["event"] is not None
assert result["event"]["content"] == "waiting for this"
class TestE2EMessagesSend:
def test_send_missing_args(self, mcp_server_e2e, _event_loop):