From ecfae9815296ea0391fe68c9c3b84538e5ff6174 Mon Sep 17 00:00:00 2001 From: Kenny Xie Date: Fri, 10 Apr 2026 14:00:21 -0700 Subject: [PATCH] fix(gateway): address restart review feedback --- gateway/platforms/base.py | 35 +++- gateway/restart.py | 20 +++ gateway/run.py | 86 +++++----- hermes_cli/gateway.py | 21 ++- tests/gateway/restart_test_helpers.py | 110 +++++++++++++ tests/gateway/test_gateway_shutdown.py | 111 +++++-------- tests/gateway/test_restart_drain.py | 201 +++++++++++++---------- tests/hermes_cli/test_gateway_service.py | 33 +++- 8 files changed, 404 insertions(+), 213 deletions(-) create mode 100644 gateway/restart.py create mode 100644 tests/gateway/restart_test_helpers.py diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 34aacc7a3..04f0c1deb 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -673,6 +673,32 @@ class SendResult: retryable: bool = False # True for transient connection errors — base will retry automatically +def merge_pending_message_event( + pending_messages: Dict[str, MessageEvent], + session_key: str, + event: MessageEvent, +) -> None: + """Store or merge a pending event for a session. + + Photo bursts/albums often arrive as multiple near-simultaneous PHOTO + events. Merge those into the existing queued event so the next turn sees + the whole burst, while non-photo follow-ups still replace the pending + event normally. + """ + existing = pending_messages.get(session_key) + if ( + existing + and getattr(existing, "message_type", None) == MessageType.PHOTO + and event.message_type == MessageType.PHOTO + ): + existing.media_urls.extend(event.media_urls) + existing.media_types.extend(event.media_types) + if event.text: + existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text) + return + pending_messages[session_key] = event + + # Error substrings that indicate a transient *connection* failure worth retrying. # "timeout" / "timed out" / "readtimeout" / "writetimeout" are intentionally # excluded: a read/write timeout on a non-idempotent call (e.g. send_message) @@ -1432,14 +1458,7 @@ class BasePlatformAdapter(ABC): # then process them immediately after the current task finishes. if event.message_type == MessageType.PHOTO: logger.debug("[%s] Queuing photo follow-up for session %s without interrupt", self.name, session_key) - existing = self._pending_messages.get(session_key) - if existing and existing.message_type == MessageType.PHOTO: - existing.media_urls.extend(event.media_urls) - existing.media_types.extend(event.media_types) - if event.text: - existing.text = self._merge_caption(existing.text, event.text) - else: - self._pending_messages[session_key] = event + merge_pending_message_event(self._pending_messages, session_key, event) return # Don't interrupt now - will run after current task completes # Default behavior for non-photo follow-ups: interrupt the running agent diff --git a/gateway/restart.py b/gateway/restart.py new file mode 100644 index 000000000..fe9b70022 --- /dev/null +++ b/gateway/restart.py @@ -0,0 +1,20 @@ +"""Shared gateway restart constants and parsing helpers.""" + +from hermes_cli.config import DEFAULT_CONFIG + +# EX_TEMPFAIL from sysexits.h — used to ask the service manager to restart +# the gateway after a graceful drain/reload path completes. +GATEWAY_SERVICE_RESTART_EXIT_CODE = 75 + +DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT = float( + DEFAULT_CONFIG["agent"]["restart_drain_timeout"] +) + + +def parse_restart_drain_timeout(raw: object) -> float: + """Parse a configured drain timeout, falling back to the shared default.""" + try: + value = float(raw) if str(raw or "").strip() else DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT + except (TypeError, ValueError): + return DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT + return max(0.0, value) diff --git a/gateway/run.py b/gateway/run.py index 7f950b297..b370060fc 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -241,7 +241,17 @@ from gateway.session import ( build_session_key, ) from gateway.delivery import DeliveryRouter -from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType +from gateway.platforms.base import ( + BasePlatformAdapter, + MessageEvent, + MessageType, + merge_pending_message_event, +) +from gateway.restart import ( + DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT, + GATEWAY_SERVICE_RESTART_EXIT_CODE, + parse_restart_drain_timeout, +) def _normalize_whatsapp_identifier(value: str) -> str: @@ -478,7 +488,7 @@ class GatewayRunner: # blow up on attribute access. _running_agents_ts: Dict[str, float] = {} _busy_input_mode: str = "interrupt" - _restart_drain_timeout: float = 60.0 + _restart_drain_timeout: float = DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT _exit_code: Optional[int] = None _draining: bool = False _restart_requested: bool = False @@ -486,6 +496,7 @@ class GatewayRunner: _restart_detached: bool = False _restart_via_service: bool = False _stop_task: Optional[asyncio.Task] = None + _session_model_overrides: Dict[str, Dict[str, str]] = {} def __init__(self, config: Optional[GatewayConfig] = None): self.config = config or load_gateway_config() @@ -1076,12 +1087,17 @@ class GatewayRunner: raw = str(cfg.get("agent", {}).get("restart_drain_timeout", "") or "").strip() except Exception: pass - try: - value = float(raw) if raw else 60.0 - except ValueError: - logger.warning("Invalid restart_drain_timeout '%s', using default 60s", raw) - return 60.0 - return max(0.0, value) + value = parse_restart_drain_timeout(raw) + if raw and value == DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT: + try: + float(raw) + except (TypeError, ValueError): + logger.warning( + "Invalid restart_drain_timeout '%s', using default %.0fs", + raw, + DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT, + ) + return value @staticmethod def _load_background_notifications_mode() -> str: @@ -1178,14 +1194,7 @@ class GatewayRunner: adapter = self.adapters.get(event.source.platform) if not adapter: return - existing = adapter._pending_messages.get(session_key) - if existing and getattr(existing, "message_type", None) == MessageType.PHOTO and event.message_type == MessageType.PHOTO: - existing.media_urls.extend(event.media_urls) - existing.media_types.extend(event.media_types) - if event.text: - existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text) - return - adapter._pending_messages[session_key] = event + merge_pending_message_event(adapter._pending_messages, session_key, event) async def _handle_active_session_busy_message(self, event: MessageEvent, session_key: str) -> bool: if not self._draining: @@ -1212,20 +1221,32 @@ class GatewayRunner: async def _drain_active_agents(self, timeout: float) -> tuple[Dict[str, Any], bool]: snapshot = self._snapshot_running_agents() + last_active_count = self._running_agent_count() + last_status_at = 0.0 + + def _maybe_update_status(force: bool = False) -> None: + nonlocal last_active_count, last_status_at + now = asyncio.get_running_loop().time() + active_count = self._running_agent_count() + if force or active_count != last_active_count or (now - last_status_at) >= 1.0: + self._update_runtime_status("draining") + last_active_count = active_count + last_status_at = now + if not self._running_agents: - self._update_runtime_status("draining") + _maybe_update_status(force=True) return snapshot, False - self._update_runtime_status("draining") + _maybe_update_status(force=True) if timeout <= 0: return snapshot, True deadline = asyncio.get_running_loop().time() + timeout while self._running_agents and asyncio.get_running_loop().time() < deadline: - self._update_runtime_status("draining") + _maybe_update_status() await asyncio.sleep(0.1) timed_out = bool(self._running_agents) - self._update_runtime_status("draining") + _maybe_update_status(force=True) return snapshot, timed_out def _interrupt_running_agents(self, reason: str) -> None: @@ -1841,7 +1862,7 @@ class GatewayRunner: remove_pid_file() if self._restart_requested and self._restart_via_service: - self._exit_code = 75 + self._exit_code = GATEWAY_SERVICE_RESTART_EXIT_CODE self._exit_reason = self._exit_reason or "Gateway restart requested" self._draining = False @@ -2338,18 +2359,7 @@ class GatewayRunner: logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20]) adapter = self.adapters.get(source.platform) if adapter: - # Reuse adapter queue semantics so photo bursts merge cleanly. - if _quick_key in adapter._pending_messages: - existing = adapter._pending_messages[_quick_key] - if getattr(existing, "message_type", None) == MessageType.PHOTO: - existing.media_urls.extend(event.media_urls) - existing.media_types.extend(event.media_types) - if event.text: - existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text) - else: - adapter._pending_messages[_quick_key] = event - else: - adapter._pending_messages[_quick_key] = event + merge_pending_message_event(adapter._pending_messages, _quick_key, event) return None running_agent = self._running_agents.get(_quick_key) @@ -3951,7 +3961,7 @@ class GatewayRunner: # Check for session override source = event.source session_key = self._session_key_for_source(source) - override = getattr(self, "_session_model_overrides", {}).get(session_key, {}) + override = self._session_model_overrides.get(session_key, {}) if override: current_model = override.get("model", current_model) current_provider = override.get("provider", current_provider) @@ -4033,8 +4043,6 @@ class GatewayRunner: f"via {result.provider_label or result.target_provider}. " f"Adjust your self-identification accordingly.]" ) - if not hasattr(_self, "_session_model_overrides"): - _self._session_model_overrides = {} _self._session_model_overrides[_session_key] = { "model": result.new_model, "provider": result.target_provider, @@ -4148,8 +4156,6 @@ class GatewayRunner: ) # Store session override so next agent creation uses the new model - if not hasattr(self, "_session_model_overrides"): - self._session_model_overrides = {} self._session_model_overrides[session_key] = { "model": result.new_model, "provider": result.target_provider, @@ -6828,7 +6834,7 @@ class GatewayRunner: subsequent messages. Fields with ``None`` values are skipped so partial overrides don't clobber valid config defaults. """ - override = getattr(self, "_session_model_overrides", {}).get(session_key) + override = self._session_model_overrides.get(session_key) if not override: return model, runtime_kwargs model = override.get("model", model) @@ -6840,7 +6846,7 @@ class GatewayRunner: def _is_intentional_model_switch(self, session_key: str, agent_model: str) -> bool: """Return True if *agent_model* matches an active /model session override.""" - override = getattr(self, "_session_model_overrides", {}).get(session_key) + override = self._session_model_overrides.get(session_key) return override is not None and override.get("model") == agent_model def _evict_cached_agent(self, session_key: str) -> None: diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index 689164e15..b29511dd5 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -15,8 +15,12 @@ from pathlib import Path PROJECT_ROOT = Path(__file__).parent.parent.resolve() from gateway.status import terminate_pid +from gateway.restart import ( + DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT, + GATEWAY_SERVICE_RESTART_EXIT_CODE, + parse_restart_drain_timeout, +) from hermes_cli.config import ( - DEFAULT_CONFIG, get_env_value, get_hermes_home, is_managed, @@ -787,7 +791,7 @@ Environment="VIRTUAL_ENV={venv_dir}" Environment="HERMES_HOME={hermes_home}" Restart=on-failure RestartSec=30 -RestartForceExitStatus=75 +RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE} KillMode=mixed KillSignal=SIGTERM ExecReload=/bin/kill -USR1 $MAINPID @@ -819,7 +823,7 @@ Environment="VIRTUAL_ENV={venv_dir}" Environment="HERMES_HOME={hermes_home}" Restart=on-failure RestartSec=30 -RestartForceExitStatus=75 +RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE} KillMode=mixed KillSignal=SIGTERM ExecReload=/bin/kill -USR1 $MAINPID @@ -932,11 +936,12 @@ def _get_restart_drain_timeout() -> float: if not raw: cfg = read_raw_config() agent_cfg = cfg.get("agent", {}) if isinstance(cfg, dict) else {} - raw = str(agent_cfg.get("restart_drain_timeout", DEFAULT_CONFIG["agent"]["restart_drain_timeout"])) - try: - return max(0.0, float(raw)) - except (TypeError, ValueError): - return float(DEFAULT_CONFIG["agent"]["restart_drain_timeout"]) + raw = str( + agent_cfg.get( + "restart_drain_timeout", DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT + ) + ) + return parse_restart_drain_timeout(raw) def systemd_install(force: bool = False, system: bool = False, run_as_user: str | None = None): diff --git a/tests/gateway/restart_test_helpers.py b/tests/gateway/restart_test_helpers.py new file mode 100644 index 000000000..54dcd69b9 --- /dev/null +++ b/tests/gateway/restart_test_helpers.py @@ -0,0 +1,110 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock + +from gateway.config import GatewayConfig, Platform, PlatformConfig +from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult +from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT +from gateway.run import GatewayRunner +from gateway.session import SessionSource + + +class RestartTestAdapter(BasePlatformAdapter): + def __init__(self): + super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM) + self.sent: list[str] = [] + + async def connect(self): + return True + + async def disconnect(self): + return None + + async def send(self, chat_id, content, reply_to=None, metadata=None): + self.sent.append(content) + return SendResult(success=True, message_id="1") + + async def send_typing(self, chat_id, metadata=None): + return None + + async def get_chat_info(self, chat_id): + return {"id": chat_id} + + +def make_restart_source(chat_id: str = "123456", chat_type: str = "dm") -> SessionSource: + return SessionSource( + platform=Platform.TELEGRAM, + chat_id=chat_id, + chat_type=chat_type, + ) + + +def make_restart_runner( + adapter: BasePlatformAdapter | None = None, +) -> tuple[GatewayRunner, BasePlatformAdapter]: + runner = object.__new__(GatewayRunner) + runner.config = GatewayConfig( + platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")} + ) + runner._running = True + runner._shutdown_event = asyncio.Event() + runner._exit_reason = None + runner._exit_code = None + runner._running_agents = {} + runner._running_agents_ts = {} + runner._pending_messages = {} + runner._pending_approvals = {} + runner._pending_model_notes = {} + runner._background_tasks = set() + runner._draining = False + runner._restart_requested = False + runner._restart_task_started = False + runner._restart_detached = False + runner._restart_via_service = False + runner._restart_drain_timeout = DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT + runner._stop_task = None + runner._busy_input_mode = "interrupt" + runner._update_prompt_pending = {} + runner._voice_mode = {} + runner._session_model_overrides = {} + runner._shutdown_all_gateway_honcho = lambda: None + runner._update_runtime_status = MagicMock() + runner._queue_or_replace_pending_event = GatewayRunner._queue_or_replace_pending_event.__get__( + runner, GatewayRunner + ) + runner._session_key_for_source = GatewayRunner._session_key_for_source.__get__( + runner, GatewayRunner + ) + runner._handle_active_session_busy_message = ( + GatewayRunner._handle_active_session_busy_message.__get__(runner, GatewayRunner) + ) + runner._handle_restart_command = GatewayRunner._handle_restart_command.__get__( + runner, GatewayRunner + ) + runner._status_action_label = GatewayRunner._status_action_label.__get__( + runner, GatewayRunner + ) + runner._status_action_gerund = GatewayRunner._status_action_gerund.__get__( + runner, GatewayRunner + ) + runner._queue_during_drain_enabled = GatewayRunner._queue_during_drain_enabled.__get__( + runner, GatewayRunner + ) + runner._running_agent_count = GatewayRunner._running_agent_count.__get__( + runner, GatewayRunner + ) + runner._launch_detached_restart_command = GatewayRunner._launch_detached_restart_command.__get__( + runner, GatewayRunner + ) + runner.request_restart = GatewayRunner.request_restart.__get__(runner, GatewayRunner) + runner._is_user_authorized = lambda _source: True + runner.hooks = MagicMock() + runner.hooks.emit = AsyncMock() + runner.pairing_store = MagicMock() + runner.session_store = MagicMock() + runner.delivery_router = MagicMock() + + platform_adapter = adapter or RestartTestAdapter() + platform_adapter.set_message_handler(AsyncMock(return_value=None)) + platform_adapter.set_busy_session_handler(runner._handle_active_session_busy_message) + runner.adapters = {Platform.TELEGRAM: platform_adapter} + return runner, platform_adapter diff --git a/tests/gateway/test_gateway_shutdown.py b/tests/gateway/test_gateway_shutdown.py index b6a7f8fa7..4dc9919bc 100644 --- a/tests/gateway/test_gateway_shutdown.py +++ b/tests/gateway/test_gateway_shutdown.py @@ -3,67 +3,15 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from gateway.config import GatewayConfig, Platform, PlatformConfig -from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult -from gateway.run import GatewayRunner -from gateway.session import SessionSource, build_session_key - - -class StubAdapter(BasePlatformAdapter): - def __init__(self): - super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM) - - async def connect(self): - return True - - async def disconnect(self): - return None - - async def send(self, chat_id, content, reply_to=None, metadata=None): - return SendResult(success=True, message_id="1") - - async def send_typing(self, chat_id, metadata=None): - return None - - async def get_chat_info(self, chat_id): - return {"id": chat_id} - - -def _source(chat_id="123456", chat_type="dm"): - return SessionSource( - platform=Platform.TELEGRAM, - chat_id=chat_id, - chat_type=chat_type, - ) - - -def _make_runner() -> GatewayRunner: - runner = object.__new__(GatewayRunner) - runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}) - runner._running = True - runner._shutdown_event = asyncio.Event() - runner._exit_reason = None - runner._exit_code = None - runner._pending_messages = {} - runner._pending_approvals = {} - runner._background_tasks = set() - runner._running_agents = {} - runner._running_agents_ts = {} - runner._draining = False - runner._restart_requested = False - runner._restart_task_started = False - runner._restart_detached = False - runner._restart_via_service = False - runner._restart_drain_timeout = 60.0 - runner._stop_task = None - runner._shutdown_all_gateway_honcho = lambda: None - runner._update_runtime_status = MagicMock() - return runner +from gateway.platforms.base import MessageEvent +from gateway.restart import GATEWAY_SERVICE_RESTART_EXIT_CODE +from gateway.session import build_session_key +from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source @pytest.mark.asyncio async def test_cancel_background_tasks_cancels_inflight_message_processing(): - adapter = StubAdapter() + _runner, adapter = make_restart_runner() release = asyncio.Event() async def block_forever(_event): @@ -71,7 +19,7 @@ async def test_cancel_background_tasks_cancels_inflight_message_processing(): return None adapter.set_message_handler(block_forever) - event = MessageEvent(text="work", source=_source(), message_id="1") + event = MessageEvent(text="work", source=make_restart_source(), message_id="1") await adapter.handle_message(event) await asyncio.sleep(0) @@ -89,12 +37,11 @@ async def test_cancel_background_tasks_cancels_inflight_message_processing(): @pytest.mark.asyncio async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(): - runner = _make_runner() + runner, adapter = make_restart_runner() runner._pending_messages = {"session": "pending text"} runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}} runner._restart_drain_timeout = 0.0 - adapter = StubAdapter() release = asyncio.Event() async def block_forever(_event): @@ -102,7 +49,7 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks( return None adapter.set_message_handler(block_forever) - event = MessageEvent(text="work", source=_source(), message_id="1") + event = MessageEvent(text="work", source=make_restart_source(), message_id="1") await adapter.handle_message(event) await asyncio.sleep(0) @@ -112,7 +59,6 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks( session_key = build_session_key(event.source) running_agent = MagicMock() runner._running_agents = {session_key: running_agent} - runner.adapters = {Platform.TELEGRAM: adapter} with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"): await runner.stop() @@ -128,11 +74,9 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks( @pytest.mark.asyncio async def test_gateway_stop_drains_running_agents_before_disconnect(): - runner = _make_runner() - adapter = StubAdapter() + runner, adapter = make_restart_runner() disconnect_mock = AsyncMock() adapter.disconnect = disconnect_mock - runner.adapters = {Platform.TELEGRAM: adapter} running_agent = MagicMock() runner._running_agents = {"session": running_agent} @@ -153,13 +97,11 @@ async def test_gateway_stop_drains_running_agents_before_disconnect(): @pytest.mark.asyncio async def test_gateway_stop_interrupts_after_drain_timeout(): - runner = _make_runner() + runner, adapter = make_restart_runner() runner._restart_drain_timeout = 0.05 - adapter = StubAdapter() disconnect_mock = AsyncMock() adapter.disconnect = disconnect_mock - runner.adapters = {Platform.TELEGRAM: adapter} running_agent = MagicMock() runner._running_agents = {"session": running_agent} @@ -170,3 +112,36 @@ async def test_gateway_stop_interrupts_after_drain_timeout(): running_agent.interrupt.assert_called_once_with("Gateway shutting down") disconnect_mock.assert_awaited_once() assert runner._shutdown_event.is_set() is True + + +@pytest.mark.asyncio +async def test_gateway_stop_service_restart_sets_named_exit_code(): + runner, adapter = make_restart_runner() + adapter.disconnect = AsyncMock() + + with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"): + await runner.stop(restart=True, service_restart=True) + + assert runner._exit_code == GATEWAY_SERVICE_RESTART_EXIT_CODE + + +@pytest.mark.asyncio +async def test_drain_active_agents_throttles_status_updates(): + runner, _adapter = make_restart_runner() + runner._update_runtime_status = MagicMock() + + runner._running_agents = {"a": MagicMock(), "b": MagicMock()} + + async def finish_agents(): + await asyncio.sleep(0.12) + runner._running_agents.pop("a") + await asyncio.sleep(0.12) + runner._running_agents.clear() + + task = asyncio.create_task(finish_agents()) + await runner._drain_active_agents(1.0) + await task + + # Start, one count-change update, and final update. Allow one extra update + # if the loop observes the zero-agent state before exiting. + assert 3 <= runner._update_runtime_status.call_count <= 4 diff --git a/tests/gateway/test_restart_drain.py b/tests/gateway/test_restart_drain.py index 2c59f9a97..0c1324664 100644 --- a/tests/gateway/test_restart_drain.py +++ b/tests/gateway/test_restart_drain.py @@ -1,95 +1,27 @@ import asyncio +import shutil +import subprocess from unittest.mock import AsyncMock, MagicMock import pytest -from gateway.config import GatewayConfig, Platform, PlatformConfig -from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult -from gateway.run import GatewayRunner -from gateway.session import SessionSource, build_session_key - - -class RecordingAdapter(BasePlatformAdapter): - def __init__(self): - super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM) - self.sent: list[str] = [] - - async def connect(self): - return True - - async def disconnect(self): - return None - - async def send(self, chat_id, content, reply_to=None, metadata=None): - self.sent.append(content) - return SendResult(success=True, message_id="1") - - async def send_typing(self, chat_id, metadata=None): - return None - - async def get_chat_info(self, chat_id): - return {"id": chat_id} - - -def _source(chat_id="123456"): - return SessionSource( - platform=Platform.TELEGRAM, - chat_id=chat_id, - chat_type="dm", - ) - - -def _make_runner() -> tuple[GatewayRunner, RecordingAdapter]: - runner = object.__new__(GatewayRunner) - runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}) - runner.adapters = {} - runner._running = True - runner._shutdown_event = asyncio.Event() - runner._exit_reason = None - runner._exit_code = None - runner._running_agents = {} - runner._running_agents_ts = {} - runner._pending_messages = {} - runner._pending_approvals = {} - runner._background_tasks = set() - runner._draining = False - runner._restart_requested = False - runner._restart_task_started = False - runner._restart_detached = False - runner._restart_via_service = False - runner._restart_drain_timeout = 60.0 - runner._stop_task = None - runner._busy_input_mode = "interrupt" - runner._update_prompt_pending = {} - runner._voice_mode = {} - runner._update_runtime_status = MagicMock() - runner._queue_or_replace_pending_event = GatewayRunner._queue_or_replace_pending_event.__get__(runner, GatewayRunner) - runner._session_key_for_source = GatewayRunner._session_key_for_source.__get__(runner, GatewayRunner) - runner._handle_active_session_busy_message = GatewayRunner._handle_active_session_busy_message.__get__(runner, GatewayRunner) - runner._handle_restart_command = GatewayRunner._handle_restart_command.__get__(runner, GatewayRunner) - runner._status_action_label = GatewayRunner._status_action_label.__get__(runner, GatewayRunner) - runner._status_action_gerund = GatewayRunner._status_action_gerund.__get__(runner, GatewayRunner) - runner._queue_during_drain_enabled = GatewayRunner._queue_during_drain_enabled.__get__(runner, GatewayRunner) - runner._running_agent_count = GatewayRunner._running_agent_count.__get__(runner, GatewayRunner) - runner.request_restart = MagicMock(return_value=True) - runner._is_user_authorized = lambda _source: True - runner.hooks = MagicMock() - runner.hooks.emit = AsyncMock() - runner.pairing_store = MagicMock() - runner.session_store = MagicMock() - runner.delivery_router = MagicMock() - - adapter = RecordingAdapter() - adapter.set_message_handler(AsyncMock(return_value=None)) - adapter.set_busy_session_handler(runner._handle_active_session_busy_message) - runner.adapters = {Platform.TELEGRAM: adapter} - return runner, adapter +import gateway.run as gateway_run +from gateway.platforms.base import MessageEvent, MessageType +from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT +from gateway.session import build_session_key +from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source @pytest.mark.asyncio async def test_restart_command_while_busy_requests_drain_without_interrupt(): - runner, _adapter = _make_runner() - event = MessageEvent(text="/restart", message_type=MessageType.TEXT, source=_source(), message_id="m1") + runner, _adapter = make_restart_runner() + runner.request_restart = MagicMock(return_value=True) + event = MessageEvent( + text="/restart", + message_type=MessageType.TEXT, + source=make_restart_source(), + message_id="m1", + ) session_key = build_session_key(event.source) running_agent = MagicMock() runner._running_agents[session_key] = running_agent @@ -103,12 +35,17 @@ async def test_restart_command_while_busy_requests_drain_without_interrupt(): @pytest.mark.asyncio async def test_drain_queue_mode_queues_follow_up_without_interrupt(): - runner, adapter = _make_runner() + runner, adapter = make_restart_runner() runner._draining = True runner._restart_requested = True runner._busy_input_mode = "queue" - event = MessageEvent(text="follow up", message_type=MessageType.TEXT, source=_source(), message_id="m2") + event = MessageEvent( + text="follow up", + message_type=MessageType.TEXT, + source=make_restart_source(), + message_id="m2", + ) session_key = build_session_key(event.source) adapter._active_sessions[session_key] = asyncio.Event() @@ -122,12 +59,102 @@ async def test_drain_queue_mode_queues_follow_up_without_interrupt(): @pytest.mark.asyncio async def test_draining_rejects_new_session_messages(): - runner, _adapter = _make_runner() + runner, _adapter = make_restart_runner() runner._draining = True runner._restart_requested = True - event = MessageEvent(text="hello", message_type=MessageType.TEXT, source=_source("fresh"), message_id="m3") + event = MessageEvent( + text="hello", + message_type=MessageType.TEXT, + source=make_restart_source("fresh"), + message_id="m3", + ) result = await runner._handle_message(event) assert result == "⏳ Gateway is restarting and is not accepting new work right now." + + +def test_load_busy_input_mode_prefers_env_then_config_then_default(tmp_path, monkeypatch): + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + monkeypatch.delenv("HERMES_GATEWAY_BUSY_INPUT_MODE", raising=False) + + assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt" + + (tmp_path / "config.yaml").write_text( + "display:\n busy_input_mode: queue\n", encoding="utf-8" + ) + assert gateway_run.GatewayRunner._load_busy_input_mode() == "queue" + + monkeypatch.setenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "interrupt") + assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt" + + +def test_load_restart_drain_timeout_prefers_env_then_config_then_default( + tmp_path, monkeypatch, caplog +): + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False) + + assert ( + gateway_run.GatewayRunner._load_restart_drain_timeout() + == DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT + ) + + (tmp_path / "config.yaml").write_text( + "agent:\n restart_drain_timeout: 12\n", encoding="utf-8" + ) + assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 12.0 + + monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "7") + assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 7.0 + + monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "invalid") + assert ( + gateway_run.GatewayRunner._load_restart_drain_timeout() + == DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT + ) + assert "Invalid restart_drain_timeout" in caplog.text + + +@pytest.mark.asyncio +async def test_request_restart_is_idempotent(): + runner, _adapter = make_restart_runner() + runner.stop = AsyncMock() + + assert runner.request_restart(detached=True, via_service=False) is True + first_task = next(iter(runner._background_tasks)) + assert runner.request_restart(detached=True, via_service=False) is False + + await first_task + + runner.stop.assert_awaited_once_with( + restart=True, detached_restart=True, service_restart=False + ) + + +@pytest.mark.asyncio +async def test_launch_detached_restart_command_uses_setsid(monkeypatch): + runner, _adapter = make_restart_runner() + popen_calls = [] + + monkeypatch.setattr(gateway_run, "_resolve_hermes_bin", lambda: ["/usr/bin/hermes"]) + monkeypatch.setattr(gateway_run.os, "getpid", lambda: 321) + monkeypatch.setattr(shutil, "which", lambda cmd: "/usr/bin/setsid" if cmd == "setsid" else None) + + def fake_popen(cmd, **kwargs): + popen_calls.append((cmd, kwargs)) + return MagicMock() + + monkeypatch.setattr(subprocess, "Popen", fake_popen) + + await runner._launch_detached_restart_command() + + assert len(popen_calls) == 1 + cmd, kwargs = popen_calls[0] + assert cmd[:2] == ["/usr/bin/setsid", "bash"] + assert "gateway restart" in cmd[-1] + assert "kill -0 321" in cmd[-1] + assert kwargs["start_new_session"] is True + assert kwargs["stdout"] is subprocess.DEVNULL + assert kwargs["stderr"] is subprocess.DEVNULL diff --git a/tests/hermes_cli/test_gateway_service.py b/tests/hermes_cli/test_gateway_service.py index 26919608d..c5d4cb4f5 100644 --- a/tests/hermes_cli/test_gateway_service.py +++ b/tests/hermes_cli/test_gateway_service.py @@ -5,6 +5,10 @@ from pathlib import Path from types import SimpleNamespace import hermes_cli.gateway as gateway_cli +from gateway.restart import ( + DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT, + GATEWAY_SERVICE_RESTART_EXIT_CODE, +) class TestSystemdServiceRefresh: @@ -85,7 +89,7 @@ class TestGeneratedSystemdUnits: assert "ExecStart=" in unit assert "ExecStop=" not in unit assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit - assert "RestartForceExitStatus=75" in unit + assert f"RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}" in unit assert "TimeoutStopSec=60" in unit def test_user_unit_includes_resolved_node_directory_in_path(self, monkeypatch): @@ -101,7 +105,7 @@ class TestGeneratedSystemdUnits: assert "ExecStart=" in unit assert "ExecStop=" not in unit assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit - assert "RestartForceExitStatus=75" in unit + assert f"RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}" in unit assert "TimeoutStopSec=60" in unit assert "WantedBy=multi-user.target" in unit @@ -161,6 +165,31 @@ class TestGatewayStopCleanup: class TestLaunchdServiceRecovery: + def test_get_restart_drain_timeout_prefers_env_then_config_then_default(self, monkeypatch): + monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False) + monkeypatch.setattr(gateway_cli, "read_raw_config", lambda: {}) + + assert ( + gateway_cli._get_restart_drain_timeout() + == DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT + ) + + monkeypatch.setattr( + gateway_cli, + "read_raw_config", + lambda: {"agent": {"restart_drain_timeout": 14}}, + ) + assert gateway_cli._get_restart_drain_timeout() == 14.0 + + monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "9") + assert gateway_cli._get_restart_drain_timeout() == 9.0 + + monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "invalid") + assert ( + gateway_cli._get_restart_drain_timeout() + == DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT + ) + def test_launchd_install_repairs_outdated_plist_without_force(self, tmp_path, monkeypatch): plist_path = tmp_path / "ai.hermes.gateway.plist" plist_path.write_text("old content", encoding="utf-8")