mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-18 04:41:56 +00:00
fix(async): close unscheduled coroutines in all threadsafe bridges (#26584)
Wraps every sync->async coroutine-scheduling site in the codebase with a new agent.async_utils.safe_schedule_threadsafe() helper that closes the coroutine on scheduling failure (closed loop, shutdown race, etc.) instead of leaking it as 'coroutine was never awaited' RuntimeWarnings plus reference leaks. 22 production call sites migrated across the codebase: - acp_adapter/events.py, acp_adapter/permissions.py - agent/lsp/manager.py - cron/scheduler.py (media + text delivery paths) - gateway/platforms/feishu.py (5 sites, via existing _submit_on_loop helper which now delegates to safe_schedule_threadsafe) - gateway/run.py (10 sites: telegram rename, agent:step hook, status callback, interim+bg-review, clarify send, exec-approval button+text, temp-bubble cleanup, channel-directory refresh) - plugins/memory/hindsight, plugins/platforms/google_chat - tools/browser_supervisor.py (3), browser_cdp_tool.py, computer_use/cua_backend.py, slash_confirm.py - tools/environments/modal.py (_AsyncWorker) - tools/mcp_tool.py (2 + 8 _run_on_mcp_loop callers converted to factory-style so the coroutine is never constructed on a dead loop) - tui_gateway/ws.py Tests: new tests/agent/test_async_utils.py covers helper behavior under live loop, dead loop, None loop, and scheduling exceptions. Regression tests added at three PR-original sites (acp events, acp permissions, mcp loop runner) mirroring contributor's intent. Live-tested end-to-end: - Helper stress test: 1500 schedules across live/dead/race scenarios, zero leaked coroutines - Race exercised: 5000 schedules with loop killed mid-flight, 100 ok / 4900 None returns, zero leaks - hermes chat -q with terminal tool call (exercises step_callback bridge) - MCP probe against failing subprocess servers + factory path - Real gateway daemon boot + SIGINT shutdown across multiple platform adapter inits - WSTransport 100 live + 50 dead-loop writes - Cron delivery path live + dead loop Salvages PR #2657 — adopts contributor's intent over a much wider site list and a single centralized helper instead of inline try/except at each site. 3 of the original PR's 6 sites no longer exist on main (environments/patches.py deleted, DingTalk refactored to native async); the equivalent fix lives in tools/environments/modal.py instead. Co-authored-by: JithendraNara <jithendranaidunara@gmail.com>
This commit is contained in:
parent
931caf2b2d
commit
4e89c53082
23 changed files with 690 additions and 186 deletions
|
|
@ -1,6 +1,8 @@
|
|||
"""Tests for acp_adapter.events — callback factories for ACP notifications."""
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import warnings
|
||||
from concurrent.futures import Future
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
|
@ -10,6 +12,7 @@ import acp
|
|||
from acp.schema import ToolCallStart, ToolCallProgress, AgentThoughtChunk, AgentMessageChunk
|
||||
|
||||
from acp_adapter.events import (
|
||||
_send_update,
|
||||
make_message_cb,
|
||||
make_step_cb,
|
||||
make_thinking_cb,
|
||||
|
|
@ -325,3 +328,46 @@ class TestMessageCallback:
|
|||
cb("")
|
||||
|
||||
mock_rcts.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scheduler-failure regression
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendUpdate:
|
||||
def test_scheduler_failure_closes_update_coroutine(self, event_loop_fixture):
|
||||
"""If run_coroutine_threadsafe raises, _send_update must close the coro."""
|
||||
created = {"coro": None}
|
||||
|
||||
async def _session_update(session_id, update):
|
||||
return None
|
||||
|
||||
conn = MagicMock()
|
||||
|
||||
def _capture_update(session_id, update):
|
||||
created["coro"] = _session_update(session_id, update)
|
||||
return created["coro"]
|
||||
|
||||
conn.session_update = _capture_update
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
with patch(
|
||||
"agent.async_utils.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=RuntimeError("scheduler down"),
|
||||
):
|
||||
_send_update(conn, "session-1", event_loop_fixture, {"type": "noop"})
|
||||
gc.collect()
|
||||
|
||||
assert created["coro"] is not None
|
||||
assert created["coro"].cr_frame is None
|
||||
# Only count warnings about THIS test's coroutine; other tests in the
|
||||
# same xdist worker (or stdlib mock internals) may emit unrelated
|
||||
# "coroutine was never awaited" warnings that bleed through.
|
||||
runtime_warnings = [
|
||||
w for w in caught
|
||||
if issubclass(w.category, RuntimeWarning)
|
||||
and "was never awaited" in str(w.message)
|
||||
and "_session_update" in str(w.message)
|
||||
]
|
||||
assert runtime_warnings == []
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def _invoke_callback(
|
|||
scheduled["loop"] = passed_loop
|
||||
return future
|
||||
|
||||
with patch("acp_adapter.permissions.asyncio.run_coroutine_threadsafe", side_effect=_schedule):
|
||||
with patch("agent.async_utils.asyncio.run_coroutine_threadsafe", side_effect=_schedule):
|
||||
cb = make_approval_callback(request_permission, loop, session_id="s1", timeout=timeout)
|
||||
if use_prompt_path:
|
||||
result = prompt_dangerous_approval(
|
||||
|
|
@ -135,7 +135,7 @@ class TestApprovalBridge:
|
|||
scheduled["loop"] = passed_loop
|
||||
return future
|
||||
|
||||
with patch("acp_adapter.permissions.asyncio.run_coroutine_threadsafe", side_effect=_schedule):
|
||||
with patch("agent.async_utils.asyncio.run_coroutine_threadsafe", side_effect=_schedule):
|
||||
cb = make_approval_callback(request_permission, loop, session_id="s1", timeout=0.01)
|
||||
result = cb("rm -rf /", "dangerous command")
|
||||
|
||||
|
|
@ -159,10 +159,53 @@ class TestApprovalBridge:
|
|||
scheduled["loop"] = passed_loop
|
||||
return future
|
||||
|
||||
with patch("acp_adapter.permissions.asyncio.run_coroutine_threadsafe", side_effect=_schedule):
|
||||
with patch("agent.async_utils.asyncio.run_coroutine_threadsafe", side_effect=_schedule):
|
||||
cb = make_approval_callback(request_permission, loop, session_id="s1", timeout=1.0)
|
||||
result = cb("echo hi", "demo")
|
||||
|
||||
scheduled["coro"].close()
|
||||
|
||||
assert result == "deny"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scheduler-failure regression
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import gc # noqa: E402
|
||||
import warnings # noqa: E402
|
||||
|
||||
|
||||
class TestSchedulerFailure:
|
||||
def test_scheduler_failure_closes_permission_coroutine(self):
|
||||
"""If run_coroutine_threadsafe raises, the coro is closed and we return 'deny'."""
|
||||
loop = MagicMock(spec=asyncio.AbstractEventLoop)
|
||||
created = {"coro": None}
|
||||
|
||||
async def _response_coro(**kwargs):
|
||||
return _make_response(AllowedOutcome(option_id="allow_once", outcome="selected"))
|
||||
|
||||
def _request_permission(**kwargs):
|
||||
created["coro"] = _response_coro(**kwargs)
|
||||
return created["coro"]
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
with patch(
|
||||
"agent.async_utils.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=RuntimeError("scheduler down"),
|
||||
):
|
||||
cb = make_approval_callback(_request_permission, loop, session_id="s1", timeout=0.01)
|
||||
result = cb("rm -rf /", "dangerous")
|
||||
gc.collect()
|
||||
|
||||
assert result == "deny"
|
||||
assert created["coro"] is not None
|
||||
assert created["coro"].cr_frame is None
|
||||
runtime_warnings = [
|
||||
w for w in caught
|
||||
if issubclass(w.category, RuntimeWarning)
|
||||
and "was never awaited" in str(w.message)
|
||||
and "_response_coro" in str(w.message)
|
||||
]
|
||||
assert runtime_warnings == []
|
||||
|
|
|
|||
157
tests/agent/test_async_utils.py
Normal file
157
tests/agent/test_async_utils.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
"""Tests for agent.async_utils.safe_schedule_threadsafe."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import warnings
|
||||
from concurrent.futures import Future
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.async_utils import safe_schedule_threadsafe
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _no_unawaited_warnings(caught, *, coro_name: str = "") -> bool:
|
||||
"""Return True if no "X was never awaited" warning slipped through.
|
||||
|
||||
When *coro_name* is provided, only warnings naming that coroutine are
|
||||
counted — xdist workers may emit unrelated unawaited-coroutine warnings
|
||||
(e.g. ``AsyncMockMixin._execute_mock_call``) from concurrent tests.
|
||||
"""
|
||||
bad = [
|
||||
w for w in caught
|
||||
if issubclass(w.category, RuntimeWarning)
|
||||
and "was never awaited" in str(w.message)
|
||||
and (not coro_name or coro_name in str(w.message))
|
||||
]
|
||||
return not bad
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSafeScheduleThreadsafe:
|
||||
def test_returns_future_on_success(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
import threading
|
||||
ready = threading.Event()
|
||||
stop = threading.Event()
|
||||
|
||||
def _runner():
|
||||
asyncio.set_event_loop(loop)
|
||||
ready.set()
|
||||
loop.run_until_complete(_wait_for_stop(stop))
|
||||
|
||||
async def _wait_for_stop(ev):
|
||||
while not ev.is_set():
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
t = threading.Thread(target=_runner, daemon=True)
|
||||
t.start()
|
||||
ready.wait(timeout=2)
|
||||
|
||||
async def _sample():
|
||||
return 42
|
||||
|
||||
fut = safe_schedule_threadsafe(_sample(), loop)
|
||||
assert isinstance(fut, Future)
|
||||
assert fut.result(timeout=2) == 42
|
||||
|
||||
stop.set()
|
||||
t.join(timeout=2)
|
||||
finally:
|
||||
if loop.is_running():
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
loop.close()
|
||||
|
||||
def test_closed_loop_returns_none_and_closes_coroutine(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.close()
|
||||
|
||||
async def _sample():
|
||||
return "ok"
|
||||
|
||||
coro = _sample()
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
result = safe_schedule_threadsafe(coro, loop)
|
||||
del coro
|
||||
gc.collect()
|
||||
|
||||
assert result is None
|
||||
assert _no_unawaited_warnings(caught, coro_name='_sample')
|
||||
|
||||
def test_none_loop_returns_none_and_closes_coroutine(self):
|
||||
async def _sample():
|
||||
return "ok"
|
||||
|
||||
coro = _sample()
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
result = safe_schedule_threadsafe(coro, None)
|
||||
del coro
|
||||
gc.collect()
|
||||
|
||||
assert result is None
|
||||
assert _no_unawaited_warnings(caught, coro_name='_sample')
|
||||
|
||||
def test_scheduling_exception_closes_coroutine(self):
|
||||
"""If run_coroutine_threadsafe raises, close the coroutine and return None."""
|
||||
# A loop that *looks* open but raises on submission
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
async def _sample():
|
||||
return "ok"
|
||||
|
||||
coro = _sample()
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
with patch(
|
||||
"agent.async_utils.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=RuntimeError("scheduler down"),
|
||||
):
|
||||
result = safe_schedule_threadsafe(coro, loop)
|
||||
del coro
|
||||
gc.collect()
|
||||
|
||||
assert result is None
|
||||
assert _no_unawaited_warnings(caught, coro_name='_sample')
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def test_logs_at_specified_level(self, caplog):
|
||||
import logging
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.close()
|
||||
|
||||
async def _sample():
|
||||
return None
|
||||
|
||||
custom = logging.getLogger("test_async_utils")
|
||||
with caplog.at_level(logging.WARNING, logger="test_async_utils"):
|
||||
result = safe_schedule_threadsafe(
|
||||
_sample(), loop,
|
||||
logger=custom,
|
||||
log_message="custom-msg",
|
||||
log_level=logging.WARNING,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert any("custom-msg" in rec.message for rec in caplog.records)
|
||||
|
||||
def test_non_coroutine_arg_does_not_crash(self):
|
||||
"""Defensive: even if the caller hands us something weird, don't blow up."""
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.close()
|
||||
|
||||
# Pass a non-coroutine sentinel
|
||||
result = safe_schedule_threadsafe("not-a-coroutine", loop) # type: ignore[arg-type]
|
||||
assert result is None
|
||||
|
|
@ -69,7 +69,8 @@ class TestProbeMcpServerTools:
|
|||
patch("tools.mcp_tool._stop_mcp_loop"):
|
||||
|
||||
# Simulate running the async probe
|
||||
def run_coro(coro, timeout=120):
|
||||
def run_coro(coro_or_factory, timeout=120):
|
||||
coro = coro_or_factory() if callable(coro_or_factory) else coro_or_factory
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
|
|
@ -110,7 +111,8 @@ class TestProbeMcpServerTools:
|
|||
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
|
||||
patch("tools.mcp_tool._stop_mcp_loop"):
|
||||
|
||||
def run_coro(coro, timeout=120):
|
||||
def run_coro(coro_or_factory, timeout=120):
|
||||
coro = coro_or_factory() if callable(coro_or_factory) else coro_or_factory
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
|
|
@ -144,7 +146,8 @@ class TestProbeMcpServerTools:
|
|||
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
|
||||
patch("tools.mcp_tool._stop_mcp_loop"):
|
||||
|
||||
def run_coro(coro, timeout=120):
|
||||
def run_coro(coro_or_factory, timeout=120):
|
||||
coro = coro_or_factory() if callable(coro_or_factory) else coro_or_factory
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
|
|
@ -198,7 +201,8 @@ class TestProbeMcpServerTools:
|
|||
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
|
||||
patch("tools.mcp_tool._stop_mcp_loop"):
|
||||
|
||||
def run_coro(coro, timeout=120):
|
||||
def run_coro(coro_or_factory, timeout=120):
|
||||
coro = coro_or_factory() if callable(coro_or_factory) else coro_or_factory
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,8 @@ class _FakeCallToolResult:
|
|||
self.structuredContent = structuredContent
|
||||
|
||||
|
||||
def _fake_run_on_mcp_loop(coro, timeout=30):
|
||||
def _fake_run_on_mcp_loop(coro_or_factory, timeout=30):
|
||||
coro = coro_or_factory() if callable(coro_or_factory) else coro_or_factory
|
||||
"""Run an MCP coroutine directly in a fresh event loop."""
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -397,6 +397,77 @@ class TestCheckFunction:
|
|||
_servers.pop("test_server", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP loop runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRunOnMcpLoop:
|
||||
def test_scheduler_failure_closes_factory_coroutine(self):
|
||||
"""If run_coroutine_threadsafe raises, the factory's coroutine is closed."""
|
||||
import gc
|
||||
import warnings
|
||||
import tools.mcp_tool as mcp
|
||||
|
||||
created = {"coro": None}
|
||||
|
||||
async def _sample():
|
||||
return "ok"
|
||||
|
||||
def factory():
|
||||
created["coro"] = _sample()
|
||||
return created["coro"]
|
||||
|
||||
fake_loop = MagicMock()
|
||||
fake_loop.is_running.return_value = True
|
||||
|
||||
with patch.object(mcp, "_mcp_loop", fake_loop):
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
with patch(
|
||||
"agent.async_utils.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=RuntimeError("scheduler down"),
|
||||
):
|
||||
with pytest.raises(RuntimeError):
|
||||
mcp._run_on_mcp_loop(factory)
|
||||
gc.collect()
|
||||
|
||||
assert created["coro"] is not None
|
||||
assert created["coro"].cr_frame is None
|
||||
runtime_warnings = [
|
||||
w for w in caught
|
||||
if issubclass(w.category, RuntimeWarning)
|
||||
and "was never awaited" in str(w.message)
|
||||
and "_sample" in str(w.message)
|
||||
]
|
||||
assert runtime_warnings == []
|
||||
|
||||
def test_dead_loop_closes_passed_coroutine(self):
|
||||
"""If loop is None, a passed coroutine (not factory) is closed."""
|
||||
import gc
|
||||
import warnings
|
||||
import tools.mcp_tool as mcp
|
||||
|
||||
async def _sample():
|
||||
return "ok"
|
||||
|
||||
coro = _sample()
|
||||
with patch.object(mcp, "_mcp_loop", None):
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
with pytest.raises(RuntimeError, match="not running"):
|
||||
mcp._run_on_mcp_loop(coro)
|
||||
gc.collect()
|
||||
|
||||
assert coro.cr_frame is None
|
||||
runtime_warnings = [
|
||||
w for w in caught
|
||||
if issubclass(w.category, RuntimeWarning)
|
||||
and "was never awaited" in str(w.message)
|
||||
and "_sample" in str(w.message)
|
||||
]
|
||||
assert runtime_warnings == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -406,7 +477,8 @@ class TestToolHandler:
|
|||
|
||||
def _patch_mcp_loop(self, coro_side_effect=None):
|
||||
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
|
||||
def fake_run(coro, timeout=30):
|
||||
def fake_run(coro_or_factory, timeout=30):
|
||||
coro = coro_or_factory() if callable(coro_or_factory) else coro_or_factory
|
||||
return asyncio.run(coro)
|
||||
if coro_side_effect:
|
||||
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=coro_side_effect)
|
||||
|
|
@ -485,7 +557,8 @@ class TestToolHandler:
|
|||
|
||||
try:
|
||||
handler = _make_tool_handler("test_srv", "greet", 120)
|
||||
def _interrupting_run(coro, timeout=30):
|
||||
def _interrupting_run(coro_or_factory, timeout=30):
|
||||
coro = coro_or_factory() if callable(coro_or_factory) else coro_or_factory
|
||||
coro.close()
|
||||
raise InterruptedError("User sent a new message")
|
||||
with patch(
|
||||
|
|
@ -1792,7 +1865,8 @@ class TestUtilityHandlers:
|
|||
|
||||
def _patch_mcp_loop(self):
|
||||
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
|
||||
def fake_run(coro, timeout=30):
|
||||
def fake_run(coro_or_factory, timeout=30):
|
||||
coro = coro_or_factory() if callable(coro_or_factory) else coro_or_factory
|
||||
return asyncio.run(coro)
|
||||
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue