hermes-agent/tests/gateway/test_gateway_shutdown.py
Teknium 327b57da91
fix(gateway): kill tool subprocesses before adapter disconnect on drain timeout (#14728)
Closes #8202.

Root cause: stop() reclaimed tool-call bash/sleep children only at the
very end of the shutdown sequence — after a 60s drain, 5s interrupt
grace, and per-adapter disconnect. Under systemd (TimeoutStopSec bounded
by drain_timeout), that meant the cgroup SIGKILL escalation fired first,
and systemd reaped the bash/sleep children instead of us.

Fix:
- Extract tool-subprocess cleanup into a local helper
  _kill_tool_subprocesses() in _stop_impl().
- Invoke it eagerly right after _interrupt_running_agents() on the
  drain-timeout path, before adapter disconnect.
- Keep the existing catch-all call at the end for the graceful path
  and defense in depth against mid-teardown respawns.
- Bump generated systemd unit TimeoutStopSec to drain_timeout + 30s
  so cleanup + disconnect + DB close has headroom above the drain
  budget, matching the 'subprocess timeout > TimeoutStopSec + margin'
  rule from the skill.

Tests:
- New: test_gateway_stop_kills_tool_subprocesses_before_adapter_disconnect_on_timeout
  asserts kill_all() runs before disconnect() when drain times out.
- New: test_gateway_stop_kills_tool_subprocesses_on_graceful_path
  guards that the final catch-all still fires when drain succeeds
  (regression guard against accidental removal during refactor).
- Updated: existing systemd unit generator tests expect TimeoutStopSec=90
  (= 60s drain + 30s headroom) with explanatory comment.
2026-04-23 13:59:29 -07:00

230 lines
7.9 KiB
Python

import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.platforms.base import MessageEvent
from gateway.restart import GATEWAY_SERVICE_RESTART_EXIT_CODE
from gateway.session import build_session_key
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
@pytest.mark.asyncio
async def test_cancel_background_tasks_cancels_inflight_message_processing():
_runner, adapter = make_restart_runner()
release = asyncio.Event()
async def block_forever(_event):
await release.wait()
return None
adapter.set_message_handler(block_forever)
event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
await adapter.handle_message(event)
await asyncio.sleep(0)
session_key = build_session_key(event.source)
assert session_key in adapter._active_sessions
assert adapter._background_tasks
await adapter.cancel_background_tasks()
assert adapter._background_tasks == set()
assert adapter._active_sessions == {}
assert adapter._pending_messages == {}
@pytest.mark.asyncio
async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks():
runner, adapter = make_restart_runner()
runner._pending_messages = {"session": "pending text"}
runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}}
runner._restart_drain_timeout = 0.0
release = asyncio.Event()
async def block_forever(_event):
await release.wait()
return None
adapter.set_message_handler(block_forever)
event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
await adapter.handle_message(event)
await asyncio.sleep(0)
disconnect_mock = AsyncMock()
adapter.disconnect = disconnect_mock
session_key = build_session_key(event.source)
running_agent = MagicMock()
runner._running_agents = {session_key: running_agent}
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop()
running_agent.interrupt.assert_called_once_with("Gateway shutting down")
disconnect_mock.assert_awaited_once()
assert runner.adapters == {}
assert runner._running_agents == {}
assert runner._pending_messages == {}
assert runner._pending_approvals == {}
assert runner._shutdown_event.is_set() is True
@pytest.mark.asyncio
async def test_gateway_stop_drains_running_agents_before_disconnect():
runner, adapter = make_restart_runner()
disconnect_mock = AsyncMock()
adapter.disconnect = disconnect_mock
running_agent = MagicMock()
runner._running_agents = {"session": running_agent}
async def finish_agent():
await asyncio.sleep(0.05)
runner._running_agents.clear()
asyncio.create_task(finish_agent())
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop()
running_agent.interrupt.assert_not_called()
disconnect_mock.assert_awaited_once()
assert runner._shutdown_event.is_set() is True
@pytest.mark.asyncio
async def test_gateway_stop_interrupts_after_drain_timeout():
runner, adapter = make_restart_runner()
runner._restart_drain_timeout = 0.05
disconnect_mock = AsyncMock()
adapter.disconnect = disconnect_mock
running_agent = MagicMock()
runner._running_agents = {"session": running_agent}
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop()
running_agent.interrupt.assert_called_once_with("Gateway shutting down")
disconnect_mock.assert_awaited_once()
assert runner._shutdown_event.is_set() is True
@pytest.mark.asyncio
async def test_gateway_stop_service_restart_sets_named_exit_code():
runner, adapter = make_restart_runner()
adapter.disconnect = AsyncMock()
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop(restart=True, service_restart=True)
assert runner._exit_code == GATEWAY_SERVICE_RESTART_EXIT_CODE
@pytest.mark.asyncio
async def test_drain_active_agents_throttles_status_updates():
runner, _adapter = make_restart_runner()
runner._update_runtime_status = MagicMock()
runner._running_agents = {"a": MagicMock(), "b": MagicMock()}
async def finish_agents():
await asyncio.sleep(0.12)
runner._running_agents.pop("a")
await asyncio.sleep(0.12)
runner._running_agents.clear()
task = asyncio.create_task(finish_agents())
await runner._drain_active_agents(1.0)
await task
# Start, one count-change update, and final update. Allow one extra update
# if the loop observes the zero-agent state before exiting.
assert 3 <= runner._update_runtime_status.call_count <= 4
@pytest.mark.asyncio
async def test_gateway_stop_kills_tool_subprocesses_before_adapter_disconnect_on_timeout(monkeypatch):
"""On drain timeout, tool subprocesses must be killed BEFORE adapter
disconnect so systemd's TimeoutStopSec doesn't SIGKILL the cgroup with
bash/sleep children still attached (#8202)."""
runner, adapter = make_restart_runner()
runner._restart_drain_timeout = 0.01 # force timeout path
call_order: list[str] = []
def _fake_kill_all(task_id=None):
call_order.append("kill_all")
return 2
def _fake_cleanup_envs():
call_order.append("cleanup_environments")
def _fake_cleanup_browsers():
call_order.append("cleanup_browsers")
async def _disconnect():
call_order.append("disconnect")
# Patch the module-level names the stop() helper imports lazily.
import tools.process_registry as _pr
import tools.terminal_tool as _tt
import tools.browser_tool as _bt
monkeypatch.setattr(_pr.process_registry, "kill_all", _fake_kill_all)
monkeypatch.setattr(_tt, "cleanup_all_environments", _fake_cleanup_envs)
monkeypatch.setattr(_bt, "cleanup_all_browsers", _fake_cleanup_browsers)
adapter.disconnect = _disconnect
runner._running_agents = {"session": MagicMock()}
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop()
# First kill_all must precede the first disconnect. (Both the eager
# post-interrupt cleanup and the final catch-all call _kill_tool_
# subprocesses, so we expect kill_all to appear twice total.)
assert "kill_all" in call_order
assert "disconnect" in call_order
first_kill = call_order.index("kill_all")
first_disconnect = call_order.index("disconnect")
assert first_kill < first_disconnect, (
f"Tool subprocesses must be killed before adapter disconnect on "
f"drain timeout, got order: {call_order}"
)
# Defense-in-depth final cleanup still runs.
assert call_order.count("kill_all") >= 2
@pytest.mark.asyncio
async def test_gateway_stop_kills_tool_subprocesses_on_graceful_path(monkeypatch):
"""Graceful shutdown (no drain timeout) must still kill tool subprocesses
exactly once via the final catch-all — regression guard against
accidentally removing that call when refactoring."""
runner, adapter = make_restart_runner()
adapter.disconnect = AsyncMock()
kill_count = 0
def _fake_kill_all(task_id=None):
nonlocal kill_count
kill_count += 1
return 0
import tools.process_registry as _pr
import tools.terminal_tool as _tt
import tools.browser_tool as _bt
monkeypatch.setattr(_pr.process_registry, "kill_all", _fake_kill_all)
monkeypatch.setattr(_tt, "cleanup_all_environments", lambda: None)
monkeypatch.setattr(_bt, "cleanup_all_browsers", lambda: None)
# No running agents → drain returns immediately, no timeout, no eager cleanup.
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop()
# Only the final catch-all fires on the graceful path.
assert kill_count == 1