mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(gateway): address restart review feedback
This commit is contained in:
parent
a55c044ca8
commit
ecfae98152
8 changed files with 404 additions and 213 deletions
|
|
@ -673,6 +673,32 @@ class SendResult:
|
||||||
retryable: bool = False # True for transient connection errors — base will retry automatically
|
retryable: bool = False # True for transient connection errors — base will retry automatically
|
||||||
|
|
||||||
|
|
||||||
|
def merge_pending_message_event(
|
||||||
|
pending_messages: Dict[str, MessageEvent],
|
||||||
|
session_key: str,
|
||||||
|
event: MessageEvent,
|
||||||
|
) -> None:
|
||||||
|
"""Store or merge a pending event for a session.
|
||||||
|
|
||||||
|
Photo bursts/albums often arrive as multiple near-simultaneous PHOTO
|
||||||
|
events. Merge those into the existing queued event so the next turn sees
|
||||||
|
the whole burst, while non-photo follow-ups still replace the pending
|
||||||
|
event normally.
|
||||||
|
"""
|
||||||
|
existing = pending_messages.get(session_key)
|
||||||
|
if (
|
||||||
|
existing
|
||||||
|
and getattr(existing, "message_type", None) == MessageType.PHOTO
|
||||||
|
and event.message_type == MessageType.PHOTO
|
||||||
|
):
|
||||||
|
existing.media_urls.extend(event.media_urls)
|
||||||
|
existing.media_types.extend(event.media_types)
|
||||||
|
if event.text:
|
||||||
|
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
|
||||||
|
return
|
||||||
|
pending_messages[session_key] = event
|
||||||
|
|
||||||
|
|
||||||
# Error substrings that indicate a transient *connection* failure worth retrying.
|
# Error substrings that indicate a transient *connection* failure worth retrying.
|
||||||
# "timeout" / "timed out" / "readtimeout" / "writetimeout" are intentionally
|
# "timeout" / "timed out" / "readtimeout" / "writetimeout" are intentionally
|
||||||
# excluded: a read/write timeout on a non-idempotent call (e.g. send_message)
|
# excluded: a read/write timeout on a non-idempotent call (e.g. send_message)
|
||||||
|
|
@ -1432,14 +1458,7 @@ class BasePlatformAdapter(ABC):
|
||||||
# then process them immediately after the current task finishes.
|
# then process them immediately after the current task finishes.
|
||||||
if event.message_type == MessageType.PHOTO:
|
if event.message_type == MessageType.PHOTO:
|
||||||
logger.debug("[%s] Queuing photo follow-up for session %s without interrupt", self.name, session_key)
|
logger.debug("[%s] Queuing photo follow-up for session %s without interrupt", self.name, session_key)
|
||||||
existing = self._pending_messages.get(session_key)
|
merge_pending_message_event(self._pending_messages, session_key, event)
|
||||||
if existing and existing.message_type == MessageType.PHOTO:
|
|
||||||
existing.media_urls.extend(event.media_urls)
|
|
||||||
existing.media_types.extend(event.media_types)
|
|
||||||
if event.text:
|
|
||||||
existing.text = self._merge_caption(existing.text, event.text)
|
|
||||||
else:
|
|
||||||
self._pending_messages[session_key] = event
|
|
||||||
return # Don't interrupt now - will run after current task completes
|
return # Don't interrupt now - will run after current task completes
|
||||||
|
|
||||||
# Default behavior for non-photo follow-ups: interrupt the running agent
|
# Default behavior for non-photo follow-ups: interrupt the running agent
|
||||||
|
|
|
||||||
20
gateway/restart.py
Normal file
20
gateway/restart.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
"""Shared gateway restart constants and parsing helpers."""
|
||||||
|
|
||||||
|
from hermes_cli.config import DEFAULT_CONFIG
|
||||||
|
|
||||||
|
# EX_TEMPFAIL from sysexits.h — used to ask the service manager to restart
|
||||||
|
# the gateway after a graceful drain/reload path completes.
|
||||||
|
GATEWAY_SERVICE_RESTART_EXIT_CODE = 75
|
||||||
|
|
||||||
|
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT = float(
|
||||||
|
DEFAULT_CONFIG["agent"]["restart_drain_timeout"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_restart_drain_timeout(raw: object) -> float:
|
||||||
|
"""Parse a configured drain timeout, falling back to the shared default."""
|
||||||
|
try:
|
||||||
|
value = float(raw) if str(raw or "").strip() else DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
return max(0.0, value)
|
||||||
|
|
@ -241,7 +241,17 @@ from gateway.session import (
|
||||||
build_session_key,
|
build_session_key,
|
||||||
)
|
)
|
||||||
from gateway.delivery import DeliveryRouter
|
from gateway.delivery import DeliveryRouter
|
||||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType
|
from gateway.platforms.base import (
|
||||||
|
BasePlatformAdapter,
|
||||||
|
MessageEvent,
|
||||||
|
MessageType,
|
||||||
|
merge_pending_message_event,
|
||||||
|
)
|
||||||
|
from gateway.restart import (
|
||||||
|
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
|
||||||
|
GATEWAY_SERVICE_RESTART_EXIT_CODE,
|
||||||
|
parse_restart_drain_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _normalize_whatsapp_identifier(value: str) -> str:
|
def _normalize_whatsapp_identifier(value: str) -> str:
|
||||||
|
|
@ -478,7 +488,7 @@ class GatewayRunner:
|
||||||
# blow up on attribute access.
|
# blow up on attribute access.
|
||||||
_running_agents_ts: Dict[str, float] = {}
|
_running_agents_ts: Dict[str, float] = {}
|
||||||
_busy_input_mode: str = "interrupt"
|
_busy_input_mode: str = "interrupt"
|
||||||
_restart_drain_timeout: float = 60.0
|
_restart_drain_timeout: float = DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
_exit_code: Optional[int] = None
|
_exit_code: Optional[int] = None
|
||||||
_draining: bool = False
|
_draining: bool = False
|
||||||
_restart_requested: bool = False
|
_restart_requested: bool = False
|
||||||
|
|
@ -486,6 +496,7 @@ class GatewayRunner:
|
||||||
_restart_detached: bool = False
|
_restart_detached: bool = False
|
||||||
_restart_via_service: bool = False
|
_restart_via_service: bool = False
|
||||||
_stop_task: Optional[asyncio.Task] = None
|
_stop_task: Optional[asyncio.Task] = None
|
||||||
|
_session_model_overrides: Dict[str, Dict[str, str]] = {}
|
||||||
|
|
||||||
def __init__(self, config: Optional[GatewayConfig] = None):
|
def __init__(self, config: Optional[GatewayConfig] = None):
|
||||||
self.config = config or load_gateway_config()
|
self.config = config or load_gateway_config()
|
||||||
|
|
@ -1076,12 +1087,17 @@ class GatewayRunner:
|
||||||
raw = str(cfg.get("agent", {}).get("restart_drain_timeout", "") or "").strip()
|
raw = str(cfg.get("agent", {}).get("restart_drain_timeout", "") or "").strip()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
try:
|
value = parse_restart_drain_timeout(raw)
|
||||||
value = float(raw) if raw else 60.0
|
if raw and value == DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT:
|
||||||
except ValueError:
|
try:
|
||||||
logger.warning("Invalid restart_drain_timeout '%s', using default 60s", raw)
|
float(raw)
|
||||||
return 60.0
|
except (TypeError, ValueError):
|
||||||
return max(0.0, value)
|
logger.warning(
|
||||||
|
"Invalid restart_drain_timeout '%s', using default %.0fs",
|
||||||
|
raw,
|
||||||
|
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_background_notifications_mode() -> str:
|
def _load_background_notifications_mode() -> str:
|
||||||
|
|
@ -1178,14 +1194,7 @@ class GatewayRunner:
|
||||||
adapter = self.adapters.get(event.source.platform)
|
adapter = self.adapters.get(event.source.platform)
|
||||||
if not adapter:
|
if not adapter:
|
||||||
return
|
return
|
||||||
existing = adapter._pending_messages.get(session_key)
|
merge_pending_message_event(adapter._pending_messages, session_key, event)
|
||||||
if existing and getattr(existing, "message_type", None) == MessageType.PHOTO and event.message_type == MessageType.PHOTO:
|
|
||||||
existing.media_urls.extend(event.media_urls)
|
|
||||||
existing.media_types.extend(event.media_types)
|
|
||||||
if event.text:
|
|
||||||
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
|
|
||||||
return
|
|
||||||
adapter._pending_messages[session_key] = event
|
|
||||||
|
|
||||||
async def _handle_active_session_busy_message(self, event: MessageEvent, session_key: str) -> bool:
|
async def _handle_active_session_busy_message(self, event: MessageEvent, session_key: str) -> bool:
|
||||||
if not self._draining:
|
if not self._draining:
|
||||||
|
|
@ -1212,20 +1221,32 @@ class GatewayRunner:
|
||||||
|
|
||||||
async def _drain_active_agents(self, timeout: float) -> tuple[Dict[str, Any], bool]:
|
async def _drain_active_agents(self, timeout: float) -> tuple[Dict[str, Any], bool]:
|
||||||
snapshot = self._snapshot_running_agents()
|
snapshot = self._snapshot_running_agents()
|
||||||
|
last_active_count = self._running_agent_count()
|
||||||
|
last_status_at = 0.0
|
||||||
|
|
||||||
|
def _maybe_update_status(force: bool = False) -> None:
|
||||||
|
nonlocal last_active_count, last_status_at
|
||||||
|
now = asyncio.get_running_loop().time()
|
||||||
|
active_count = self._running_agent_count()
|
||||||
|
if force or active_count != last_active_count or (now - last_status_at) >= 1.0:
|
||||||
|
self._update_runtime_status("draining")
|
||||||
|
last_active_count = active_count
|
||||||
|
last_status_at = now
|
||||||
|
|
||||||
if not self._running_agents:
|
if not self._running_agents:
|
||||||
self._update_runtime_status("draining")
|
_maybe_update_status(force=True)
|
||||||
return snapshot, False
|
return snapshot, False
|
||||||
|
|
||||||
self._update_runtime_status("draining")
|
_maybe_update_status(force=True)
|
||||||
if timeout <= 0:
|
if timeout <= 0:
|
||||||
return snapshot, True
|
return snapshot, True
|
||||||
|
|
||||||
deadline = asyncio.get_running_loop().time() + timeout
|
deadline = asyncio.get_running_loop().time() + timeout
|
||||||
while self._running_agents and asyncio.get_running_loop().time() < deadline:
|
while self._running_agents and asyncio.get_running_loop().time() < deadline:
|
||||||
self._update_runtime_status("draining")
|
_maybe_update_status()
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
timed_out = bool(self._running_agents)
|
timed_out = bool(self._running_agents)
|
||||||
self._update_runtime_status("draining")
|
_maybe_update_status(force=True)
|
||||||
return snapshot, timed_out
|
return snapshot, timed_out
|
||||||
|
|
||||||
def _interrupt_running_agents(self, reason: str) -> None:
|
def _interrupt_running_agents(self, reason: str) -> None:
|
||||||
|
|
@ -1841,7 +1862,7 @@ class GatewayRunner:
|
||||||
remove_pid_file()
|
remove_pid_file()
|
||||||
|
|
||||||
if self._restart_requested and self._restart_via_service:
|
if self._restart_requested and self._restart_via_service:
|
||||||
self._exit_code = 75
|
self._exit_code = GATEWAY_SERVICE_RESTART_EXIT_CODE
|
||||||
self._exit_reason = self._exit_reason or "Gateway restart requested"
|
self._exit_reason = self._exit_reason or "Gateway restart requested"
|
||||||
|
|
||||||
self._draining = False
|
self._draining = False
|
||||||
|
|
@ -2338,18 +2359,7 @@ class GatewayRunner:
|
||||||
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
|
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
|
||||||
adapter = self.adapters.get(source.platform)
|
adapter = self.adapters.get(source.platform)
|
||||||
if adapter:
|
if adapter:
|
||||||
# Reuse adapter queue semantics so photo bursts merge cleanly.
|
merge_pending_message_event(adapter._pending_messages, _quick_key, event)
|
||||||
if _quick_key in adapter._pending_messages:
|
|
||||||
existing = adapter._pending_messages[_quick_key]
|
|
||||||
if getattr(existing, "message_type", None) == MessageType.PHOTO:
|
|
||||||
existing.media_urls.extend(event.media_urls)
|
|
||||||
existing.media_types.extend(event.media_types)
|
|
||||||
if event.text:
|
|
||||||
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
|
|
||||||
else:
|
|
||||||
adapter._pending_messages[_quick_key] = event
|
|
||||||
else:
|
|
||||||
adapter._pending_messages[_quick_key] = event
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
running_agent = self._running_agents.get(_quick_key)
|
running_agent = self._running_agents.get(_quick_key)
|
||||||
|
|
@ -3951,7 +3961,7 @@ class GatewayRunner:
|
||||||
# Check for session override
|
# Check for session override
|
||||||
source = event.source
|
source = event.source
|
||||||
session_key = self._session_key_for_source(source)
|
session_key = self._session_key_for_source(source)
|
||||||
override = getattr(self, "_session_model_overrides", {}).get(session_key, {})
|
override = self._session_model_overrides.get(session_key, {})
|
||||||
if override:
|
if override:
|
||||||
current_model = override.get("model", current_model)
|
current_model = override.get("model", current_model)
|
||||||
current_provider = override.get("provider", current_provider)
|
current_provider = override.get("provider", current_provider)
|
||||||
|
|
@ -4033,8 +4043,6 @@ class GatewayRunner:
|
||||||
f"via {result.provider_label or result.target_provider}. "
|
f"via {result.provider_label or result.target_provider}. "
|
||||||
f"Adjust your self-identification accordingly.]"
|
f"Adjust your self-identification accordingly.]"
|
||||||
)
|
)
|
||||||
if not hasattr(_self, "_session_model_overrides"):
|
|
||||||
_self._session_model_overrides = {}
|
|
||||||
_self._session_model_overrides[_session_key] = {
|
_self._session_model_overrides[_session_key] = {
|
||||||
"model": result.new_model,
|
"model": result.new_model,
|
||||||
"provider": result.target_provider,
|
"provider": result.target_provider,
|
||||||
|
|
@ -4148,8 +4156,6 @@ class GatewayRunner:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store session override so next agent creation uses the new model
|
# Store session override so next agent creation uses the new model
|
||||||
if not hasattr(self, "_session_model_overrides"):
|
|
||||||
self._session_model_overrides = {}
|
|
||||||
self._session_model_overrides[session_key] = {
|
self._session_model_overrides[session_key] = {
|
||||||
"model": result.new_model,
|
"model": result.new_model,
|
||||||
"provider": result.target_provider,
|
"provider": result.target_provider,
|
||||||
|
|
@ -6828,7 +6834,7 @@ class GatewayRunner:
|
||||||
subsequent messages. Fields with ``None`` values are skipped so
|
subsequent messages. Fields with ``None`` values are skipped so
|
||||||
partial overrides don't clobber valid config defaults.
|
partial overrides don't clobber valid config defaults.
|
||||||
"""
|
"""
|
||||||
override = getattr(self, "_session_model_overrides", {}).get(session_key)
|
override = self._session_model_overrides.get(session_key)
|
||||||
if not override:
|
if not override:
|
||||||
return model, runtime_kwargs
|
return model, runtime_kwargs
|
||||||
model = override.get("model", model)
|
model = override.get("model", model)
|
||||||
|
|
@ -6840,7 +6846,7 @@ class GatewayRunner:
|
||||||
|
|
||||||
def _is_intentional_model_switch(self, session_key: str, agent_model: str) -> bool:
|
def _is_intentional_model_switch(self, session_key: str, agent_model: str) -> bool:
|
||||||
"""Return True if *agent_model* matches an active /model session override."""
|
"""Return True if *agent_model* matches an active /model session override."""
|
||||||
override = getattr(self, "_session_model_overrides", {}).get(session_key)
|
override = self._session_model_overrides.get(session_key)
|
||||||
return override is not None and override.get("model") == agent_model
|
return override is not None and override.get("model") == agent_model
|
||||||
|
|
||||||
def _evict_cached_agent(self, session_key: str) -> None:
|
def _evict_cached_agent(self, session_key: str) -> None:
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,12 @@ from pathlib import Path
|
||||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||||
|
|
||||||
from gateway.status import terminate_pid
|
from gateway.status import terminate_pid
|
||||||
|
from gateway.restart import (
|
||||||
|
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
|
||||||
|
GATEWAY_SERVICE_RESTART_EXIT_CODE,
|
||||||
|
parse_restart_drain_timeout,
|
||||||
|
)
|
||||||
from hermes_cli.config import (
|
from hermes_cli.config import (
|
||||||
DEFAULT_CONFIG,
|
|
||||||
get_env_value,
|
get_env_value,
|
||||||
get_hermes_home,
|
get_hermes_home,
|
||||||
is_managed,
|
is_managed,
|
||||||
|
|
@ -787,7 +791,7 @@ Environment="VIRTUAL_ENV={venv_dir}"
|
||||||
Environment="HERMES_HOME={hermes_home}"
|
Environment="HERMES_HOME={hermes_home}"
|
||||||
Restart=on-failure
|
Restart=on-failure
|
||||||
RestartSec=30
|
RestartSec=30
|
||||||
RestartForceExitStatus=75
|
RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}
|
||||||
KillMode=mixed
|
KillMode=mixed
|
||||||
KillSignal=SIGTERM
|
KillSignal=SIGTERM
|
||||||
ExecReload=/bin/kill -USR1 $MAINPID
|
ExecReload=/bin/kill -USR1 $MAINPID
|
||||||
|
|
@ -819,7 +823,7 @@ Environment="VIRTUAL_ENV={venv_dir}"
|
||||||
Environment="HERMES_HOME={hermes_home}"
|
Environment="HERMES_HOME={hermes_home}"
|
||||||
Restart=on-failure
|
Restart=on-failure
|
||||||
RestartSec=30
|
RestartSec=30
|
||||||
RestartForceExitStatus=75
|
RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}
|
||||||
KillMode=mixed
|
KillMode=mixed
|
||||||
KillSignal=SIGTERM
|
KillSignal=SIGTERM
|
||||||
ExecReload=/bin/kill -USR1 $MAINPID
|
ExecReload=/bin/kill -USR1 $MAINPID
|
||||||
|
|
@ -932,11 +936,12 @@ def _get_restart_drain_timeout() -> float:
|
||||||
if not raw:
|
if not raw:
|
||||||
cfg = read_raw_config()
|
cfg = read_raw_config()
|
||||||
agent_cfg = cfg.get("agent", {}) if isinstance(cfg, dict) else {}
|
agent_cfg = cfg.get("agent", {}) if isinstance(cfg, dict) else {}
|
||||||
raw = str(agent_cfg.get("restart_drain_timeout", DEFAULT_CONFIG["agent"]["restart_drain_timeout"]))
|
raw = str(
|
||||||
try:
|
agent_cfg.get(
|
||||||
return max(0.0, float(raw))
|
"restart_drain_timeout", DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
except (TypeError, ValueError):
|
)
|
||||||
return float(DEFAULT_CONFIG["agent"]["restart_drain_timeout"])
|
)
|
||||||
|
return parse_restart_drain_timeout(raw)
|
||||||
|
|
||||||
|
|
||||||
def systemd_install(force: bool = False, system: bool = False, run_as_user: str | None = None):
|
def systemd_install(force: bool = False, system: bool = False, run_as_user: str | None = None):
|
||||||
|
|
|
||||||
110
tests/gateway/restart_test_helpers.py
Normal file
110
tests/gateway/restart_test_helpers.py
Normal file
|
|
@ -0,0 +1,110 @@
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||||
|
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||||
|
from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
from gateway.run import GatewayRunner
|
||||||
|
from gateway.session import SessionSource
|
||||||
|
|
||||||
|
|
||||||
|
class RestartTestAdapter(BasePlatformAdapter):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
||||||
|
self.sent: list[str] = []
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||||
|
self.sent.append(content)
|
||||||
|
return SendResult(success=True, message_id="1")
|
||||||
|
|
||||||
|
async def send_typing(self, chat_id, metadata=None):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_chat_info(self, chat_id):
|
||||||
|
return {"id": chat_id}
|
||||||
|
|
||||||
|
|
||||||
|
def make_restart_source(chat_id: str = "123456", chat_type: str = "dm") -> SessionSource:
|
||||||
|
return SessionSource(
|
||||||
|
platform=Platform.TELEGRAM,
|
||||||
|
chat_id=chat_id,
|
||||||
|
chat_type=chat_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_restart_runner(
|
||||||
|
adapter: BasePlatformAdapter | None = None,
|
||||||
|
) -> tuple[GatewayRunner, BasePlatformAdapter]:
|
||||||
|
runner = object.__new__(GatewayRunner)
|
||||||
|
runner.config = GatewayConfig(
|
||||||
|
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||||
|
)
|
||||||
|
runner._running = True
|
||||||
|
runner._shutdown_event = asyncio.Event()
|
||||||
|
runner._exit_reason = None
|
||||||
|
runner._exit_code = None
|
||||||
|
runner._running_agents = {}
|
||||||
|
runner._running_agents_ts = {}
|
||||||
|
runner._pending_messages = {}
|
||||||
|
runner._pending_approvals = {}
|
||||||
|
runner._pending_model_notes = {}
|
||||||
|
runner._background_tasks = set()
|
||||||
|
runner._draining = False
|
||||||
|
runner._restart_requested = False
|
||||||
|
runner._restart_task_started = False
|
||||||
|
runner._restart_detached = False
|
||||||
|
runner._restart_via_service = False
|
||||||
|
runner._restart_drain_timeout = DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
runner._stop_task = None
|
||||||
|
runner._busy_input_mode = "interrupt"
|
||||||
|
runner._update_prompt_pending = {}
|
||||||
|
runner._voice_mode = {}
|
||||||
|
runner._session_model_overrides = {}
|
||||||
|
runner._shutdown_all_gateway_honcho = lambda: None
|
||||||
|
runner._update_runtime_status = MagicMock()
|
||||||
|
runner._queue_or_replace_pending_event = GatewayRunner._queue_or_replace_pending_event.__get__(
|
||||||
|
runner, GatewayRunner
|
||||||
|
)
|
||||||
|
runner._session_key_for_source = GatewayRunner._session_key_for_source.__get__(
|
||||||
|
runner, GatewayRunner
|
||||||
|
)
|
||||||
|
runner._handle_active_session_busy_message = (
|
||||||
|
GatewayRunner._handle_active_session_busy_message.__get__(runner, GatewayRunner)
|
||||||
|
)
|
||||||
|
runner._handle_restart_command = GatewayRunner._handle_restart_command.__get__(
|
||||||
|
runner, GatewayRunner
|
||||||
|
)
|
||||||
|
runner._status_action_label = GatewayRunner._status_action_label.__get__(
|
||||||
|
runner, GatewayRunner
|
||||||
|
)
|
||||||
|
runner._status_action_gerund = GatewayRunner._status_action_gerund.__get__(
|
||||||
|
runner, GatewayRunner
|
||||||
|
)
|
||||||
|
runner._queue_during_drain_enabled = GatewayRunner._queue_during_drain_enabled.__get__(
|
||||||
|
runner, GatewayRunner
|
||||||
|
)
|
||||||
|
runner._running_agent_count = GatewayRunner._running_agent_count.__get__(
|
||||||
|
runner, GatewayRunner
|
||||||
|
)
|
||||||
|
runner._launch_detached_restart_command = GatewayRunner._launch_detached_restart_command.__get__(
|
||||||
|
runner, GatewayRunner
|
||||||
|
)
|
||||||
|
runner.request_restart = GatewayRunner.request_restart.__get__(runner, GatewayRunner)
|
||||||
|
runner._is_user_authorized = lambda _source: True
|
||||||
|
runner.hooks = MagicMock()
|
||||||
|
runner.hooks.emit = AsyncMock()
|
||||||
|
runner.pairing_store = MagicMock()
|
||||||
|
runner.session_store = MagicMock()
|
||||||
|
runner.delivery_router = MagicMock()
|
||||||
|
|
||||||
|
platform_adapter = adapter or RestartTestAdapter()
|
||||||
|
platform_adapter.set_message_handler(AsyncMock(return_value=None))
|
||||||
|
platform_adapter.set_busy_session_handler(runner._handle_active_session_busy_message)
|
||||||
|
runner.adapters = {Platform.TELEGRAM: platform_adapter}
|
||||||
|
return runner, platform_adapter
|
||||||
|
|
@ -3,67 +3,15 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
from gateway.platforms.base import MessageEvent
|
||||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
from gateway.restart import GATEWAY_SERVICE_RESTART_EXIT_CODE
|
||||||
from gateway.run import GatewayRunner
|
from gateway.session import build_session_key
|
||||||
from gateway.session import SessionSource, build_session_key
|
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
|
||||||
|
|
||||||
|
|
||||||
class StubAdapter(BasePlatformAdapter):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
|
||||||
|
|
||||||
async def connect(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def disconnect(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
|
||||||
return SendResult(success=True, message_id="1")
|
|
||||||
|
|
||||||
async def send_typing(self, chat_id, metadata=None):
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_chat_info(self, chat_id):
|
|
||||||
return {"id": chat_id}
|
|
||||||
|
|
||||||
|
|
||||||
def _source(chat_id="123456", chat_type="dm"):
|
|
||||||
return SessionSource(
|
|
||||||
platform=Platform.TELEGRAM,
|
|
||||||
chat_id=chat_id,
|
|
||||||
chat_type=chat_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_runner() -> GatewayRunner:
|
|
||||||
runner = object.__new__(GatewayRunner)
|
|
||||||
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
|
|
||||||
runner._running = True
|
|
||||||
runner._shutdown_event = asyncio.Event()
|
|
||||||
runner._exit_reason = None
|
|
||||||
runner._exit_code = None
|
|
||||||
runner._pending_messages = {}
|
|
||||||
runner._pending_approvals = {}
|
|
||||||
runner._background_tasks = set()
|
|
||||||
runner._running_agents = {}
|
|
||||||
runner._running_agents_ts = {}
|
|
||||||
runner._draining = False
|
|
||||||
runner._restart_requested = False
|
|
||||||
runner._restart_task_started = False
|
|
||||||
runner._restart_detached = False
|
|
||||||
runner._restart_via_service = False
|
|
||||||
runner._restart_drain_timeout = 60.0
|
|
||||||
runner._stop_task = None
|
|
||||||
runner._shutdown_all_gateway_honcho = lambda: None
|
|
||||||
runner._update_runtime_status = MagicMock()
|
|
||||||
return runner
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
||||||
adapter = StubAdapter()
|
_runner, adapter = make_restart_runner()
|
||||||
release = asyncio.Event()
|
release = asyncio.Event()
|
||||||
|
|
||||||
async def block_forever(_event):
|
async def block_forever(_event):
|
||||||
|
|
@ -71,7 +19,7 @@ async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
adapter.set_message_handler(block_forever)
|
adapter.set_message_handler(block_forever)
|
||||||
event = MessageEvent(text="work", source=_source(), message_id="1")
|
event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
|
||||||
|
|
||||||
await adapter.handle_message(event)
|
await adapter.handle_message(event)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
@ -89,12 +37,11 @@ async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks():
|
async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks():
|
||||||
runner = _make_runner()
|
runner, adapter = make_restart_runner()
|
||||||
runner._pending_messages = {"session": "pending text"}
|
runner._pending_messages = {"session": "pending text"}
|
||||||
runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}}
|
runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}}
|
||||||
runner._restart_drain_timeout = 0.0
|
runner._restart_drain_timeout = 0.0
|
||||||
|
|
||||||
adapter = StubAdapter()
|
|
||||||
release = asyncio.Event()
|
release = asyncio.Event()
|
||||||
|
|
||||||
async def block_forever(_event):
|
async def block_forever(_event):
|
||||||
|
|
@ -102,7 +49,7 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
adapter.set_message_handler(block_forever)
|
adapter.set_message_handler(block_forever)
|
||||||
event = MessageEvent(text="work", source=_source(), message_id="1")
|
event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
|
||||||
await adapter.handle_message(event)
|
await adapter.handle_message(event)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|
@ -112,7 +59,6 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
|
||||||
session_key = build_session_key(event.source)
|
session_key = build_session_key(event.source)
|
||||||
running_agent = MagicMock()
|
running_agent = MagicMock()
|
||||||
runner._running_agents = {session_key: running_agent}
|
runner._running_agents = {session_key: running_agent}
|
||||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
|
||||||
|
|
||||||
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
||||||
await runner.stop()
|
await runner.stop()
|
||||||
|
|
@ -128,11 +74,9 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gateway_stop_drains_running_agents_before_disconnect():
|
async def test_gateway_stop_drains_running_agents_before_disconnect():
|
||||||
runner = _make_runner()
|
runner, adapter = make_restart_runner()
|
||||||
adapter = StubAdapter()
|
|
||||||
disconnect_mock = AsyncMock()
|
disconnect_mock = AsyncMock()
|
||||||
adapter.disconnect = disconnect_mock
|
adapter.disconnect = disconnect_mock
|
||||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
|
||||||
|
|
||||||
running_agent = MagicMock()
|
running_agent = MagicMock()
|
||||||
runner._running_agents = {"session": running_agent}
|
runner._running_agents = {"session": running_agent}
|
||||||
|
|
@ -153,13 +97,11 @@ async def test_gateway_stop_drains_running_agents_before_disconnect():
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gateway_stop_interrupts_after_drain_timeout():
|
async def test_gateway_stop_interrupts_after_drain_timeout():
|
||||||
runner = _make_runner()
|
runner, adapter = make_restart_runner()
|
||||||
runner._restart_drain_timeout = 0.05
|
runner._restart_drain_timeout = 0.05
|
||||||
|
|
||||||
adapter = StubAdapter()
|
|
||||||
disconnect_mock = AsyncMock()
|
disconnect_mock = AsyncMock()
|
||||||
adapter.disconnect = disconnect_mock
|
adapter.disconnect = disconnect_mock
|
||||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
|
||||||
|
|
||||||
running_agent = MagicMock()
|
running_agent = MagicMock()
|
||||||
runner._running_agents = {"session": running_agent}
|
runner._running_agents = {"session": running_agent}
|
||||||
|
|
@ -170,3 +112,36 @@ async def test_gateway_stop_interrupts_after_drain_timeout():
|
||||||
running_agent.interrupt.assert_called_once_with("Gateway shutting down")
|
running_agent.interrupt.assert_called_once_with("Gateway shutting down")
|
||||||
disconnect_mock.assert_awaited_once()
|
disconnect_mock.assert_awaited_once()
|
||||||
assert runner._shutdown_event.is_set() is True
|
assert runner._shutdown_event.is_set() is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gateway_stop_service_restart_sets_named_exit_code():
|
||||||
|
runner, adapter = make_restart_runner()
|
||||||
|
adapter.disconnect = AsyncMock()
|
||||||
|
|
||||||
|
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
||||||
|
await runner.stop(restart=True, service_restart=True)
|
||||||
|
|
||||||
|
assert runner._exit_code == GATEWAY_SERVICE_RESTART_EXIT_CODE
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_drain_active_agents_throttles_status_updates():
|
||||||
|
runner, _adapter = make_restart_runner()
|
||||||
|
runner._update_runtime_status = MagicMock()
|
||||||
|
|
||||||
|
runner._running_agents = {"a": MagicMock(), "b": MagicMock()}
|
||||||
|
|
||||||
|
async def finish_agents():
|
||||||
|
await asyncio.sleep(0.12)
|
||||||
|
runner._running_agents.pop("a")
|
||||||
|
await asyncio.sleep(0.12)
|
||||||
|
runner._running_agents.clear()
|
||||||
|
|
||||||
|
task = asyncio.create_task(finish_agents())
|
||||||
|
await runner._drain_active_agents(1.0)
|
||||||
|
await task
|
||||||
|
|
||||||
|
# Start, one count-change update, and final update. Allow one extra update
|
||||||
|
# if the loop observes the zero-agent state before exiting.
|
||||||
|
assert 3 <= runner._update_runtime_status.call_count <= 4
|
||||||
|
|
|
||||||
|
|
@ -1,95 +1,27 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
import gateway.run as gateway_run
|
||||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
|
from gateway.platforms.base import MessageEvent, MessageType
|
||||||
from gateway.run import GatewayRunner
|
from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
from gateway.session import SessionSource, build_session_key
|
from gateway.session import build_session_key
|
||||||
|
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
|
||||||
|
|
||||||
class RecordingAdapter(BasePlatformAdapter):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
|
||||||
self.sent: list[str] = []
|
|
||||||
|
|
||||||
async def connect(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def disconnect(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
|
||||||
self.sent.append(content)
|
|
||||||
return SendResult(success=True, message_id="1")
|
|
||||||
|
|
||||||
async def send_typing(self, chat_id, metadata=None):
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_chat_info(self, chat_id):
|
|
||||||
return {"id": chat_id}
|
|
||||||
|
|
||||||
|
|
||||||
def _source(chat_id="123456"):
|
|
||||||
return SessionSource(
|
|
||||||
platform=Platform.TELEGRAM,
|
|
||||||
chat_id=chat_id,
|
|
||||||
chat_type="dm",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_runner() -> tuple[GatewayRunner, RecordingAdapter]:
|
|
||||||
runner = object.__new__(GatewayRunner)
|
|
||||||
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
|
|
||||||
runner.adapters = {}
|
|
||||||
runner._running = True
|
|
||||||
runner._shutdown_event = asyncio.Event()
|
|
||||||
runner._exit_reason = None
|
|
||||||
runner._exit_code = None
|
|
||||||
runner._running_agents = {}
|
|
||||||
runner._running_agents_ts = {}
|
|
||||||
runner._pending_messages = {}
|
|
||||||
runner._pending_approvals = {}
|
|
||||||
runner._background_tasks = set()
|
|
||||||
runner._draining = False
|
|
||||||
runner._restart_requested = False
|
|
||||||
runner._restart_task_started = False
|
|
||||||
runner._restart_detached = False
|
|
||||||
runner._restart_via_service = False
|
|
||||||
runner._restart_drain_timeout = 60.0
|
|
||||||
runner._stop_task = None
|
|
||||||
runner._busy_input_mode = "interrupt"
|
|
||||||
runner._update_prompt_pending = {}
|
|
||||||
runner._voice_mode = {}
|
|
||||||
runner._update_runtime_status = MagicMock()
|
|
||||||
runner._queue_or_replace_pending_event = GatewayRunner._queue_or_replace_pending_event.__get__(runner, GatewayRunner)
|
|
||||||
runner._session_key_for_source = GatewayRunner._session_key_for_source.__get__(runner, GatewayRunner)
|
|
||||||
runner._handle_active_session_busy_message = GatewayRunner._handle_active_session_busy_message.__get__(runner, GatewayRunner)
|
|
||||||
runner._handle_restart_command = GatewayRunner._handle_restart_command.__get__(runner, GatewayRunner)
|
|
||||||
runner._status_action_label = GatewayRunner._status_action_label.__get__(runner, GatewayRunner)
|
|
||||||
runner._status_action_gerund = GatewayRunner._status_action_gerund.__get__(runner, GatewayRunner)
|
|
||||||
runner._queue_during_drain_enabled = GatewayRunner._queue_during_drain_enabled.__get__(runner, GatewayRunner)
|
|
||||||
runner._running_agent_count = GatewayRunner._running_agent_count.__get__(runner, GatewayRunner)
|
|
||||||
runner.request_restart = MagicMock(return_value=True)
|
|
||||||
runner._is_user_authorized = lambda _source: True
|
|
||||||
runner.hooks = MagicMock()
|
|
||||||
runner.hooks.emit = AsyncMock()
|
|
||||||
runner.pairing_store = MagicMock()
|
|
||||||
runner.session_store = MagicMock()
|
|
||||||
runner.delivery_router = MagicMock()
|
|
||||||
|
|
||||||
adapter = RecordingAdapter()
|
|
||||||
adapter.set_message_handler(AsyncMock(return_value=None))
|
|
||||||
adapter.set_busy_session_handler(runner._handle_active_session_busy_message)
|
|
||||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
|
||||||
return runner, adapter
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_restart_command_while_busy_requests_drain_without_interrupt():
|
async def test_restart_command_while_busy_requests_drain_without_interrupt():
|
||||||
runner, _adapter = _make_runner()
|
runner, _adapter = make_restart_runner()
|
||||||
event = MessageEvent(text="/restart", message_type=MessageType.TEXT, source=_source(), message_id="m1")
|
runner.request_restart = MagicMock(return_value=True)
|
||||||
|
event = MessageEvent(
|
||||||
|
text="/restart",
|
||||||
|
message_type=MessageType.TEXT,
|
||||||
|
source=make_restart_source(),
|
||||||
|
message_id="m1",
|
||||||
|
)
|
||||||
session_key = build_session_key(event.source)
|
session_key = build_session_key(event.source)
|
||||||
running_agent = MagicMock()
|
running_agent = MagicMock()
|
||||||
runner._running_agents[session_key] = running_agent
|
runner._running_agents[session_key] = running_agent
|
||||||
|
|
@ -103,12 +35,17 @@ async def test_restart_command_while_busy_requests_drain_without_interrupt():
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_drain_queue_mode_queues_follow_up_without_interrupt():
|
async def test_drain_queue_mode_queues_follow_up_without_interrupt():
|
||||||
runner, adapter = _make_runner()
|
runner, adapter = make_restart_runner()
|
||||||
runner._draining = True
|
runner._draining = True
|
||||||
runner._restart_requested = True
|
runner._restart_requested = True
|
||||||
runner._busy_input_mode = "queue"
|
runner._busy_input_mode = "queue"
|
||||||
|
|
||||||
event = MessageEvent(text="follow up", message_type=MessageType.TEXT, source=_source(), message_id="m2")
|
event = MessageEvent(
|
||||||
|
text="follow up",
|
||||||
|
message_type=MessageType.TEXT,
|
||||||
|
source=make_restart_source(),
|
||||||
|
message_id="m2",
|
||||||
|
)
|
||||||
session_key = build_session_key(event.source)
|
session_key = build_session_key(event.source)
|
||||||
adapter._active_sessions[session_key] = asyncio.Event()
|
adapter._active_sessions[session_key] = asyncio.Event()
|
||||||
|
|
||||||
|
|
@ -122,12 +59,102 @@ async def test_drain_queue_mode_queues_follow_up_without_interrupt():
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_draining_rejects_new_session_messages():
|
async def test_draining_rejects_new_session_messages():
|
||||||
runner, _adapter = _make_runner()
|
runner, _adapter = make_restart_runner()
|
||||||
runner._draining = True
|
runner._draining = True
|
||||||
runner._restart_requested = True
|
runner._restart_requested = True
|
||||||
|
|
||||||
event = MessageEvent(text="hello", message_type=MessageType.TEXT, source=_source("fresh"), message_id="m3")
|
event = MessageEvent(
|
||||||
|
text="hello",
|
||||||
|
message_type=MessageType.TEXT,
|
||||||
|
source=make_restart_source("fresh"),
|
||||||
|
message_id="m3",
|
||||||
|
)
|
||||||
|
|
||||||
result = await runner._handle_message(event)
|
result = await runner._handle_message(event)
|
||||||
|
|
||||||
assert result == "⏳ Gateway is restarting and is not accepting new work right now."
|
assert result == "⏳ Gateway is restarting and is not accepting new work right now."
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_busy_input_mode_prefers_env_then_config_then_default(tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||||
|
monkeypatch.delenv("HERMES_GATEWAY_BUSY_INPUT_MODE", raising=False)
|
||||||
|
|
||||||
|
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
|
||||||
|
|
||||||
|
(tmp_path / "config.yaml").write_text(
|
||||||
|
"display:\n busy_input_mode: queue\n", encoding="utf-8"
|
||||||
|
)
|
||||||
|
assert gateway_run.GatewayRunner._load_busy_input_mode() == "queue"
|
||||||
|
|
||||||
|
monkeypatch.setenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "interrupt")
|
||||||
|
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_restart_drain_timeout_prefers_env_then_config_then_default(
|
||||||
|
tmp_path, monkeypatch, caplog
|
||||||
|
):
|
||||||
|
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||||
|
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
gateway_run.GatewayRunner._load_restart_drain_timeout()
|
||||||
|
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
)
|
||||||
|
|
||||||
|
(tmp_path / "config.yaml").write_text(
|
||||||
|
"agent:\n restart_drain_timeout: 12\n", encoding="utf-8"
|
||||||
|
)
|
||||||
|
assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 12.0
|
||||||
|
|
||||||
|
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "7")
|
||||||
|
assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 7.0
|
||||||
|
|
||||||
|
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "invalid")
|
||||||
|
assert (
|
||||||
|
gateway_run.GatewayRunner._load_restart_drain_timeout()
|
||||||
|
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
)
|
||||||
|
assert "Invalid restart_drain_timeout" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_restart_is_idempotent():
|
||||||
|
runner, _adapter = make_restart_runner()
|
||||||
|
runner.stop = AsyncMock()
|
||||||
|
|
||||||
|
assert runner.request_restart(detached=True, via_service=False) is True
|
||||||
|
first_task = next(iter(runner._background_tasks))
|
||||||
|
assert runner.request_restart(detached=True, via_service=False) is False
|
||||||
|
|
||||||
|
await first_task
|
||||||
|
|
||||||
|
runner.stop.assert_awaited_once_with(
|
||||||
|
restart=True, detached_restart=True, service_restart=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_launch_detached_restart_command_uses_setsid(monkeypatch):
|
||||||
|
runner, _adapter = make_restart_runner()
|
||||||
|
popen_calls = []
|
||||||
|
|
||||||
|
monkeypatch.setattr(gateway_run, "_resolve_hermes_bin", lambda: ["/usr/bin/hermes"])
|
||||||
|
monkeypatch.setattr(gateway_run.os, "getpid", lambda: 321)
|
||||||
|
monkeypatch.setattr(shutil, "which", lambda cmd: "/usr/bin/setsid" if cmd == "setsid" else None)
|
||||||
|
|
||||||
|
def fake_popen(cmd, **kwargs):
|
||||||
|
popen_calls.append((cmd, kwargs))
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
monkeypatch.setattr(subprocess, "Popen", fake_popen)
|
||||||
|
|
||||||
|
await runner._launch_detached_restart_command()
|
||||||
|
|
||||||
|
assert len(popen_calls) == 1
|
||||||
|
cmd, kwargs = popen_calls[0]
|
||||||
|
assert cmd[:2] == ["/usr/bin/setsid", "bash"]
|
||||||
|
assert "gateway restart" in cmd[-1]
|
||||||
|
assert "kill -0 321" in cmd[-1]
|
||||||
|
assert kwargs["start_new_session"] is True
|
||||||
|
assert kwargs["stdout"] is subprocess.DEVNULL
|
||||||
|
assert kwargs["stderr"] is subprocess.DEVNULL
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,10 @@ from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import hermes_cli.gateway as gateway_cli
|
import hermes_cli.gateway as gateway_cli
|
||||||
|
from gateway.restart import (
|
||||||
|
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
|
||||||
|
GATEWAY_SERVICE_RESTART_EXIT_CODE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSystemdServiceRefresh:
|
class TestSystemdServiceRefresh:
|
||||||
|
|
@ -85,7 +89,7 @@ class TestGeneratedSystemdUnits:
|
||||||
assert "ExecStart=" in unit
|
assert "ExecStart=" in unit
|
||||||
assert "ExecStop=" not in unit
|
assert "ExecStop=" not in unit
|
||||||
assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit
|
assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit
|
||||||
assert "RestartForceExitStatus=75" in unit
|
assert f"RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}" in unit
|
||||||
assert "TimeoutStopSec=60" in unit
|
assert "TimeoutStopSec=60" in unit
|
||||||
|
|
||||||
def test_user_unit_includes_resolved_node_directory_in_path(self, monkeypatch):
|
def test_user_unit_includes_resolved_node_directory_in_path(self, monkeypatch):
|
||||||
|
|
@ -101,7 +105,7 @@ class TestGeneratedSystemdUnits:
|
||||||
assert "ExecStart=" in unit
|
assert "ExecStart=" in unit
|
||||||
assert "ExecStop=" not in unit
|
assert "ExecStop=" not in unit
|
||||||
assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit
|
assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit
|
||||||
assert "RestartForceExitStatus=75" in unit
|
assert f"RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}" in unit
|
||||||
assert "TimeoutStopSec=60" in unit
|
assert "TimeoutStopSec=60" in unit
|
||||||
assert "WantedBy=multi-user.target" in unit
|
assert "WantedBy=multi-user.target" in unit
|
||||||
|
|
||||||
|
|
@ -161,6 +165,31 @@ class TestGatewayStopCleanup:
|
||||||
|
|
||||||
|
|
||||||
class TestLaunchdServiceRecovery:
|
class TestLaunchdServiceRecovery:
|
||||||
|
def test_get_restart_drain_timeout_prefers_env_then_config_then_default(self, monkeypatch):
|
||||||
|
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
|
||||||
|
monkeypatch.setattr(gateway_cli, "read_raw_config", lambda: {})
|
||||||
|
|
||||||
|
assert (
|
||||||
|
gateway_cli._get_restart_drain_timeout()
|
||||||
|
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
gateway_cli,
|
||||||
|
"read_raw_config",
|
||||||
|
lambda: {"agent": {"restart_drain_timeout": 14}},
|
||||||
|
)
|
||||||
|
assert gateway_cli._get_restart_drain_timeout() == 14.0
|
||||||
|
|
||||||
|
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "9")
|
||||||
|
assert gateway_cli._get_restart_drain_timeout() == 9.0
|
||||||
|
|
||||||
|
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "invalid")
|
||||||
|
assert (
|
||||||
|
gateway_cli._get_restart_drain_timeout()
|
||||||
|
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
)
|
||||||
|
|
||||||
def test_launchd_install_repairs_outdated_plist_without_force(self, tmp_path, monkeypatch):
|
def test_launchd_install_repairs_outdated_plist_without_force(self, tmp_path, monkeypatch):
|
||||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||||
plist_path.write_text("<plist>old content</plist>", encoding="utf-8")
|
plist_path.write_text("<plist>old content</plist>", encoding="utf-8")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue