fix(gateway): address restart review feedback

This commit is contained in:
Kenny Xie 2026-04-10 14:00:21 -07:00 committed by Teknium
parent a55c044ca8
commit ecfae98152
8 changed files with 404 additions and 213 deletions

View file

@ -673,6 +673,32 @@ class SendResult:
retryable: bool = False # True for transient connection errors — base will retry automatically retryable: bool = False # True for transient connection errors — base will retry automatically
def merge_pending_message_event(
pending_messages: Dict[str, MessageEvent],
session_key: str,
event: MessageEvent,
) -> None:
"""Store or merge a pending event for a session.
Photo bursts/albums often arrive as multiple near-simultaneous PHOTO
events. Merge those into the existing queued event so the next turn sees
the whole burst, while non-photo follow-ups still replace the pending
event normally.
"""
existing = pending_messages.get(session_key)
if (
existing
and getattr(existing, "message_type", None) == MessageType.PHOTO
and event.message_type == MessageType.PHOTO
):
existing.media_urls.extend(event.media_urls)
existing.media_types.extend(event.media_types)
if event.text:
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
return
pending_messages[session_key] = event
# Error substrings that indicate a transient *connection* failure worth retrying. # Error substrings that indicate a transient *connection* failure worth retrying.
# "timeout" / "timed out" / "readtimeout" / "writetimeout" are intentionally # "timeout" / "timed out" / "readtimeout" / "writetimeout" are intentionally
# excluded: a read/write timeout on a non-idempotent call (e.g. send_message) # excluded: a read/write timeout on a non-idempotent call (e.g. send_message)
@ -1432,14 +1458,7 @@ class BasePlatformAdapter(ABC):
# then process them immediately after the current task finishes. # then process them immediately after the current task finishes.
if event.message_type == MessageType.PHOTO: if event.message_type == MessageType.PHOTO:
logger.debug("[%s] Queuing photo follow-up for session %s without interrupt", self.name, session_key) logger.debug("[%s] Queuing photo follow-up for session %s without interrupt", self.name, session_key)
existing = self._pending_messages.get(session_key) merge_pending_message_event(self._pending_messages, session_key, event)
if existing and existing.message_type == MessageType.PHOTO:
existing.media_urls.extend(event.media_urls)
existing.media_types.extend(event.media_types)
if event.text:
existing.text = self._merge_caption(existing.text, event.text)
else:
self._pending_messages[session_key] = event
return # Don't interrupt now - will run after current task completes return # Don't interrupt now - will run after current task completes
# Default behavior for non-photo follow-ups: interrupt the running agent # Default behavior for non-photo follow-ups: interrupt the running agent

20
gateway/restart.py Normal file
View file

@ -0,0 +1,20 @@
"""Shared gateway restart constants and parsing helpers."""
from hermes_cli.config import DEFAULT_CONFIG
# EX_TEMPFAIL from sysexits.h — used to ask the service manager to restart
# the gateway after a graceful drain/reload path completes.
GATEWAY_SERVICE_RESTART_EXIT_CODE = 75
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT = float(
DEFAULT_CONFIG["agent"]["restart_drain_timeout"]
)
def parse_restart_drain_timeout(raw: object) -> float:
"""Parse a configured drain timeout, falling back to the shared default."""
try:
value = float(raw) if str(raw or "").strip() else DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
except (TypeError, ValueError):
return DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
return max(0.0, value)

View file

@ -241,7 +241,17 @@ from gateway.session import (
build_session_key, build_session_key,
) )
from gateway.delivery import DeliveryRouter from gateway.delivery import DeliveryRouter
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
MessageType,
merge_pending_message_event,
)
from gateway.restart import (
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
GATEWAY_SERVICE_RESTART_EXIT_CODE,
parse_restart_drain_timeout,
)
def _normalize_whatsapp_identifier(value: str) -> str: def _normalize_whatsapp_identifier(value: str) -> str:
@ -478,7 +488,7 @@ class GatewayRunner:
# blow up on attribute access. # blow up on attribute access.
_running_agents_ts: Dict[str, float] = {} _running_agents_ts: Dict[str, float] = {}
_busy_input_mode: str = "interrupt" _busy_input_mode: str = "interrupt"
_restart_drain_timeout: float = 60.0 _restart_drain_timeout: float = DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
_exit_code: Optional[int] = None _exit_code: Optional[int] = None
_draining: bool = False _draining: bool = False
_restart_requested: bool = False _restart_requested: bool = False
@ -486,6 +496,7 @@ class GatewayRunner:
_restart_detached: bool = False _restart_detached: bool = False
_restart_via_service: bool = False _restart_via_service: bool = False
_stop_task: Optional[asyncio.Task] = None _stop_task: Optional[asyncio.Task] = None
_session_model_overrides: Dict[str, Dict[str, str]] = {}
def __init__(self, config: Optional[GatewayConfig] = None): def __init__(self, config: Optional[GatewayConfig] = None):
self.config = config or load_gateway_config() self.config = config or load_gateway_config()
@ -1076,12 +1087,17 @@ class GatewayRunner:
raw = str(cfg.get("agent", {}).get("restart_drain_timeout", "") or "").strip() raw = str(cfg.get("agent", {}).get("restart_drain_timeout", "") or "").strip()
except Exception: except Exception:
pass pass
value = parse_restart_drain_timeout(raw)
if raw and value == DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT:
try: try:
value = float(raw) if raw else 60.0 float(raw)
except ValueError: except (TypeError, ValueError):
logger.warning("Invalid restart_drain_timeout '%s', using default 60s", raw) logger.warning(
return 60.0 "Invalid restart_drain_timeout '%s', using default %.0fs",
return max(0.0, value) raw,
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
)
return value
@staticmethod @staticmethod
def _load_background_notifications_mode() -> str: def _load_background_notifications_mode() -> str:
@ -1178,14 +1194,7 @@ class GatewayRunner:
adapter = self.adapters.get(event.source.platform) adapter = self.adapters.get(event.source.platform)
if not adapter: if not adapter:
return return
existing = adapter._pending_messages.get(session_key) merge_pending_message_event(adapter._pending_messages, session_key, event)
if existing and getattr(existing, "message_type", None) == MessageType.PHOTO and event.message_type == MessageType.PHOTO:
existing.media_urls.extend(event.media_urls)
existing.media_types.extend(event.media_types)
if event.text:
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
return
adapter._pending_messages[session_key] = event
async def _handle_active_session_busy_message(self, event: MessageEvent, session_key: str) -> bool: async def _handle_active_session_busy_message(self, event: MessageEvent, session_key: str) -> bool:
if not self._draining: if not self._draining:
@ -1212,20 +1221,32 @@ class GatewayRunner:
async def _drain_active_agents(self, timeout: float) -> tuple[Dict[str, Any], bool]: async def _drain_active_agents(self, timeout: float) -> tuple[Dict[str, Any], bool]:
snapshot = self._snapshot_running_agents() snapshot = self._snapshot_running_agents()
if not self._running_agents: last_active_count = self._running_agent_count()
last_status_at = 0.0
def _maybe_update_status(force: bool = False) -> None:
nonlocal last_active_count, last_status_at
now = asyncio.get_running_loop().time()
active_count = self._running_agent_count()
if force or active_count != last_active_count or (now - last_status_at) >= 1.0:
self._update_runtime_status("draining") self._update_runtime_status("draining")
last_active_count = active_count
last_status_at = now
if not self._running_agents:
_maybe_update_status(force=True)
return snapshot, False return snapshot, False
self._update_runtime_status("draining") _maybe_update_status(force=True)
if timeout <= 0: if timeout <= 0:
return snapshot, True return snapshot, True
deadline = asyncio.get_running_loop().time() + timeout deadline = asyncio.get_running_loop().time() + timeout
while self._running_agents and asyncio.get_running_loop().time() < deadline: while self._running_agents and asyncio.get_running_loop().time() < deadline:
self._update_runtime_status("draining") _maybe_update_status()
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
timed_out = bool(self._running_agents) timed_out = bool(self._running_agents)
self._update_runtime_status("draining") _maybe_update_status(force=True)
return snapshot, timed_out return snapshot, timed_out
def _interrupt_running_agents(self, reason: str) -> None: def _interrupt_running_agents(self, reason: str) -> None:
@ -1841,7 +1862,7 @@ class GatewayRunner:
remove_pid_file() remove_pid_file()
if self._restart_requested and self._restart_via_service: if self._restart_requested and self._restart_via_service:
self._exit_code = 75 self._exit_code = GATEWAY_SERVICE_RESTART_EXIT_CODE
self._exit_reason = self._exit_reason or "Gateway restart requested" self._exit_reason = self._exit_reason or "Gateway restart requested"
self._draining = False self._draining = False
@ -2338,18 +2359,7 @@ class GatewayRunner:
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20]) logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
adapter = self.adapters.get(source.platform) adapter = self.adapters.get(source.platform)
if adapter: if adapter:
# Reuse adapter queue semantics so photo bursts merge cleanly. merge_pending_message_event(adapter._pending_messages, _quick_key, event)
if _quick_key in adapter._pending_messages:
existing = adapter._pending_messages[_quick_key]
if getattr(existing, "message_type", None) == MessageType.PHOTO:
existing.media_urls.extend(event.media_urls)
existing.media_types.extend(event.media_types)
if event.text:
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
else:
adapter._pending_messages[_quick_key] = event
else:
adapter._pending_messages[_quick_key] = event
return None return None
running_agent = self._running_agents.get(_quick_key) running_agent = self._running_agents.get(_quick_key)
@ -3951,7 +3961,7 @@ class GatewayRunner:
# Check for session override # Check for session override
source = event.source source = event.source
session_key = self._session_key_for_source(source) session_key = self._session_key_for_source(source)
override = getattr(self, "_session_model_overrides", {}).get(session_key, {}) override = self._session_model_overrides.get(session_key, {})
if override: if override:
current_model = override.get("model", current_model) current_model = override.get("model", current_model)
current_provider = override.get("provider", current_provider) current_provider = override.get("provider", current_provider)
@ -4033,8 +4043,6 @@ class GatewayRunner:
f"via {result.provider_label or result.target_provider}. " f"via {result.provider_label or result.target_provider}. "
f"Adjust your self-identification accordingly.]" f"Adjust your self-identification accordingly.]"
) )
if not hasattr(_self, "_session_model_overrides"):
_self._session_model_overrides = {}
_self._session_model_overrides[_session_key] = { _self._session_model_overrides[_session_key] = {
"model": result.new_model, "model": result.new_model,
"provider": result.target_provider, "provider": result.target_provider,
@ -4148,8 +4156,6 @@ class GatewayRunner:
) )
# Store session override so next agent creation uses the new model # Store session override so next agent creation uses the new model
if not hasattr(self, "_session_model_overrides"):
self._session_model_overrides = {}
self._session_model_overrides[session_key] = { self._session_model_overrides[session_key] = {
"model": result.new_model, "model": result.new_model,
"provider": result.target_provider, "provider": result.target_provider,
@ -6828,7 +6834,7 @@ class GatewayRunner:
subsequent messages. Fields with ``None`` values are skipped so subsequent messages. Fields with ``None`` values are skipped so
partial overrides don't clobber valid config defaults. partial overrides don't clobber valid config defaults.
""" """
override = getattr(self, "_session_model_overrides", {}).get(session_key) override = self._session_model_overrides.get(session_key)
if not override: if not override:
return model, runtime_kwargs return model, runtime_kwargs
model = override.get("model", model) model = override.get("model", model)
@ -6840,7 +6846,7 @@ class GatewayRunner:
def _is_intentional_model_switch(self, session_key: str, agent_model: str) -> bool: def _is_intentional_model_switch(self, session_key: str, agent_model: str) -> bool:
"""Return True if *agent_model* matches an active /model session override.""" """Return True if *agent_model* matches an active /model session override."""
override = getattr(self, "_session_model_overrides", {}).get(session_key) override = self._session_model_overrides.get(session_key)
return override is not None and override.get("model") == agent_model return override is not None and override.get("model") == agent_model
def _evict_cached_agent(self, session_key: str) -> None: def _evict_cached_agent(self, session_key: str) -> None:

View file

@ -15,8 +15,12 @@ from pathlib import Path
PROJECT_ROOT = Path(__file__).parent.parent.resolve() PROJECT_ROOT = Path(__file__).parent.parent.resolve()
from gateway.status import terminate_pid from gateway.status import terminate_pid
from gateway.restart import (
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
GATEWAY_SERVICE_RESTART_EXIT_CODE,
parse_restart_drain_timeout,
)
from hermes_cli.config import ( from hermes_cli.config import (
DEFAULT_CONFIG,
get_env_value, get_env_value,
get_hermes_home, get_hermes_home,
is_managed, is_managed,
@ -787,7 +791,7 @@ Environment="VIRTUAL_ENV={venv_dir}"
Environment="HERMES_HOME={hermes_home}" Environment="HERMES_HOME={hermes_home}"
Restart=on-failure Restart=on-failure
RestartSec=30 RestartSec=30
RestartForceExitStatus=75 RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}
KillMode=mixed KillMode=mixed
KillSignal=SIGTERM KillSignal=SIGTERM
ExecReload=/bin/kill -USR1 $MAINPID ExecReload=/bin/kill -USR1 $MAINPID
@ -819,7 +823,7 @@ Environment="VIRTUAL_ENV={venv_dir}"
Environment="HERMES_HOME={hermes_home}" Environment="HERMES_HOME={hermes_home}"
Restart=on-failure Restart=on-failure
RestartSec=30 RestartSec=30
RestartForceExitStatus=75 RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}
KillMode=mixed KillMode=mixed
KillSignal=SIGTERM KillSignal=SIGTERM
ExecReload=/bin/kill -USR1 $MAINPID ExecReload=/bin/kill -USR1 $MAINPID
@ -932,11 +936,12 @@ def _get_restart_drain_timeout() -> float:
if not raw: if not raw:
cfg = read_raw_config() cfg = read_raw_config()
agent_cfg = cfg.get("agent", {}) if isinstance(cfg, dict) else {} agent_cfg = cfg.get("agent", {}) if isinstance(cfg, dict) else {}
raw = str(agent_cfg.get("restart_drain_timeout", DEFAULT_CONFIG["agent"]["restart_drain_timeout"])) raw = str(
try: agent_cfg.get(
return max(0.0, float(raw)) "restart_drain_timeout", DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
except (TypeError, ValueError): )
return float(DEFAULT_CONFIG["agent"]["restart_drain_timeout"]) )
return parse_restart_drain_timeout(raw)
def systemd_install(force: bool = False, system: bool = False, run_as_user: str | None = None): def systemd_install(force: bool = False, system: bool = False, run_as_user: str | None = None):

View file

@ -0,0 +1,110 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
from gateway.run import GatewayRunner
from gateway.session import SessionSource
class RestartTestAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
self.sent: list[str] = []
async def connect(self):
return True
async def disconnect(self):
return None
async def send(self, chat_id, content, reply_to=None, metadata=None):
self.sent.append(content)
return SendResult(success=True, message_id="1")
async def send_typing(self, chat_id, metadata=None):
return None
async def get_chat_info(self, chat_id):
return {"id": chat_id}
def make_restart_source(chat_id: str = "123456", chat_type: str = "dm") -> SessionSource:
return SessionSource(
platform=Platform.TELEGRAM,
chat_id=chat_id,
chat_type=chat_type,
)
def make_restart_runner(
adapter: BasePlatformAdapter | None = None,
) -> tuple[GatewayRunner, BasePlatformAdapter]:
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
)
runner._running = True
runner._shutdown_event = asyncio.Event()
runner._exit_reason = None
runner._exit_code = None
runner._running_agents = {}
runner._running_agents_ts = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._pending_model_notes = {}
runner._background_tasks = set()
runner._draining = False
runner._restart_requested = False
runner._restart_task_started = False
runner._restart_detached = False
runner._restart_via_service = False
runner._restart_drain_timeout = DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
runner._stop_task = None
runner._busy_input_mode = "interrupt"
runner._update_prompt_pending = {}
runner._voice_mode = {}
runner._session_model_overrides = {}
runner._shutdown_all_gateway_honcho = lambda: None
runner._update_runtime_status = MagicMock()
runner._queue_or_replace_pending_event = GatewayRunner._queue_or_replace_pending_event.__get__(
runner, GatewayRunner
)
runner._session_key_for_source = GatewayRunner._session_key_for_source.__get__(
runner, GatewayRunner
)
runner._handle_active_session_busy_message = (
GatewayRunner._handle_active_session_busy_message.__get__(runner, GatewayRunner)
)
runner._handle_restart_command = GatewayRunner._handle_restart_command.__get__(
runner, GatewayRunner
)
runner._status_action_label = GatewayRunner._status_action_label.__get__(
runner, GatewayRunner
)
runner._status_action_gerund = GatewayRunner._status_action_gerund.__get__(
runner, GatewayRunner
)
runner._queue_during_drain_enabled = GatewayRunner._queue_during_drain_enabled.__get__(
runner, GatewayRunner
)
runner._running_agent_count = GatewayRunner._running_agent_count.__get__(
runner, GatewayRunner
)
runner._launch_detached_restart_command = GatewayRunner._launch_detached_restart_command.__get__(
runner, GatewayRunner
)
runner.request_restart = GatewayRunner.request_restart.__get__(runner, GatewayRunner)
runner._is_user_authorized = lambda _source: True
runner.hooks = MagicMock()
runner.hooks.emit = AsyncMock()
runner.pairing_store = MagicMock()
runner.session_store = MagicMock()
runner.delivery_router = MagicMock()
platform_adapter = adapter or RestartTestAdapter()
platform_adapter.set_message_handler(AsyncMock(return_value=None))
platform_adapter.set_busy_session_handler(runner._handle_active_session_busy_message)
runner.adapters = {Platform.TELEGRAM: platform_adapter}
return runner, platform_adapter

View file

@ -3,67 +3,15 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig from gateway.platforms.base import MessageEvent
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult from gateway.restart import GATEWAY_SERVICE_RESTART_EXIT_CODE
from gateway.run import GatewayRunner from gateway.session import build_session_key
from gateway.session import SessionSource, build_session_key from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
class StubAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
async def connect(self):
return True
async def disconnect(self):
return None
async def send(self, chat_id, content, reply_to=None, metadata=None):
return SendResult(success=True, message_id="1")
async def send_typing(self, chat_id, metadata=None):
return None
async def get_chat_info(self, chat_id):
return {"id": chat_id}
def _source(chat_id="123456", chat_type="dm"):
return SessionSource(
platform=Platform.TELEGRAM,
chat_id=chat_id,
chat_type=chat_type,
)
def _make_runner() -> GatewayRunner:
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
runner._running = True
runner._shutdown_event = asyncio.Event()
runner._exit_reason = None
runner._exit_code = None
runner._pending_messages = {}
runner._pending_approvals = {}
runner._background_tasks = set()
runner._running_agents = {}
runner._running_agents_ts = {}
runner._draining = False
runner._restart_requested = False
runner._restart_task_started = False
runner._restart_detached = False
runner._restart_via_service = False
runner._restart_drain_timeout = 60.0
runner._stop_task = None
runner._shutdown_all_gateway_honcho = lambda: None
runner._update_runtime_status = MagicMock()
return runner
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cancel_background_tasks_cancels_inflight_message_processing(): async def test_cancel_background_tasks_cancels_inflight_message_processing():
adapter = StubAdapter() _runner, adapter = make_restart_runner()
release = asyncio.Event() release = asyncio.Event()
async def block_forever(_event): async def block_forever(_event):
@ -71,7 +19,7 @@ async def test_cancel_background_tasks_cancels_inflight_message_processing():
return None return None
adapter.set_message_handler(block_forever) adapter.set_message_handler(block_forever)
event = MessageEvent(text="work", source=_source(), message_id="1") event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
await adapter.handle_message(event) await adapter.handle_message(event)
await asyncio.sleep(0) await asyncio.sleep(0)
@ -89,12 +37,11 @@ async def test_cancel_background_tasks_cancels_inflight_message_processing():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(): async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks():
runner = _make_runner() runner, adapter = make_restart_runner()
runner._pending_messages = {"session": "pending text"} runner._pending_messages = {"session": "pending text"}
runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}} runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}}
runner._restart_drain_timeout = 0.0 runner._restart_drain_timeout = 0.0
adapter = StubAdapter()
release = asyncio.Event() release = asyncio.Event()
async def block_forever(_event): async def block_forever(_event):
@ -102,7 +49,7 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
return None return None
adapter.set_message_handler(block_forever) adapter.set_message_handler(block_forever)
event = MessageEvent(text="work", source=_source(), message_id="1") event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
await adapter.handle_message(event) await adapter.handle_message(event)
await asyncio.sleep(0) await asyncio.sleep(0)
@ -112,7 +59,6 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
session_key = build_session_key(event.source) session_key = build_session_key(event.source)
running_agent = MagicMock() running_agent = MagicMock()
runner._running_agents = {session_key: running_agent} runner._running_agents = {session_key: running_agent}
runner.adapters = {Platform.TELEGRAM: adapter}
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"): with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop() await runner.stop()
@ -128,11 +74,9 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gateway_stop_drains_running_agents_before_disconnect(): async def test_gateway_stop_drains_running_agents_before_disconnect():
runner = _make_runner() runner, adapter = make_restart_runner()
adapter = StubAdapter()
disconnect_mock = AsyncMock() disconnect_mock = AsyncMock()
adapter.disconnect = disconnect_mock adapter.disconnect = disconnect_mock
runner.adapters = {Platform.TELEGRAM: adapter}
running_agent = MagicMock() running_agent = MagicMock()
runner._running_agents = {"session": running_agent} runner._running_agents = {"session": running_agent}
@ -153,13 +97,11 @@ async def test_gateway_stop_drains_running_agents_before_disconnect():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gateway_stop_interrupts_after_drain_timeout(): async def test_gateway_stop_interrupts_after_drain_timeout():
runner = _make_runner() runner, adapter = make_restart_runner()
runner._restart_drain_timeout = 0.05 runner._restart_drain_timeout = 0.05
adapter = StubAdapter()
disconnect_mock = AsyncMock() disconnect_mock = AsyncMock()
adapter.disconnect = disconnect_mock adapter.disconnect = disconnect_mock
runner.adapters = {Platform.TELEGRAM: adapter}
running_agent = MagicMock() running_agent = MagicMock()
runner._running_agents = {"session": running_agent} runner._running_agents = {"session": running_agent}
@ -170,3 +112,36 @@ async def test_gateway_stop_interrupts_after_drain_timeout():
running_agent.interrupt.assert_called_once_with("Gateway shutting down") running_agent.interrupt.assert_called_once_with("Gateway shutting down")
disconnect_mock.assert_awaited_once() disconnect_mock.assert_awaited_once()
assert runner._shutdown_event.is_set() is True assert runner._shutdown_event.is_set() is True
@pytest.mark.asyncio
async def test_gateway_stop_service_restart_sets_named_exit_code():
runner, adapter = make_restart_runner()
adapter.disconnect = AsyncMock()
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop(restart=True, service_restart=True)
assert runner._exit_code == GATEWAY_SERVICE_RESTART_EXIT_CODE
@pytest.mark.asyncio
async def test_drain_active_agents_throttles_status_updates():
runner, _adapter = make_restart_runner()
runner._update_runtime_status = MagicMock()
runner._running_agents = {"a": MagicMock(), "b": MagicMock()}
async def finish_agents():
await asyncio.sleep(0.12)
runner._running_agents.pop("a")
await asyncio.sleep(0.12)
runner._running_agents.clear()
task = asyncio.create_task(finish_agents())
await runner._drain_active_agents(1.0)
await task
# Start, one count-change update, and final update. Allow one extra update
# if the loop observes the zero-agent state before exiting.
assert 3 <= runner._update_runtime_status.call_count <= 4

View file

@ -1,95 +1,27 @@
import asyncio import asyncio
import shutil
import subprocess
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig import gateway.run as gateway_run
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult from gateway.platforms.base import MessageEvent, MessageType
from gateway.run import GatewayRunner from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
from gateway.session import SessionSource, build_session_key from gateway.session import build_session_key
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
class RecordingAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
self.sent: list[str] = []
async def connect(self):
return True
async def disconnect(self):
return None
async def send(self, chat_id, content, reply_to=None, metadata=None):
self.sent.append(content)
return SendResult(success=True, message_id="1")
async def send_typing(self, chat_id, metadata=None):
return None
async def get_chat_info(self, chat_id):
return {"id": chat_id}
def _source(chat_id="123456"):
return SessionSource(
platform=Platform.TELEGRAM,
chat_id=chat_id,
chat_type="dm",
)
def _make_runner() -> tuple[GatewayRunner, RecordingAdapter]:
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
runner.adapters = {}
runner._running = True
runner._shutdown_event = asyncio.Event()
runner._exit_reason = None
runner._exit_code = None
runner._running_agents = {}
runner._running_agents_ts = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._background_tasks = set()
runner._draining = False
runner._restart_requested = False
runner._restart_task_started = False
runner._restart_detached = False
runner._restart_via_service = False
runner._restart_drain_timeout = 60.0
runner._stop_task = None
runner._busy_input_mode = "interrupt"
runner._update_prompt_pending = {}
runner._voice_mode = {}
runner._update_runtime_status = MagicMock()
runner._queue_or_replace_pending_event = GatewayRunner._queue_or_replace_pending_event.__get__(runner, GatewayRunner)
runner._session_key_for_source = GatewayRunner._session_key_for_source.__get__(runner, GatewayRunner)
runner._handle_active_session_busy_message = GatewayRunner._handle_active_session_busy_message.__get__(runner, GatewayRunner)
runner._handle_restart_command = GatewayRunner._handle_restart_command.__get__(runner, GatewayRunner)
runner._status_action_label = GatewayRunner._status_action_label.__get__(runner, GatewayRunner)
runner._status_action_gerund = GatewayRunner._status_action_gerund.__get__(runner, GatewayRunner)
runner._queue_during_drain_enabled = GatewayRunner._queue_during_drain_enabled.__get__(runner, GatewayRunner)
runner._running_agent_count = GatewayRunner._running_agent_count.__get__(runner, GatewayRunner)
runner.request_restart = MagicMock(return_value=True)
runner._is_user_authorized = lambda _source: True
runner.hooks = MagicMock()
runner.hooks.emit = AsyncMock()
runner.pairing_store = MagicMock()
runner.session_store = MagicMock()
runner.delivery_router = MagicMock()
adapter = RecordingAdapter()
adapter.set_message_handler(AsyncMock(return_value=None))
adapter.set_busy_session_handler(runner._handle_active_session_busy_message)
runner.adapters = {Platform.TELEGRAM: adapter}
return runner, adapter
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_restart_command_while_busy_requests_drain_without_interrupt(): async def test_restart_command_while_busy_requests_drain_without_interrupt():
runner, _adapter = _make_runner() runner, _adapter = make_restart_runner()
event = MessageEvent(text="/restart", message_type=MessageType.TEXT, source=_source(), message_id="m1") runner.request_restart = MagicMock(return_value=True)
event = MessageEvent(
text="/restart",
message_type=MessageType.TEXT,
source=make_restart_source(),
message_id="m1",
)
session_key = build_session_key(event.source) session_key = build_session_key(event.source)
running_agent = MagicMock() running_agent = MagicMock()
runner._running_agents[session_key] = running_agent runner._running_agents[session_key] = running_agent
@ -103,12 +35,17 @@ async def test_restart_command_while_busy_requests_drain_without_interrupt():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_drain_queue_mode_queues_follow_up_without_interrupt(): async def test_drain_queue_mode_queues_follow_up_without_interrupt():
runner, adapter = _make_runner() runner, adapter = make_restart_runner()
runner._draining = True runner._draining = True
runner._restart_requested = True runner._restart_requested = True
runner._busy_input_mode = "queue" runner._busy_input_mode = "queue"
event = MessageEvent(text="follow up", message_type=MessageType.TEXT, source=_source(), message_id="m2") event = MessageEvent(
text="follow up",
message_type=MessageType.TEXT,
source=make_restart_source(),
message_id="m2",
)
session_key = build_session_key(event.source) session_key = build_session_key(event.source)
adapter._active_sessions[session_key] = asyncio.Event() adapter._active_sessions[session_key] = asyncio.Event()
@ -122,12 +59,102 @@ async def test_drain_queue_mode_queues_follow_up_without_interrupt():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_draining_rejects_new_session_messages(): async def test_draining_rejects_new_session_messages():
runner, _adapter = _make_runner() runner, _adapter = make_restart_runner()
runner._draining = True runner._draining = True
runner._restart_requested = True runner._restart_requested = True
event = MessageEvent(text="hello", message_type=MessageType.TEXT, source=_source("fresh"), message_id="m3") event = MessageEvent(
text="hello",
message_type=MessageType.TEXT,
source=make_restart_source("fresh"),
message_id="m3",
)
result = await runner._handle_message(event) result = await runner._handle_message(event)
assert result == "⏳ Gateway is restarting and is not accepting new work right now." assert result == "⏳ Gateway is restarting and is not accepting new work right now."
def test_load_busy_input_mode_prefers_env_then_config_then_default(tmp_path, monkeypatch):
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("HERMES_GATEWAY_BUSY_INPUT_MODE", raising=False)
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
(tmp_path / "config.yaml").write_text(
"display:\n busy_input_mode: queue\n", encoding="utf-8"
)
assert gateway_run.GatewayRunner._load_busy_input_mode() == "queue"
monkeypatch.setenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "interrupt")
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
def test_load_restart_drain_timeout_prefers_env_then_config_then_default(
tmp_path, monkeypatch, caplog
):
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
assert (
gateway_run.GatewayRunner._load_restart_drain_timeout()
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
)
(tmp_path / "config.yaml").write_text(
"agent:\n restart_drain_timeout: 12\n", encoding="utf-8"
)
assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 12.0
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "7")
assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 7.0
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "invalid")
assert (
gateway_run.GatewayRunner._load_restart_drain_timeout()
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
)
assert "Invalid restart_drain_timeout" in caplog.text
@pytest.mark.asyncio
async def test_request_restart_is_idempotent():
runner, _adapter = make_restart_runner()
runner.stop = AsyncMock()
assert runner.request_restart(detached=True, via_service=False) is True
first_task = next(iter(runner._background_tasks))
assert runner.request_restart(detached=True, via_service=False) is False
await first_task
runner.stop.assert_awaited_once_with(
restart=True, detached_restart=True, service_restart=False
)
@pytest.mark.asyncio
async def test_launch_detached_restart_command_uses_setsid(monkeypatch):
runner, _adapter = make_restart_runner()
popen_calls = []
monkeypatch.setattr(gateway_run, "_resolve_hermes_bin", lambda: ["/usr/bin/hermes"])
monkeypatch.setattr(gateway_run.os, "getpid", lambda: 321)
monkeypatch.setattr(shutil, "which", lambda cmd: "/usr/bin/setsid" if cmd == "setsid" else None)
def fake_popen(cmd, **kwargs):
popen_calls.append((cmd, kwargs))
return MagicMock()
monkeypatch.setattr(subprocess, "Popen", fake_popen)
await runner._launch_detached_restart_command()
assert len(popen_calls) == 1
cmd, kwargs = popen_calls[0]
assert cmd[:2] == ["/usr/bin/setsid", "bash"]
assert "gateway restart" in cmd[-1]
assert "kill -0 321" in cmd[-1]
assert kwargs["start_new_session"] is True
assert kwargs["stdout"] is subprocess.DEVNULL
assert kwargs["stderr"] is subprocess.DEVNULL

View file

@ -5,6 +5,10 @@ from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
import hermes_cli.gateway as gateway_cli import hermes_cli.gateway as gateway_cli
from gateway.restart import (
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
GATEWAY_SERVICE_RESTART_EXIT_CODE,
)
class TestSystemdServiceRefresh: class TestSystemdServiceRefresh:
@ -85,7 +89,7 @@ class TestGeneratedSystemdUnits:
assert "ExecStart=" in unit assert "ExecStart=" in unit
assert "ExecStop=" not in unit assert "ExecStop=" not in unit
assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit
assert "RestartForceExitStatus=75" in unit assert f"RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}" in unit
assert "TimeoutStopSec=60" in unit assert "TimeoutStopSec=60" in unit
def test_user_unit_includes_resolved_node_directory_in_path(self, monkeypatch): def test_user_unit_includes_resolved_node_directory_in_path(self, monkeypatch):
@ -101,7 +105,7 @@ class TestGeneratedSystemdUnits:
assert "ExecStart=" in unit assert "ExecStart=" in unit
assert "ExecStop=" not in unit assert "ExecStop=" not in unit
assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit
assert "RestartForceExitStatus=75" in unit assert f"RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}" in unit
assert "TimeoutStopSec=60" in unit assert "TimeoutStopSec=60" in unit
assert "WantedBy=multi-user.target" in unit assert "WantedBy=multi-user.target" in unit
@ -161,6 +165,31 @@ class TestGatewayStopCleanup:
class TestLaunchdServiceRecovery: class TestLaunchdServiceRecovery:
def test_get_restart_drain_timeout_prefers_env_then_config_then_default(self, monkeypatch):
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
monkeypatch.setattr(gateway_cli, "read_raw_config", lambda: {})
assert (
gateway_cli._get_restart_drain_timeout()
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
)
monkeypatch.setattr(
gateway_cli,
"read_raw_config",
lambda: {"agent": {"restart_drain_timeout": 14}},
)
assert gateway_cli._get_restart_drain_timeout() == 14.0
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "9")
assert gateway_cli._get_restart_drain_timeout() == 9.0
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "invalid")
assert (
gateway_cli._get_restart_drain_timeout()
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
)
def test_launchd_install_repairs_outdated_plist_without_force(self, tmp_path, monkeypatch): def test_launchd_install_repairs_outdated_plist_without_force(self, tmp_path, monkeypatch):
plist_path = tmp_path / "ai.hermes.gateway.plist" plist_path = tmp_path / "ai.hermes.gateway.plist"
plist_path.write_text("<plist>old content</plist>", encoding="utf-8") plist_path.write_text("<plist>old content</plist>", encoding="utf-8")