fix(mcp): add half-open state to circuit breaker

The MCP circuit breaker previously had no path back to the closed
state: once _server_error_counts[srv] reached _CIRCUIT_BREAKER_THRESHOLD
the gate short-circuited every subsequent call, so the only reset
path (on successful call) was unreachable. A single transient
3-failure blip (bad network, server restart, expired token) permanently
disabled every tool on that MCP server for the rest of the agent
session.

Introduce a classic closed/open/half-open state machine:

- Track a per-server breaker-open timestamp in _server_breaker_opened_at
  alongside the existing failure count.
- Add _CIRCUIT_BREAKER_COOLDOWN_SEC (60s). Once the count reaches
  threshold, calls short-circuit for the cooldown window.
- After the cooldown elapses, the *next* call falls through as a
  half-open probe that actually hits the session. Success resets the
  breaker via _reset_server_error; failure re-bumps the count via
  _bump_server_error, which re-stamps the open timestamp and re-arms
  the cooldown.

The error message now includes the live failure count and an
"Auto-retry available in ~Ns" hint so the model knows the breaker
will self-heal rather than giving up on the tool for the whole
session.

Covers tests 1 (half-opens after cooldown) and 2 (reopens on probe
failure); test 3 (cleared on reconnect) still fails pending fix #2.
This commit is contained in:
Ben 2026-04-21 19:19:13 +10:00 committed by Teknium
parent 724377c429
commit 8cc3cebca2

View file

@ -1249,9 +1249,47 @@ _servers: Dict[str, MCPServerTask] = {}
# _CIRCUIT_BREAKER_THRESHOLD consecutive failures, the handler returns
# a "server unreachable" message that tells the model to stop retrying,
# preventing the 90-iteration burn loop described in #10447.
# Reset to 0 on any successful call.
#
# State machine:
# closed — error count below threshold; all calls go through.
# open — threshold reached; calls short-circuit until the
# cooldown elapses.
# half-open — cooldown elapsed; the next call is a probe that
# actually hits the session. Probe success → closed.
# Probe failure → reopens (cooldown re-armed).
#
# ``_server_breaker_opened_at`` records the monotonic timestamp when
# the breaker most recently transitioned into the open state. Use the
# ``_bump_server_error`` / ``_reset_server_error`` helpers to mutate
# this state — they keep the count and timestamp in sync.
_server_error_counts: Dict[str, int] = {}
_server_breaker_opened_at: Dict[str, float] = {}
_CIRCUIT_BREAKER_THRESHOLD = 3
_CIRCUIT_BREAKER_COOLDOWN_SEC = 60.0
def _bump_server_error(server_name: str) -> None:
"""Increment the consecutive-failure count for ``server_name``.
When the count crosses :data:`_CIRCUIT_BREAKER_THRESHOLD`, stamp the
breaker-open timestamp so the cooldown clock starts (or re-starts,
for probe failures in the half-open state).
"""
n = _server_error_counts.get(server_name, 0) + 1
_server_error_counts[server_name] = n
if n >= _CIRCUIT_BREAKER_THRESHOLD:
_server_breaker_opened_at[server_name] = time.monotonic()
def _reset_server_error(server_name: str) -> None:
"""Fully close the breaker for ``server_name``.
Clears both the failure count and the breaker-open timestamp. Call
this on any unambiguous success signal (successful tool call,
successful reconnect, manual /mcp refresh).
"""
_server_error_counts[server_name] = 0
_server_breaker_opened_at.pop(server_name, None)
# ---------------------------------------------------------------------------
# Auth-failure detection helpers (Task 6 of MCP OAuth consolidation)
@ -1396,10 +1434,10 @@ def _handle_auth_error_and_retry(
try:
parsed = json.loads(result)
if "error" not in parsed:
_server_error_counts[server_name] = 0
_reset_server_error(server_name)
return result
except (json.JSONDecodeError, TypeError):
_server_error_counts[server_name] = 0
_reset_server_error(server_name)
return result
except Exception as retry_exc:
logger.warning(
@ -1410,7 +1448,7 @@ def _handle_auth_error_and_retry(
# No recovery available, or retry also failed: surface a structured
# needs_reauth error. Bumps the circuit breaker so the model stops
# retrying the tool.
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
_bump_server_error(server_name)
return json.dumps({
"error": (
f"MCP server '{server_name}' requires re-authentication. "
@ -1614,20 +1652,33 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
# Circuit breaker: if this server has failed too many times
# consecutively, short-circuit with a clear message so the model
# stops retrying and uses alternative approaches (#10447).
#
# Once the cooldown elapses, the breaker transitions to
# half-open: we let the *next* call through as a probe. On
# success the success-path below resets the breaker; on
# failure the error paths below bump the count again, which
# re-stamps the open-time via _bump_server_error (re-arming
# the cooldown).
if _server_error_counts.get(server_name, 0) >= _CIRCUIT_BREAKER_THRESHOLD:
return json.dumps({
"error": (
f"MCP server '{server_name}' is unreachable after "
f"{_CIRCUIT_BREAKER_THRESHOLD} consecutive failures. "
f"Do NOT retry this tool — use alternative approaches "
f"or ask the user to check the MCP server."
)
}, ensure_ascii=False)
opened_at = _server_breaker_opened_at.get(server_name, 0.0)
age = time.monotonic() - opened_at
if age < _CIRCUIT_BREAKER_COOLDOWN_SEC:
remaining = max(1, int(_CIRCUIT_BREAKER_COOLDOWN_SEC - age))
return json.dumps({
"error": (
f"MCP server '{server_name}' is unreachable after "
f"{_server_error_counts[server_name]} consecutive "
f"failures. Auto-retry available in ~{remaining}s. "
f"Do NOT retry this tool yet — use alternative "
f"approaches or ask the user to check the MCP server."
)
}, ensure_ascii=False)
# Cooldown elapsed → fall through as a half-open probe.
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
_bump_server_error(server_name)
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
}, ensure_ascii=False)
@ -1676,11 +1727,11 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
try:
parsed = json.loads(result)
if "error" in parsed:
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
_bump_server_error(server_name)
else:
_server_error_counts[server_name] = 0 # success — reset
_reset_server_error(server_name) # success — reset
except (json.JSONDecodeError, TypeError):
_server_error_counts[server_name] = 0 # non-JSON = success
_reset_server_error(server_name) # non-JSON = success
return result
except InterruptedError:
return _interrupted_call_result()
@ -1695,7 +1746,7 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
if recovered is not None:
return recovered
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
_bump_server_error(server_name)
logger.error(
"MCP tool %s/%s call failed: %s",
server_name, tool_name, exc,