hermes-agent/tests/gateway/test_cancel_background_drain.py
Teknium 62ce6a38ae
fix(gateway): cancel_background_tasks must drain late-arrivals (#12471)
During gateway shutdown, a message arriving while
cancel_background_tasks is mid-await (inside asyncio.gather) spawns
a fresh _process_message_background task via handle_message and adds
it to self._background_tasks.  The original implementation's
_background_tasks.clear() at the end of cancel_background_tasks
dropped the reference; the task ran untracked against a disconnecting
adapter, logged send-failures, and lingered until it completed on
its own.

Fix: wrap the cancel+gather in a bounded loop (MAX_DRAIN_ROUNDS=5).
If new tasks appeared during the gather, cancel them in the next
round.  The .clear() at the end is preserved as a safety net for
any task that appeared after MAX_DRAIN_ROUNDS — but in practice the
drain stabilizes in 1-2 rounds.

Tests: tests/gateway/test_cancel_background_drain.py — 3 cases.
- test_cancel_background_tasks_drains_late_arrivals: spawn M1, start
  cancel, inject M2 during M1's shielded cleanup, verify M2 is
  cancelled.
- test_cancel_background_tasks_handles_no_tasks: no-op path still
  terminates cleanly.
- test_cancel_background_tasks_bounded_rounds: baseline — single
  task cancels in one round, loop terminates.

Regression-guard validated: against the unpatched implementation,
the late-arrival test fails with exactly the expected message
('task leaked').  With the fix it passes.

Blast radius is shutdown-only; the audit classified this as MED.
Shipping because the fix is small and the hygiene is worth it.

While investigating the audit's other MEDs (busy-handler double-ack,
Discord ExecApprovalView double-resolve, UpdatePromptView
double-resolve), I verified all three were false positives — the
check-and-set patterns have no await between them, so they're
atomic on single-threaded asyncio.  No fix needed for those.
2026-04-19 01:48:42 -07:00

148 lines
5 KiB
Python

"""Regression test: cancel_background_tasks must drain late-arrival tasks.
During gateway shutdown, a message arriving while
cancel_background_tasks is mid-await can spawn a fresh
_process_message_background task via handle_message, which is added
to self._background_tasks. Without the re-drain loop, the subsequent
_background_tasks.clear() drops the reference; the task runs
untracked against a disconnecting adapter.
"""
import asyncio
from unittest.mock import AsyncMock
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType
from gateway.session import SessionSource, build_session_key
class _StubAdapter(BasePlatformAdapter):
async def connect(self):
pass
async def disconnect(self):
pass
async def send(self, chat_id, text, **kwargs):
return None
async def get_chat_info(self, chat_id):
return {}
def _make_adapter():
adapter = _StubAdapter(PlatformConfig(enabled=True, token="t"), Platform.TELEGRAM)
adapter._send_with_retry = AsyncMock(return_value=None)
return adapter
def _event(text, cid="42"):
return MessageEvent(
text=text,
message_type=MessageType.TEXT,
source=SessionSource(platform=Platform.TELEGRAM, chat_id=cid, chat_type="dm"),
)
@pytest.mark.asyncio
async def test_cancel_background_tasks_drains_late_arrivals():
"""A message that arrives during the gather window must be picked
up by the re-drain loop, not leaked as an untracked task."""
adapter = _make_adapter()
sk = build_session_key(
SessionSource(platform=Platform.TELEGRAM, chat_id="42", chat_type="dm")
)
m1_started = asyncio.Event()
m1_cleanup_running = asyncio.Event()
m2_started = asyncio.Event()
m2_cancelled = asyncio.Event()
async def handler(event):
if event.text == "M1":
m1_started.set()
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
m1_cleanup_running.set()
# Widen the gather window with a shielded cleanup
# delay so M2 can get injected during it.
await asyncio.shield(asyncio.sleep(0.2))
raise
else: # M2 — the late arrival
m2_started.set()
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
m2_cancelled.set()
raise
adapter._message_handler = handler
# Spawn M1.
await adapter.handle_message(_event("M1"))
await asyncio.wait_for(m1_started.wait(), timeout=1.0)
# Kick off shutdown. This will cancel M1 and await its cleanup.
cancel_task = asyncio.create_task(adapter.cancel_background_tasks())
# Wait until M1's cleanup is running (inside the shielded sleep).
# This is the race window: cancel_task is awaiting gather, M1 is
# shielded in cleanup, the _active_sessions entry has been cleared
# by M1's own finally.
await asyncio.wait_for(m1_cleanup_running.wait(), timeout=1.0)
# Clear the active-session entry (M1's finally hasn't fully run yet,
# but in production the platform dispatcher would deliver a new
# message that takes the no-active-session spawn path). For this
# repro, make it deterministic.
adapter._active_sessions.pop(sk, None)
# Inject late arrival — spawns a fresh _process_message_background
# task and adds it to _background_tasks while cancel_task is still
# in gather.
await adapter.handle_message(_event("M2"))
await asyncio.wait_for(m2_started.wait(), timeout=1.0)
# Let cancel_task finish. Round 1's gather completes when M1's
# shielded cleanup finishes. Round 2 should pick up M2.
await asyncio.wait_for(cancel_task, timeout=5.0)
# Assert M2 was drained, not leaked.
assert m2_cancelled.is_set(), (
"Late-arrival M2 was NOT cancelled by cancel_background_tasks — "
"the re-drain loop is missing and the task leaked"
)
assert adapter._background_tasks == set()
@pytest.mark.asyncio
async def test_cancel_background_tasks_handles_no_tasks():
"""Regression guard: no tasks, no hang, no error."""
adapter = _make_adapter()
await adapter.cancel_background_tasks()
assert adapter._background_tasks == set()
@pytest.mark.asyncio
async def test_cancel_background_tasks_bounded_rounds():
"""Regression guard: the drain loop is bounded — it does not spin
forever even if late-arrival tasks keep getting spawned."""
adapter = _make_adapter()
# Single well-behaved task that cancels cleanly — baseline check
# that the loop terminates in one round.
async def quick():
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
raise
task = asyncio.create_task(quick())
adapter._background_tasks.add(task)
await adapter.cancel_background_tasks()
assert task.done()
assert adapter._background_tasks == set()