mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(gateway): address restart review feedback
This commit is contained in:
parent
a55c044ca8
commit
ecfae98152
8 changed files with 404 additions and 213 deletions
|
|
@ -673,6 +673,32 @@ class SendResult:
|
|||
retryable: bool = False # True for transient connection errors — base will retry automatically
|
||||
|
||||
|
||||
def merge_pending_message_event(
|
||||
pending_messages: Dict[str, MessageEvent],
|
||||
session_key: str,
|
||||
event: MessageEvent,
|
||||
) -> None:
|
||||
"""Store or merge a pending event for a session.
|
||||
|
||||
Photo bursts/albums often arrive as multiple near-simultaneous PHOTO
|
||||
events. Merge those into the existing queued event so the next turn sees
|
||||
the whole burst, while non-photo follow-ups still replace the pending
|
||||
event normally.
|
||||
"""
|
||||
existing = pending_messages.get(session_key)
|
||||
if (
|
||||
existing
|
||||
and getattr(existing, "message_type", None) == MessageType.PHOTO
|
||||
and event.message_type == MessageType.PHOTO
|
||||
):
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
|
||||
return
|
||||
pending_messages[session_key] = event
|
||||
|
||||
|
||||
# Error substrings that indicate a transient *connection* failure worth retrying.
|
||||
# "timeout" / "timed out" / "readtimeout" / "writetimeout" are intentionally
|
||||
# excluded: a read/write timeout on a non-idempotent call (e.g. send_message)
|
||||
|
|
@ -1432,14 +1458,7 @@ class BasePlatformAdapter(ABC):
|
|||
# then process them immediately after the current task finishes.
|
||||
if event.message_type == MessageType.PHOTO:
|
||||
logger.debug("[%s] Queuing photo follow-up for session %s without interrupt", self.name, session_key)
|
||||
existing = self._pending_messages.get(session_key)
|
||||
if existing and existing.message_type == MessageType.PHOTO:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
existing.text = self._merge_caption(existing.text, event.text)
|
||||
else:
|
||||
self._pending_messages[session_key] = event
|
||||
merge_pending_message_event(self._pending_messages, session_key, event)
|
||||
return # Don't interrupt now - will run after current task completes
|
||||
|
||||
# Default behavior for non-photo follow-ups: interrupt the running agent
|
||||
|
|
|
|||
20
gateway/restart.py
Normal file
20
gateway/restart.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
"""Shared gateway restart constants and parsing helpers."""
|
||||
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
|
||||
# EX_TEMPFAIL from sysexits.h — used to ask the service manager to restart
|
||||
# the gateway after a graceful drain/reload path completes.
|
||||
GATEWAY_SERVICE_RESTART_EXIT_CODE = 75
|
||||
|
||||
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT = float(
|
||||
DEFAULT_CONFIG["agent"]["restart_drain_timeout"]
|
||||
)
|
||||
|
||||
|
||||
def parse_restart_drain_timeout(raw: object) -> float:
|
||||
"""Parse a configured drain timeout, falling back to the shared default."""
|
||||
try:
|
||||
value = float(raw) if str(raw or "").strip() else DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
except (TypeError, ValueError):
|
||||
return DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
return max(0.0, value)
|
||||
|
|
@ -241,7 +241,17 @@ from gateway.session import (
|
|||
build_session_key,
|
||||
)
|
||||
from gateway.delivery import DeliveryRouter
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
merge_pending_message_event,
|
||||
)
|
||||
from gateway.restart import (
|
||||
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
|
||||
GATEWAY_SERVICE_RESTART_EXIT_CODE,
|
||||
parse_restart_drain_timeout,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_whatsapp_identifier(value: str) -> str:
|
||||
|
|
@ -478,7 +488,7 @@ class GatewayRunner:
|
|||
# blow up on attribute access.
|
||||
_running_agents_ts: Dict[str, float] = {}
|
||||
_busy_input_mode: str = "interrupt"
|
||||
_restart_drain_timeout: float = 60.0
|
||||
_restart_drain_timeout: float = DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
_exit_code: Optional[int] = None
|
||||
_draining: bool = False
|
||||
_restart_requested: bool = False
|
||||
|
|
@ -486,6 +496,7 @@ class GatewayRunner:
|
|||
_restart_detached: bool = False
|
||||
_restart_via_service: bool = False
|
||||
_stop_task: Optional[asyncio.Task] = None
|
||||
_session_model_overrides: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
def __init__(self, config: Optional[GatewayConfig] = None):
|
||||
self.config = config or load_gateway_config()
|
||||
|
|
@ -1076,12 +1087,17 @@ class GatewayRunner:
|
|||
raw = str(cfg.get("agent", {}).get("restart_drain_timeout", "") or "").strip()
|
||||
except Exception:
|
||||
pass
|
||||
value = parse_restart_drain_timeout(raw)
|
||||
if raw and value == DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT:
|
||||
try:
|
||||
value = float(raw) if raw else 60.0
|
||||
except ValueError:
|
||||
logger.warning("Invalid restart_drain_timeout '%s', using default 60s", raw)
|
||||
return 60.0
|
||||
return max(0.0, value)
|
||||
float(raw)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"Invalid restart_drain_timeout '%s', using default %.0fs",
|
||||
raw,
|
||||
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
|
||||
)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _load_background_notifications_mode() -> str:
|
||||
|
|
@ -1178,14 +1194,7 @@ class GatewayRunner:
|
|||
adapter = self.adapters.get(event.source.platform)
|
||||
if not adapter:
|
||||
return
|
||||
existing = adapter._pending_messages.get(session_key)
|
||||
if existing and getattr(existing, "message_type", None) == MessageType.PHOTO and event.message_type == MessageType.PHOTO:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
|
||||
return
|
||||
adapter._pending_messages[session_key] = event
|
||||
merge_pending_message_event(adapter._pending_messages, session_key, event)
|
||||
|
||||
async def _handle_active_session_busy_message(self, event: MessageEvent, session_key: str) -> bool:
|
||||
if not self._draining:
|
||||
|
|
@ -1212,20 +1221,32 @@ class GatewayRunner:
|
|||
|
||||
async def _drain_active_agents(self, timeout: float) -> tuple[Dict[str, Any], bool]:
|
||||
snapshot = self._snapshot_running_agents()
|
||||
if not self._running_agents:
|
||||
last_active_count = self._running_agent_count()
|
||||
last_status_at = 0.0
|
||||
|
||||
def _maybe_update_status(force: bool = False) -> None:
|
||||
nonlocal last_active_count, last_status_at
|
||||
now = asyncio.get_running_loop().time()
|
||||
active_count = self._running_agent_count()
|
||||
if force or active_count != last_active_count or (now - last_status_at) >= 1.0:
|
||||
self._update_runtime_status("draining")
|
||||
last_active_count = active_count
|
||||
last_status_at = now
|
||||
|
||||
if not self._running_agents:
|
||||
_maybe_update_status(force=True)
|
||||
return snapshot, False
|
||||
|
||||
self._update_runtime_status("draining")
|
||||
_maybe_update_status(force=True)
|
||||
if timeout <= 0:
|
||||
return snapshot, True
|
||||
|
||||
deadline = asyncio.get_running_loop().time() + timeout
|
||||
while self._running_agents and asyncio.get_running_loop().time() < deadline:
|
||||
self._update_runtime_status("draining")
|
||||
_maybe_update_status()
|
||||
await asyncio.sleep(0.1)
|
||||
timed_out = bool(self._running_agents)
|
||||
self._update_runtime_status("draining")
|
||||
_maybe_update_status(force=True)
|
||||
return snapshot, timed_out
|
||||
|
||||
def _interrupt_running_agents(self, reason: str) -> None:
|
||||
|
|
@ -1841,7 +1862,7 @@ class GatewayRunner:
|
|||
remove_pid_file()
|
||||
|
||||
if self._restart_requested and self._restart_via_service:
|
||||
self._exit_code = 75
|
||||
self._exit_code = GATEWAY_SERVICE_RESTART_EXIT_CODE
|
||||
self._exit_reason = self._exit_reason or "Gateway restart requested"
|
||||
|
||||
self._draining = False
|
||||
|
|
@ -2338,18 +2359,7 @@ class GatewayRunner:
|
|||
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter:
|
||||
# Reuse adapter queue semantics so photo bursts merge cleanly.
|
||||
if _quick_key in adapter._pending_messages:
|
||||
existing = adapter._pending_messages[_quick_key]
|
||||
if getattr(existing, "message_type", None) == MessageType.PHOTO:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
|
||||
else:
|
||||
adapter._pending_messages[_quick_key] = event
|
||||
else:
|
||||
adapter._pending_messages[_quick_key] = event
|
||||
merge_pending_message_event(adapter._pending_messages, _quick_key, event)
|
||||
return None
|
||||
|
||||
running_agent = self._running_agents.get(_quick_key)
|
||||
|
|
@ -3951,7 +3961,7 @@ class GatewayRunner:
|
|||
# Check for session override
|
||||
source = event.source
|
||||
session_key = self._session_key_for_source(source)
|
||||
override = getattr(self, "_session_model_overrides", {}).get(session_key, {})
|
||||
override = self._session_model_overrides.get(session_key, {})
|
||||
if override:
|
||||
current_model = override.get("model", current_model)
|
||||
current_provider = override.get("provider", current_provider)
|
||||
|
|
@ -4033,8 +4043,6 @@ class GatewayRunner:
|
|||
f"via {result.provider_label or result.target_provider}. "
|
||||
f"Adjust your self-identification accordingly.]"
|
||||
)
|
||||
if not hasattr(_self, "_session_model_overrides"):
|
||||
_self._session_model_overrides = {}
|
||||
_self._session_model_overrides[_session_key] = {
|
||||
"model": result.new_model,
|
||||
"provider": result.target_provider,
|
||||
|
|
@ -4148,8 +4156,6 @@ class GatewayRunner:
|
|||
)
|
||||
|
||||
# Store session override so next agent creation uses the new model
|
||||
if not hasattr(self, "_session_model_overrides"):
|
||||
self._session_model_overrides = {}
|
||||
self._session_model_overrides[session_key] = {
|
||||
"model": result.new_model,
|
||||
"provider": result.target_provider,
|
||||
|
|
@ -6828,7 +6834,7 @@ class GatewayRunner:
|
|||
subsequent messages. Fields with ``None`` values are skipped so
|
||||
partial overrides don't clobber valid config defaults.
|
||||
"""
|
||||
override = getattr(self, "_session_model_overrides", {}).get(session_key)
|
||||
override = self._session_model_overrides.get(session_key)
|
||||
if not override:
|
||||
return model, runtime_kwargs
|
||||
model = override.get("model", model)
|
||||
|
|
@ -6840,7 +6846,7 @@ class GatewayRunner:
|
|||
|
||||
def _is_intentional_model_switch(self, session_key: str, agent_model: str) -> bool:
|
||||
"""Return True if *agent_model* matches an active /model session override."""
|
||||
override = getattr(self, "_session_model_overrides", {}).get(session_key)
|
||||
override = self._session_model_overrides.get(session_key)
|
||||
return override is not None and override.get("model") == agent_model
|
||||
|
||||
def _evict_cached_agent(self, session_key: str) -> None:
|
||||
|
|
|
|||
|
|
@ -15,8 +15,12 @@ from pathlib import Path
|
|||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
from gateway.status import terminate_pid
|
||||
from gateway.restart import (
|
||||
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
|
||||
GATEWAY_SERVICE_RESTART_EXIT_CODE,
|
||||
parse_restart_drain_timeout,
|
||||
)
|
||||
from hermes_cli.config import (
|
||||
DEFAULT_CONFIG,
|
||||
get_env_value,
|
||||
get_hermes_home,
|
||||
is_managed,
|
||||
|
|
@ -787,7 +791,7 @@ Environment="VIRTUAL_ENV={venv_dir}"
|
|||
Environment="HERMES_HOME={hermes_home}"
|
||||
Restart=on-failure
|
||||
RestartSec=30
|
||||
RestartForceExitStatus=75
|
||||
RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}
|
||||
KillMode=mixed
|
||||
KillSignal=SIGTERM
|
||||
ExecReload=/bin/kill -USR1 $MAINPID
|
||||
|
|
@ -819,7 +823,7 @@ Environment="VIRTUAL_ENV={venv_dir}"
|
|||
Environment="HERMES_HOME={hermes_home}"
|
||||
Restart=on-failure
|
||||
RestartSec=30
|
||||
RestartForceExitStatus=75
|
||||
RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}
|
||||
KillMode=mixed
|
||||
KillSignal=SIGTERM
|
||||
ExecReload=/bin/kill -USR1 $MAINPID
|
||||
|
|
@ -932,11 +936,12 @@ def _get_restart_drain_timeout() -> float:
|
|||
if not raw:
|
||||
cfg = read_raw_config()
|
||||
agent_cfg = cfg.get("agent", {}) if isinstance(cfg, dict) else {}
|
||||
raw = str(agent_cfg.get("restart_drain_timeout", DEFAULT_CONFIG["agent"]["restart_drain_timeout"]))
|
||||
try:
|
||||
return max(0.0, float(raw))
|
||||
except (TypeError, ValueError):
|
||||
return float(DEFAULT_CONFIG["agent"]["restart_drain_timeout"])
|
||||
raw = str(
|
||||
agent_cfg.get(
|
||||
"restart_drain_timeout", DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
)
|
||||
)
|
||||
return parse_restart_drain_timeout(raw)
|
||||
|
||||
|
||||
def systemd_install(force: bool = False, system: bool = False, run_as_user: str | None = None):
|
||||
|
|
|
|||
110
tests/gateway/restart_test_helpers.py
Normal file
110
tests/gateway/restart_test_helpers.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
class RestartTestAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
||||
self.sent: list[str] = []
|
||||
|
||||
async def connect(self):
|
||||
return True
|
||||
|
||||
async def disconnect(self):
|
||||
return None
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
self.sent.append(content)
|
||||
return SendResult(success=True, message_id="1")
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None):
|
||||
return None
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
def make_restart_source(chat_id: str = "123456", chat_type: str = "dm") -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
chat_type=chat_type,
|
||||
)
|
||||
|
||||
|
||||
def make_restart_runner(
|
||||
adapter: BasePlatformAdapter | None = None,
|
||||
) -> tuple[GatewayRunner, BasePlatformAdapter]:
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
runner._running = True
|
||||
runner._shutdown_event = asyncio.Event()
|
||||
runner._exit_reason = None
|
||||
runner._exit_code = None
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._pending_model_notes = {}
|
||||
runner._background_tasks = set()
|
||||
runner._draining = False
|
||||
runner._restart_requested = False
|
||||
runner._restart_task_started = False
|
||||
runner._restart_detached = False
|
||||
runner._restart_via_service = False
|
||||
runner._restart_drain_timeout = DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
runner._stop_task = None
|
||||
runner._busy_input_mode = "interrupt"
|
||||
runner._update_prompt_pending = {}
|
||||
runner._voice_mode = {}
|
||||
runner._session_model_overrides = {}
|
||||
runner._shutdown_all_gateway_honcho = lambda: None
|
||||
runner._update_runtime_status = MagicMock()
|
||||
runner._queue_or_replace_pending_event = GatewayRunner._queue_or_replace_pending_event.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner._session_key_for_source = GatewayRunner._session_key_for_source.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner._handle_active_session_busy_message = (
|
||||
GatewayRunner._handle_active_session_busy_message.__get__(runner, GatewayRunner)
|
||||
)
|
||||
runner._handle_restart_command = GatewayRunner._handle_restart_command.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner._status_action_label = GatewayRunner._status_action_label.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner._status_action_gerund = GatewayRunner._status_action_gerund.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner._queue_during_drain_enabled = GatewayRunner._queue_during_drain_enabled.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner._running_agent_count = GatewayRunner._running_agent_count.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner._launch_detached_restart_command = GatewayRunner._launch_detached_restart_command.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner.request_restart = GatewayRunner.request_restart.__get__(runner, GatewayRunner)
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
runner.pairing_store = MagicMock()
|
||||
runner.session_store = MagicMock()
|
||||
runner.delivery_router = MagicMock()
|
||||
|
||||
platform_adapter = adapter or RestartTestAdapter()
|
||||
platform_adapter.set_message_handler(AsyncMock(return_value=None))
|
||||
platform_adapter.set_busy_session_handler(runner._handle_active_session_busy_message)
|
||||
runner.adapters = {Platform.TELEGRAM: platform_adapter}
|
||||
return runner, platform_adapter
|
||||
|
|
@ -3,67 +3,15 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
class StubAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self):
|
||||
return True
|
||||
|
||||
async def disconnect(self):
|
||||
return None
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
return SendResult(success=True, message_id="1")
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None):
|
||||
return None
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
def _source(chat_id="123456", chat_type="dm"):
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
chat_type=chat_type,
|
||||
)
|
||||
|
||||
|
||||
def _make_runner() -> GatewayRunner:
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
|
||||
runner._running = True
|
||||
runner._shutdown_event = asyncio.Event()
|
||||
runner._exit_reason = None
|
||||
runner._exit_code = None
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._background_tasks = set()
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._draining = False
|
||||
runner._restart_requested = False
|
||||
runner._restart_task_started = False
|
||||
runner._restart_detached = False
|
||||
runner._restart_via_service = False
|
||||
runner._restart_drain_timeout = 60.0
|
||||
runner._stop_task = None
|
||||
runner._shutdown_all_gateway_honcho = lambda: None
|
||||
runner._update_runtime_status = MagicMock()
|
||||
return runner
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.restart import GATEWAY_SERVICE_RESTART_EXIT_CODE
|
||||
from gateway.session import build_session_key
|
||||
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
||||
adapter = StubAdapter()
|
||||
_runner, adapter = make_restart_runner()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def block_forever(_event):
|
||||
|
|
@ -71,7 +19,7 @@ async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
|||
return None
|
||||
|
||||
adapter.set_message_handler(block_forever)
|
||||
event = MessageEvent(text="work", source=_source(), message_id="1")
|
||||
event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
|
||||
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0)
|
||||
|
|
@ -89,12 +37,11 @@ async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks():
|
||||
runner = _make_runner()
|
||||
runner, adapter = make_restart_runner()
|
||||
runner._pending_messages = {"session": "pending text"}
|
||||
runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}}
|
||||
runner._restart_drain_timeout = 0.0
|
||||
|
||||
adapter = StubAdapter()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def block_forever(_event):
|
||||
|
|
@ -102,7 +49,7 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
|
|||
return None
|
||||
|
||||
adapter.set_message_handler(block_forever)
|
||||
event = MessageEvent(text="work", source=_source(), message_id="1")
|
||||
event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
|
@ -112,7 +59,6 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
|
|||
session_key = build_session_key(event.source)
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents = {session_key: running_agent}
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
||||
await runner.stop()
|
||||
|
|
@ -128,11 +74,9 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_stop_drains_running_agents_before_disconnect():
|
||||
runner = _make_runner()
|
||||
adapter = StubAdapter()
|
||||
runner, adapter = make_restart_runner()
|
||||
disconnect_mock = AsyncMock()
|
||||
adapter.disconnect = disconnect_mock
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents = {"session": running_agent}
|
||||
|
|
@ -153,13 +97,11 @@ async def test_gateway_stop_drains_running_agents_before_disconnect():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_stop_interrupts_after_drain_timeout():
|
||||
runner = _make_runner()
|
||||
runner, adapter = make_restart_runner()
|
||||
runner._restart_drain_timeout = 0.05
|
||||
|
||||
adapter = StubAdapter()
|
||||
disconnect_mock = AsyncMock()
|
||||
adapter.disconnect = disconnect_mock
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents = {"session": running_agent}
|
||||
|
|
@ -170,3 +112,36 @@ async def test_gateway_stop_interrupts_after_drain_timeout():
|
|||
running_agent.interrupt.assert_called_once_with("Gateway shutting down")
|
||||
disconnect_mock.assert_awaited_once()
|
||||
assert runner._shutdown_event.is_set() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_stop_service_restart_sets_named_exit_code():
|
||||
runner, adapter = make_restart_runner()
|
||||
adapter.disconnect = AsyncMock()
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
||||
await runner.stop(restart=True, service_restart=True)
|
||||
|
||||
assert runner._exit_code == GATEWAY_SERVICE_RESTART_EXIT_CODE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_active_agents_throttles_status_updates():
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner._update_runtime_status = MagicMock()
|
||||
|
||||
runner._running_agents = {"a": MagicMock(), "b": MagicMock()}
|
||||
|
||||
async def finish_agents():
|
||||
await asyncio.sleep(0.12)
|
||||
runner._running_agents.pop("a")
|
||||
await asyncio.sleep(0.12)
|
||||
runner._running_agents.clear()
|
||||
|
||||
task = asyncio.create_task(finish_agents())
|
||||
await runner._drain_active_agents(1.0)
|
||||
await task
|
||||
|
||||
# Start, one count-change update, and final update. Allow one extra update
|
||||
# if the loop observes the zero-agent state before exiting.
|
||||
assert 3 <= runner._update_runtime_status.call_count <= 4
|
||||
|
|
|
|||
|
|
@ -1,95 +1,27 @@
|
|||
import asyncio
|
||||
import shutil
|
||||
import subprocess
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
class RecordingAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
||||
self.sent: list[str] = []
|
||||
|
||||
async def connect(self):
|
||||
return True
|
||||
|
||||
async def disconnect(self):
|
||||
return None
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
self.sent.append(content)
|
||||
return SendResult(success=True, message_id="1")
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None):
|
||||
return None
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
def _source(chat_id="123456"):
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def _make_runner() -> tuple[GatewayRunner, RecordingAdapter]:
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
|
||||
runner.adapters = {}
|
||||
runner._running = True
|
||||
runner._shutdown_event = asyncio.Event()
|
||||
runner._exit_reason = None
|
||||
runner._exit_code = None
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._background_tasks = set()
|
||||
runner._draining = False
|
||||
runner._restart_requested = False
|
||||
runner._restart_task_started = False
|
||||
runner._restart_detached = False
|
||||
runner._restart_via_service = False
|
||||
runner._restart_drain_timeout = 60.0
|
||||
runner._stop_task = None
|
||||
runner._busy_input_mode = "interrupt"
|
||||
runner._update_prompt_pending = {}
|
||||
runner._voice_mode = {}
|
||||
runner._update_runtime_status = MagicMock()
|
||||
runner._queue_or_replace_pending_event = GatewayRunner._queue_or_replace_pending_event.__get__(runner, GatewayRunner)
|
||||
runner._session_key_for_source = GatewayRunner._session_key_for_source.__get__(runner, GatewayRunner)
|
||||
runner._handle_active_session_busy_message = GatewayRunner._handle_active_session_busy_message.__get__(runner, GatewayRunner)
|
||||
runner._handle_restart_command = GatewayRunner._handle_restart_command.__get__(runner, GatewayRunner)
|
||||
runner._status_action_label = GatewayRunner._status_action_label.__get__(runner, GatewayRunner)
|
||||
runner._status_action_gerund = GatewayRunner._status_action_gerund.__get__(runner, GatewayRunner)
|
||||
runner._queue_during_drain_enabled = GatewayRunner._queue_during_drain_enabled.__get__(runner, GatewayRunner)
|
||||
runner._running_agent_count = GatewayRunner._running_agent_count.__get__(runner, GatewayRunner)
|
||||
runner.request_restart = MagicMock(return_value=True)
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
runner.pairing_store = MagicMock()
|
||||
runner.session_store = MagicMock()
|
||||
runner.delivery_router = MagicMock()
|
||||
|
||||
adapter = RecordingAdapter()
|
||||
adapter.set_message_handler(AsyncMock(return_value=None))
|
||||
adapter.set_busy_session_handler(runner._handle_active_session_busy_message)
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
return runner, adapter
|
||||
import gateway.run as gateway_run
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
from gateway.session import build_session_key
|
||||
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restart_command_while_busy_requests_drain_without_interrupt():
|
||||
runner, _adapter = _make_runner()
|
||||
event = MessageEvent(text="/restart", message_type=MessageType.TEXT, source=_source(), message_id="m1")
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner.request_restart = MagicMock(return_value=True)
|
||||
event = MessageEvent(
|
||||
text="/restart",
|
||||
message_type=MessageType.TEXT,
|
||||
source=make_restart_source(),
|
||||
message_id="m1",
|
||||
)
|
||||
session_key = build_session_key(event.source)
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents[session_key] = running_agent
|
||||
|
|
@ -103,12 +35,17 @@ async def test_restart_command_while_busy_requests_drain_without_interrupt():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_queue_mode_queues_follow_up_without_interrupt():
|
||||
runner, adapter = _make_runner()
|
||||
runner, adapter = make_restart_runner()
|
||||
runner._draining = True
|
||||
runner._restart_requested = True
|
||||
runner._busy_input_mode = "queue"
|
||||
|
||||
event = MessageEvent(text="follow up", message_type=MessageType.TEXT, source=_source(), message_id="m2")
|
||||
event = MessageEvent(
|
||||
text="follow up",
|
||||
message_type=MessageType.TEXT,
|
||||
source=make_restart_source(),
|
||||
message_id="m2",
|
||||
)
|
||||
session_key = build_session_key(event.source)
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
|
|
@ -122,12 +59,102 @@ async def test_drain_queue_mode_queues_follow_up_without_interrupt():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_draining_rejects_new_session_messages():
|
||||
runner, _adapter = _make_runner()
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner._draining = True
|
||||
runner._restart_requested = True
|
||||
|
||||
event = MessageEvent(text="hello", message_type=MessageType.TEXT, source=_source("fresh"), message_id="m3")
|
||||
event = MessageEvent(
|
||||
text="hello",
|
||||
message_type=MessageType.TEXT,
|
||||
source=make_restart_source("fresh"),
|
||||
message_id="m3",
|
||||
)
|
||||
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
assert result == "⏳ Gateway is restarting and is not accepting new work right now."
|
||||
|
||||
|
||||
def test_load_busy_input_mode_prefers_env_then_config_then_default(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_BUSY_INPUT_MODE", raising=False)
|
||||
|
||||
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
|
||||
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
"display:\n busy_input_mode: queue\n", encoding="utf-8"
|
||||
)
|
||||
assert gateway_run.GatewayRunner._load_busy_input_mode() == "queue"
|
||||
|
||||
monkeypatch.setenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "interrupt")
|
||||
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
|
||||
|
||||
|
||||
def test_load_restart_drain_timeout_prefers_env_then_config_then_default(
|
||||
tmp_path, monkeypatch, caplog
|
||||
):
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
|
||||
|
||||
assert (
|
||||
gateway_run.GatewayRunner._load_restart_drain_timeout()
|
||||
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
)
|
||||
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
"agent:\n restart_drain_timeout: 12\n", encoding="utf-8"
|
||||
)
|
||||
assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 12.0
|
||||
|
||||
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "7")
|
||||
assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 7.0
|
||||
|
||||
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "invalid")
|
||||
assert (
|
||||
gateway_run.GatewayRunner._load_restart_drain_timeout()
|
||||
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
)
|
||||
assert "Invalid restart_drain_timeout" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_restart_is_idempotent():
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner.stop = AsyncMock()
|
||||
|
||||
assert runner.request_restart(detached=True, via_service=False) is True
|
||||
first_task = next(iter(runner._background_tasks))
|
||||
assert runner.request_restart(detached=True, via_service=False) is False
|
||||
|
||||
await first_task
|
||||
|
||||
runner.stop.assert_awaited_once_with(
|
||||
restart=True, detached_restart=True, service_restart=False
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_launch_detached_restart_command_uses_setsid(monkeypatch):
|
||||
runner, _adapter = make_restart_runner()
|
||||
popen_calls = []
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_resolve_hermes_bin", lambda: ["/usr/bin/hermes"])
|
||||
monkeypatch.setattr(gateway_run.os, "getpid", lambda: 321)
|
||||
monkeypatch.setattr(shutil, "which", lambda cmd: "/usr/bin/setsid" if cmd == "setsid" else None)
|
||||
|
||||
def fake_popen(cmd, **kwargs):
|
||||
popen_calls.append((cmd, kwargs))
|
||||
return MagicMock()
|
||||
|
||||
monkeypatch.setattr(subprocess, "Popen", fake_popen)
|
||||
|
||||
await runner._launch_detached_restart_command()
|
||||
|
||||
assert len(popen_calls) == 1
|
||||
cmd, kwargs = popen_calls[0]
|
||||
assert cmd[:2] == ["/usr/bin/setsid", "bash"]
|
||||
assert "gateway restart" in cmd[-1]
|
||||
assert "kill -0 321" in cmd[-1]
|
||||
assert kwargs["start_new_session"] is True
|
||||
assert kwargs["stdout"] is subprocess.DEVNULL
|
||||
assert kwargs["stderr"] is subprocess.DEVNULL
|
||||
|
|
|
|||
|
|
@ -5,6 +5,10 @@ from pathlib import Path
|
|||
from types import SimpleNamespace
|
||||
|
||||
import hermes_cli.gateway as gateway_cli
|
||||
from gateway.restart import (
|
||||
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
|
||||
GATEWAY_SERVICE_RESTART_EXIT_CODE,
|
||||
)
|
||||
|
||||
|
||||
class TestSystemdServiceRefresh:
|
||||
|
|
@ -85,7 +89,7 @@ class TestGeneratedSystemdUnits:
|
|||
assert "ExecStart=" in unit
|
||||
assert "ExecStop=" not in unit
|
||||
assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit
|
||||
assert "RestartForceExitStatus=75" in unit
|
||||
assert f"RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}" in unit
|
||||
assert "TimeoutStopSec=60" in unit
|
||||
|
||||
def test_user_unit_includes_resolved_node_directory_in_path(self, monkeypatch):
|
||||
|
|
@ -101,7 +105,7 @@ class TestGeneratedSystemdUnits:
|
|||
assert "ExecStart=" in unit
|
||||
assert "ExecStop=" not in unit
|
||||
assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit
|
||||
assert "RestartForceExitStatus=75" in unit
|
||||
assert f"RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}" in unit
|
||||
assert "TimeoutStopSec=60" in unit
|
||||
assert "WantedBy=multi-user.target" in unit
|
||||
|
||||
|
|
@ -161,6 +165,31 @@ class TestGatewayStopCleanup:
|
|||
|
||||
|
||||
class TestLaunchdServiceRecovery:
|
||||
def test_get_restart_drain_timeout_prefers_env_then_config_then_default(self, monkeypatch):
|
||||
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
|
||||
monkeypatch.setattr(gateway_cli, "read_raw_config", lambda: {})
|
||||
|
||||
assert (
|
||||
gateway_cli._get_restart_drain_timeout()
|
||||
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"read_raw_config",
|
||||
lambda: {"agent": {"restart_drain_timeout": 14}},
|
||||
)
|
||||
assert gateway_cli._get_restart_drain_timeout() == 14.0
|
||||
|
||||
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "9")
|
||||
assert gateway_cli._get_restart_drain_timeout() == 9.0
|
||||
|
||||
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "invalid")
|
||||
assert (
|
||||
gateway_cli._get_restart_drain_timeout()
|
||||
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
)
|
||||
|
||||
def test_launchd_install_repairs_outdated_plist_without_force(self, tmp_path, monkeypatch):
|
||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||
plist_path.write_text("<plist>old content</plist>", encoding="utf-8")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue