hermes-agent/tests/tools/test_mcp_tool_session_expired.py
Brian D. Evans e87a2100f6 fix(mcp): auto-reconnect + retry once when the transport session expires (#13383)
Streamable HTTP MCP servers may garbage-collect their server-side
session state while the OAuth token remains valid — idle TTL, server
restart, pod rotation, etc.  Before this fix, the tool-call handler
treated the resulting "Invalid or expired session" error as a plain
tool failure with no recovery path, so **every subsequent call on
the affected server failed until the gateway was manually
restarted**.  Reporter: #13383.

The OAuth-based recovery path (``_handle_auth_error_and_retry``)
already exists for 401s, but it only fires on auth errors.  Session
expiry slipped through because the access token is still valid —
nothing 401'd, so the existing recovery branch was skipped.

Fix
---
Add a sibling function ``_handle_session_expired_and_retry`` that
detects MCP session-expiry via ``_is_session_expired_error`` (a
narrow allow-list of known-stable substrings: ``"invalid or expired
session"``, ``"session expired"``, ``"session not found"``,
``"unknown session"``, etc.) and then uses the existing transport
reconnect mechanism:

* Sets ``MCPServerTask._reconnect_event`` — the server task's
  lifecycle loop already interprets this as "tear down the current
  ``streamablehttp_client`` + ``ClientSession`` and rebuild them,
  reusing the existing OAuth provider instance".
* Waits up to 15 s for the new session to come back ready.
* Retries the original call once.  If the retry succeeds, returns
  its result and resets the circuit-breaker error count.  If the
  retry raises, or if the reconnect doesn't ready in time, falls
  through to the caller's generic error path.

Unlike the 401 path, this does **not** call ``handle_401`` — the
access token is already valid and running an OAuth refresh would be
a pointless round-trip.

All 5 MCP handlers (``call_tool``, ``list_resources``, ``read_resource``,
``list_prompts``, ``get_prompt``) now consult both recovery paths
before falling through:

    recovered = _handle_auth_error_and_retry(...)          # 401 path
    if recovered is not None: return recovered
    recovered = _handle_session_expired_and_retry(...)     # new
    if recovered is not None: return recovered
    # generic error response

Narrow scope — explicitly not changed
-------------------------------------
* **Detection is string-based on a 5-entry allow-list.**  The MCP
  SDK wraps JSON-RPC errors in ``McpError`` whose exception type +
  attributes vary across SDK versions, so matching on message
  substrings is the durable path.  Kept narrow to avoid false
  positives — a regular ``RuntimeError("Tool failed")`` will NOT
  trigger spurious reconnects (pinned by
  ``test_is_session_expired_rejects_unrelated_errors``).
* **No change to the existing 401 recovery flow.**  The new path is
  consulted only after the auth path declines (returns ``None``).
* **Retry count stays at 1.**  If the reconnect-then-retry also
  fails, we don't loop — the error surfaces normally so the model
  sees a failed tool call rather than a hang.
* **``InterruptedError`` is explicitly excluded** from session-expired
  detection so user-cancel signals always short-circuit the same
  way they did before (pinned by
  ``test_is_session_expired_rejects_interrupted_error``).

Regression coverage
-------------------
``tests/tools/test_mcp_tool_session_expired.py`` (new, 16 cases):

Unit tests for ``_is_session_expired_error``:
* ``test_is_session_expired_detects_invalid_or_expired_session`` —
  reporter's exact wpcom-mcp text.
* ``test_is_session_expired_detects_expired_session_variant`` —
  "Session expired" / "expired session" variants.
* ``test_is_session_expired_detects_session_not_found`` — server GC
  variant ("session not found", "unknown session").
* ``test_is_session_expired_is_case_insensitive``.
* ``test_is_session_expired_rejects_unrelated_errors`` — narrow-scope
  canary: random RuntimeError / ValueError / 401 don't trigger.
* ``test_is_session_expired_rejects_interrupted_error`` — user cancel
  must never route through reconnect.
* ``test_is_session_expired_rejects_empty_message``.

Handler integration tests:
* ``test_call_tool_handler_reconnects_on_session_expired`` — reporter's
  full repro: first call raises "Invalid or expired session", handler
  signals ``_reconnect_event``, retries once, returns the retry's
  success result with no ``error`` key.
* ``test_call_tool_handler_non_session_expired_error_falls_through``
  — preserved-behaviour canary: random tool failures do NOT trigger
  reconnect.
* ``test_session_expired_handler_returns_none_without_loop`` —
  defensive: cold-start / shutdown race.
* ``test_session_expired_handler_returns_none_without_server_record``
  — torn-down server falls through cleanly.
* ``test_session_expired_handler_returns_none_when_retry_also_fails``
  — no retry loop on repeated failure.

Parametrised across all 4 non-``tools/call`` handlers:
* ``test_non_tool_handlers_also_reconnect_on_session_expired``
  [list_resources / read_resource / list_prompts / get_prompt].

**15 of 16 fail on clean ``origin/main`` (``6fb69229``)** with
``ImportError: cannot import name '_is_session_expired_error'``
— the fix's surface symbols don't exist there yet.  The 1 passing
test is an ordering artefact of pytest-xdist worker collection.

Validation
----------
``source venv/bin/activate && python -m pytest
tests/tools/test_mcp_tool_session_expired.py -q`` → **16 passed**.

Broader MCP suite (5 files:
``test_mcp_tool.py``, ``test_mcp_tool_401_handling.py``,
``test_mcp_tool_session_expired.py``, ``test_mcp_reconnect_signal.py``,
``test_mcp_oauth.py``) → **230 passed, 0 regressions**.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-24 05:28:45 -07:00

359 lines
14 KiB
Python

"""Tests for MCP tool-handler transport-session auto-reconnect.
When a Streamable HTTP MCP server garbage-collects its server-side
session (idle TTL, server restart, pod rotation, …) it rejects
subsequent requests with a JSON-RPC error containing phrases like
``"Invalid or expired session"``. The OAuth token remains valid —
only the transport session state needs rebuilding.
Before the #13383 fix, this class of failure fell through as a plain
tool error with no recovery path, so every subsequent call on the
affected MCP server failed until the gateway was manually restarted.
"""
import json
import threading
import time
from unittest.mock import AsyncMock, MagicMock
import pytest
# ---------------------------------------------------------------------------
# _is_session_expired_error — unit coverage
# ---------------------------------------------------------------------------
def test_is_session_expired_detects_invalid_or_expired_session():
"""Reporter's exact wpcom-mcp error message (#13383)."""
from tools.mcp_tool import _is_session_expired_error
exc = RuntimeError("Invalid params: Invalid or expired session")
assert _is_session_expired_error(exc) is True
def test_is_session_expired_detects_expired_session_variant():
"""Generic ``session expired`` / ``expired session`` phrasings used
by other SDK servers."""
from tools.mcp_tool import _is_session_expired_error
assert _is_session_expired_error(RuntimeError("Session expired")) is True
assert _is_session_expired_error(RuntimeError("expired session: abc")) is True
def test_is_session_expired_detects_session_not_found():
"""Server-side GC produces ``session not found`` / ``unknown session``
on some implementations."""
from tools.mcp_tool import _is_session_expired_error
assert _is_session_expired_error(RuntimeError("session not found")) is True
assert _is_session_expired_error(RuntimeError("Unknown session: abc123")) is True
def test_is_session_expired_is_case_insensitive():
"""Match uses lower-cased comparison so servers that emit the
message in different cases (SDK formatter quirks) still trigger."""
from tools.mcp_tool import _is_session_expired_error
assert _is_session_expired_error(RuntimeError("INVALID OR EXPIRED SESSION")) is True
assert _is_session_expired_error(RuntimeError("Session Expired")) is True
def test_is_session_expired_rejects_unrelated_errors():
"""Narrow scope: only the specific session-expired markers trigger.
A regular RuntimeError / ValueError does not."""
from tools.mcp_tool import _is_session_expired_error
assert _is_session_expired_error(RuntimeError("Tool failed to execute")) is False
assert _is_session_expired_error(ValueError("Missing parameter")) is False
assert _is_session_expired_error(Exception("Connection refused")) is False
# 401 is handled by the sibling _is_auth_error path, not here.
assert _is_session_expired_error(RuntimeError("401 Unauthorized")) is False
def test_is_session_expired_rejects_interrupted_error():
"""InterruptedError is the user-cancel signal — must never route
through the session-reconnect path."""
from tools.mcp_tool import _is_session_expired_error
assert _is_session_expired_error(InterruptedError()) is False
assert _is_session_expired_error(InterruptedError("Invalid or expired session")) is False
def test_is_session_expired_rejects_empty_message():
"""Bare exceptions with no message shouldn't match."""
from tools.mcp_tool import _is_session_expired_error
assert _is_session_expired_error(RuntimeError("")) is False
assert _is_session_expired_error(Exception()) is False
# ---------------------------------------------------------------------------
# Handler integration — verify the recovery plumbing wires end-to-end
# ---------------------------------------------------------------------------
def _install_stub_server(name: str = "wpcom"):
"""Register a minimal server stub that _handle_session_expired_and_retry
can signal via _reconnect_event, and that reports ready+session after
the event fires."""
from tools import mcp_tool
mcp_tool._ensure_mcp_loop()
server = MagicMock()
server.name = name
# _reconnect_event is called via loop.call_soon_threadsafe(…set); use
# a threading-safe substitute.
reconnect_flag = threading.Event()
class _EventAdapter:
def set(self):
reconnect_flag.set()
server._reconnect_event = _EventAdapter()
# Immediately "ready" — simulates a fast reconnect (_ready.is_set()
# is polled by _handle_session_expired_and_retry until the timeout).
ready_flag = threading.Event()
ready_flag.set()
server._ready = MagicMock()
server._ready.is_set = ready_flag.is_set
# session attr must be truthy for the handler's initial check
# (``if not server or not server.session``) and for the post-
# reconnect readiness probe (``srv.session is not None``).
server.session = MagicMock()
return server, reconnect_flag
def test_call_tool_handler_reconnects_on_session_expired(monkeypatch, tmp_path):
"""Reporter's exact repro: call_tool raises "Invalid or expired
session", handler triggers reconnect, retries once, and returns
the retry's successful JSON (not the generic error)."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools import mcp_tool
from tools.mcp_tool import _make_tool_handler
server, reconnect_flag = _install_stub_server("wpcom")
mcp_tool._servers["wpcom"] = server
mcp_tool._server_error_counts.pop("wpcom", None)
# First call raises session-expired; second call (post-reconnect)
# returns a proper MCP tool result.
call_count = {"n": 0}
async def _call_sequence(*a, **kw):
call_count["n"] += 1
if call_count["n"] == 1:
raise RuntimeError("Invalid params: Invalid or expired session")
# Second call: mimic the MCP SDK's structured success response.
result = MagicMock()
result.isError = False
result.content = [MagicMock(type="text", text="tool completed")]
result.structuredContent = None
return result
server.session.call_tool = _call_sequence
try:
handler = _make_tool_handler("wpcom", "wpcom-mcp-content-authoring", 10.0)
out = handler({"slug": "hello"})
parsed = json.loads(out)
# Retry succeeded — no error surfaced to caller.
assert "error" not in parsed, (
f"Expected retry to succeed after reconnect; got: {parsed}"
)
# _reconnect_event was signalled exactly once.
assert reconnect_flag.is_set(), (
"Handler did not trigger transport reconnect on session-expired "
"error — the reconnect flow is the whole point of this fix."
)
# Exactly 2 call attempts (original + one retry).
assert call_count["n"] == 2, (
f"Expected 1 original + 1 retry = 2 calls; got {call_count['n']}"
)
finally:
mcp_tool._servers.pop("wpcom", None)
mcp_tool._server_error_counts.pop("wpcom", None)
def test_call_tool_handler_non_session_expired_error_falls_through(
monkeypatch, tmp_path
):
"""Preserved-behaviour canary: a non-session-expired exception must
NOT trigger reconnect — it must fall through to the generic error
path so the caller sees the real failure."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools import mcp_tool
from tools.mcp_tool import _make_tool_handler
server, reconnect_flag = _install_stub_server("srv")
mcp_tool._servers["srv"] = server
mcp_tool._server_error_counts.pop("srv", None)
async def _raises(*a, **kw):
raise RuntimeError("Tool execution failed — unrelated error")
server.session.call_tool = _raises
try:
handler = _make_tool_handler("srv", "mytool", 10.0)
out = handler({"arg": "v"})
parsed = json.loads(out)
# Generic error path surfaced the failure.
assert "MCP call failed" in parsed.get("error", "")
# Reconnect was NOT triggered for this unrelated failure.
assert not reconnect_flag.is_set(), (
"Reconnect must not fire for non-session-expired errors — "
"this would cause spurious transport churn on every tool "
"failure."
)
finally:
mcp_tool._servers.pop("srv", None)
mcp_tool._server_error_counts.pop("srv", None)
def test_session_expired_handler_returns_none_without_loop(monkeypatch):
"""Defensive: if the MCP loop isn't running (cold start / shutdown
race), the handler must fall through cleanly instead of hanging
or raising."""
from tools import mcp_tool
from tools.mcp_tool import _handle_session_expired_and_retry
# Install a server stub but make the event loop unavailable.
server = MagicMock()
server._reconnect_event = MagicMock()
server._ready = MagicMock()
server._ready.is_set = MagicMock(return_value=True)
server.session = MagicMock()
mcp_tool._servers["srv-noloop"] = server
monkeypatch.setattr(mcp_tool, "_mcp_loop", None)
try:
out = _handle_session_expired_and_retry(
"srv-noloop",
RuntimeError("Invalid or expired session"),
lambda: '{"ok": true}',
"tools/call",
)
assert out is None, (
"Without an event loop, session-expired handler must fall "
"through to caller's generic error path — not hang or raise."
)
finally:
mcp_tool._servers.pop("srv-noloop", None)
def test_session_expired_handler_returns_none_without_server_record():
"""If the server has been torn down / isn't in _servers, fall
through cleanly — nothing to reconnect to."""
from tools.mcp_tool import _handle_session_expired_and_retry
out = _handle_session_expired_and_retry(
"does-not-exist",
RuntimeError("Invalid or expired session"),
lambda: '{"ok": true}',
"tools/call",
)
assert out is None
def test_session_expired_handler_returns_none_when_retry_also_fails(
monkeypatch, tmp_path
):
"""If the retry after reconnect also raises, fall through to the
generic error path (don't loop forever, don't mask the second
failure)."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools import mcp_tool
from tools.mcp_tool import _handle_session_expired_and_retry
server, _ = _install_stub_server("srv-retry-fail")
mcp_tool._servers["srv-retry-fail"] = server
def _retry_raises():
raise RuntimeError("retry blew up too")
try:
out = _handle_session_expired_and_retry(
"srv-retry-fail",
RuntimeError("Invalid or expired session"),
_retry_raises,
"tools/call",
)
assert out is None, (
"When the retry itself fails, the handler must return None "
"so the caller's generic error path runs — no retry loop."
)
finally:
mcp_tool._servers.pop("srv-retry-fail", None)
# ---------------------------------------------------------------------------
# Parallel coverage for resources/list, resources/read, prompts/list,
# prompts/get — all four handlers share the same exception path.
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"handler_factory, handler_kwargs, session_method, op_label",
[
("_make_list_resources_handler", {"tool_timeout": 10.0}, "list_resources", "list_resources"),
("_make_read_resource_handler", {"tool_timeout": 10.0}, "read_resource", "read_resource"),
("_make_list_prompts_handler", {"tool_timeout": 10.0}, "list_prompts", "list_prompts"),
("_make_get_prompt_handler", {"tool_timeout": 10.0}, "get_prompt", "get_prompt"),
],
)
def test_non_tool_handlers_also_reconnect_on_session_expired(
monkeypatch, tmp_path, handler_factory, handler_kwargs, session_method, op_label
):
"""All four non-``tools/call`` MCP handlers share the recovery
pattern and must reconnect the same way on session-expired."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools import mcp_tool
server, reconnect_flag = _install_stub_server(f"srv-{op_label}")
mcp_tool._servers[f"srv-{op_label}"] = server
mcp_tool._server_error_counts.pop(f"srv-{op_label}", None)
call_count = {"n": 0}
async def _sequence(*a, **kw):
call_count["n"] += 1
if call_count["n"] == 1:
raise RuntimeError("Invalid or expired session")
# Return something with the shapes each handler expects.
# Explicitly set primitive attrs — MagicMock's default auto-attr
# behaviour surfaces ``MagicMock`` values for optional fields
# like ``description``, which break ``json.dumps`` downstream.
result = MagicMock()
result.resources = []
result.prompts = []
result.contents = []
result.messages = [] # get_prompt
result.description = None # get_prompt optional field
return result
setattr(server.session, session_method, _sequence)
factory = getattr(mcp_tool, handler_factory)
# list_resources / list_prompts take (server_name, timeout).
# read_resource / get_prompt take the same signature.
try:
handler = factory(f"srv-{op_label}", **handler_kwargs)
if op_label == "read_resource":
out = handler({"uri": "file://foo"})
elif op_label == "get_prompt":
out = handler({"name": "p1"})
else:
out = handler({})
parsed = json.loads(out)
assert "error" not in parsed, (
f"{op_label}: expected retry success, got {parsed}"
)
assert reconnect_flag.is_set(), (
f"{op_label}: reconnect should fire for session-expired"
)
assert call_count["n"] == 2, (
f"{op_label}: expected 1 original + 1 retry"
)
finally:
mcp_tool._servers.pop(f"srv-{op_label}", None)
mcp_tool._server_error_counts.pop(f"srv-{op_label}", None)