diff --git a/cli-config.yaml.example b/cli-config.yaml.example index a0a2d7d8a1..5807cef7aa 100644 --- a/cli-config.yaml.example +++ b/cli-config.yaml.example @@ -480,6 +480,12 @@ agent: # Fires once per run when inactivity reaches this threshold (seconds). # Set to 0 to disable the warning. # gateway_timeout_warning: 900 + + # Graceful drain timeout for gateway stop/restart (seconds). + # The gateway stops accepting new work, waits for in-flight agents to + # finish, then interrupts anything still running after this timeout. + # 0 = no drain, interrupt immediately. + # restart_drain_timeout: 60 # Enable verbose logging verbose: false diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index dfc06ef7cb..34aacc7a39 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -727,6 +727,7 @@ class BasePlatformAdapter(ABC): # working on a task after --replace or manual restarts. self._background_tasks: set[asyncio.Task] = set() self._expected_cancelled_tasks: set[asyncio.Task] = set() + self._busy_session_handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]] = None # Chats where auto-TTS on voice input is disabled (set by /voice off) self._auto_tts_disabled_chats: set = set() # Chats where typing indicator is paused (e.g. during approval waits). @@ -815,6 +816,10 @@ class BasePlatformAdapter(ABC): an optional response string. """ self._message_handler = handler + + def set_busy_session_handler(self, handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]]) -> None: + """Set an optional handler for messages arriving during active sessions.""" + self._busy_session_handler = handler def set_session_store(self, session_store: Any) -> None: """ @@ -1396,7 +1401,7 @@ class BasePlatformAdapter(ABC): # session lifecycle and its cleanup races with the running task # (see PR #4926). cmd = event.get_command() - if cmd in ("approve", "deny", "status", "stop", "new", "reset", "background"): + if cmd in ("approve", "deny", "status", "stop", "new", "reset", "background", "restart"): logger.debug( "[%s] Command '/%s' bypassing active-session guard for %s", self.name, cmd, session_key, @@ -1415,6 +1420,13 @@ class BasePlatformAdapter(ABC): logger.error("[%s] Command '/%s' dispatch failed: %s", self.name, cmd, e, exc_info=True) return + if self._busy_session_handler is not None: + try: + if await self._busy_session_handler(event, session_key): + return + except Exception as e: + logger.error("[%s] Busy-session handler failed: %s", self.name, e, exc_info=True) + # Special case: photo bursts/albums frequently arrive as multiple near- # simultaneous messages. Queue them without interrupting the active run, # then process them immediately after the current task finishes. diff --git a/gateway/run.py b/gateway/run.py index 0dff622ae8..e4caedd9e4 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -186,6 +186,12 @@ if _config_path.exists(): os.environ["HERMES_AGENT_TIMEOUT"] = str(_agent_cfg["gateway_timeout"]) if "gateway_timeout_warning" in _agent_cfg and "HERMES_AGENT_TIMEOUT_WARNING" not in os.environ: os.environ["HERMES_AGENT_TIMEOUT_WARNING"] = str(_agent_cfg["gateway_timeout_warning"]) + if "restart_drain_timeout" in _agent_cfg and "HERMES_RESTART_DRAIN_TIMEOUT" not in os.environ: + os.environ["HERMES_RESTART_DRAIN_TIMEOUT"] = str(_agent_cfg["restart_drain_timeout"]) + _display_cfg = _cfg.get("display", {}) + if _display_cfg and isinstance(_display_cfg, dict): + if "busy_input_mode" in _display_cfg and "HERMES_GATEWAY_BUSY_INPUT_MODE" not in os.environ: + os.environ["HERMES_GATEWAY_BUSY_INPUT_MODE"] = str(_display_cfg["busy_input_mode"]) # Timezone: bridge config.yaml → HERMES_TIMEZONE env var. # HERMES_TIMEZONE from .env takes precedence (already in os.environ). _tz_cfg = _cfg.get("timezone", "") @@ -483,6 +489,8 @@ class GatewayRunner: self._reasoning_config = self._load_reasoning_config() self._service_tier = self._load_service_tier() self._show_reasoning = self._load_show_reasoning() + self._busy_input_mode = self._load_busy_input_mode() + self._restart_drain_timeout = self._load_restart_drain_timeout() self._provider_routing = self._load_provider_routing() self._fallback_model = self._load_fallback_model() self._smart_model_routing = self._load_smart_model_routing() @@ -499,6 +507,13 @@ class GatewayRunner: self._exit_cleanly = False self._exit_with_failure = False self._exit_reason: Optional[str] = None + self._exit_code: Optional[int] = None + self._draining = False + self._restart_requested = False + self._restart_task_started = False + self._restart_detached = False + self._restart_via_service = False + self._stop_task: Optional[asyncio.Task] = None # Track running agents per session for interrupt support # Key: session_key, Value: AIAgent instance @@ -759,6 +774,10 @@ class GatewayRunner: def exit_reason(self) -> Optional[str]: return self._exit_reason + @property + def exit_code(self) -> Optional[int]: + return self._exit_code + def _session_key_for_source(self, source: SessionSource) -> str: """Resolve the current session key for a source, honoring gateway config when available.""" if hasattr(self, "session_store") and self.session_store is not None: @@ -868,6 +887,30 @@ class GatewayRunner: self._exit_cleanly = True self._exit_reason = reason self._shutdown_event.set() + + def _running_agent_count(self) -> int: + return len(self._running_agents) + + def _status_action_label(self) -> str: + return "restart" if self._restart_requested else "shutdown" + + def _status_action_gerund(self) -> str: + return "restarting" if self._restart_requested else "shutting down" + + def _queue_during_drain_enabled(self) -> bool: + return self._restart_requested and self._busy_input_mode == "queue" + + def _update_runtime_status(self, gateway_state: Optional[str] = None, exit_reason: Optional[str] = None) -> None: + try: + from gateway.status import write_runtime_status + write_runtime_status( + gateway_state=gateway_state, + exit_reason=exit_reason, + restart_requested=self._restart_requested, + active_agents=self._running_agent_count(), + ) + except Exception: + pass @staticmethod def _load_prefill_messages() -> List[Dict[str, Any]]: @@ -994,6 +1037,43 @@ class GatewayRunner: pass return False + @staticmethod + def _load_busy_input_mode() -> str: + """Load gateway drain-time busy-input behavior from config/env.""" + mode = os.getenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "").strip().lower() + if not mode: + try: + import yaml as _y + cfg_path = _hermes_home / "config.yaml" + if cfg_path.exists(): + with open(cfg_path, encoding="utf-8") as _f: + cfg = _y.safe_load(_f) or {} + mode = str(cfg.get("display", {}).get("busy_input_mode", "") or "").strip().lower() + except Exception: + pass + return "queue" if mode == "queue" else "interrupt" + + @staticmethod + def _load_restart_drain_timeout() -> float: + """Load graceful gateway restart/stop drain timeout in seconds.""" + raw = os.getenv("HERMES_RESTART_DRAIN_TIMEOUT", "").strip() + if not raw: + try: + import yaml as _y + cfg_path = _hermes_home / "config.yaml" + if cfg_path.exists(): + with open(cfg_path, encoding="utf-8") as _f: + cfg = _y.safe_load(_f) or {} + raw = str(cfg.get("agent", {}).get("restart_drain_timeout", "") or "").strip() + except Exception: + pass + try: + value = float(raw) if raw else 60.0 + except ValueError: + logger.warning("Invalid restart_drain_timeout '%s', using default 60s", raw) + return 60.0 + return max(0.0, value) + @staticmethod def _load_background_notifications_mode() -> str: """Load background process notification mode from config or env var. @@ -1078,6 +1158,142 @@ class GatewayRunner: pass return {} + def _snapshot_running_agents(self) -> Dict[str, Any]: + return { + session_key: agent + for session_key, agent in self._running_agents.items() + if agent is not _AGENT_PENDING_SENTINEL + } + + def _queue_or_replace_pending_event(self, session_key: str, event: MessageEvent) -> None: + adapter = self.adapters.get(event.source.platform) + if not adapter: + return + existing = adapter._pending_messages.get(session_key) + if existing and getattr(existing, "message_type", None) == MessageType.PHOTO and event.message_type == MessageType.PHOTO: + existing.media_urls.extend(event.media_urls) + existing.media_types.extend(event.media_types) + if event.text: + existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text) + return + adapter._pending_messages[session_key] = event + + async def _handle_active_session_busy_message(self, event: MessageEvent, session_key: str) -> bool: + if not self._draining: + return False + + adapter = self.adapters.get(event.source.platform) + if not adapter: + return True + + thread_meta = {"thread_id": event.source.thread_id} if event.source.thread_id else None + if self._queue_during_drain_enabled(): + self._queue_or_replace_pending_event(session_key, event) + message = f"⏳ Gateway {self._status_action_gerund()} — queued for the next turn after it comes back." + else: + message = f"⏳ Gateway is {self._status_action_gerund()} and is not accepting another turn right now." + + await adapter._send_with_retry( + chat_id=event.source.chat_id, + content=message, + reply_to=event.message_id, + metadata=thread_meta, + ) + return True + + async def _drain_active_agents(self, timeout: float) -> tuple[Dict[str, Any], bool]: + snapshot = self._snapshot_running_agents() + if not self._running_agents: + self._update_runtime_status("draining") + return snapshot, False + + self._update_runtime_status("draining") + if timeout <= 0: + return snapshot, True + + deadline = asyncio.get_running_loop().time() + timeout + while self._running_agents and asyncio.get_running_loop().time() < deadline: + self._update_runtime_status("draining") + await asyncio.sleep(0.1) + timed_out = bool(self._running_agents) + self._update_runtime_status("draining") + return snapshot, timed_out + + def _interrupt_running_agents(self, reason: str) -> None: + for session_key, agent in list(self._running_agents.items()): + if agent is _AGENT_PENDING_SENTINEL: + continue + try: + agent.interrupt(reason) + logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20]) + except Exception as e: + logger.debug("Failed interrupting agent during shutdown: %s", e) + + def _finalize_shutdown_agents(self, active_agents: Dict[str, Any]) -> None: + for agent in active_agents.values(): + try: + from hermes_cli.plugins import invoke_hook as _invoke_hook + _invoke_hook( + "on_session_finalize", + session_id=getattr(agent, "session_id", None), + platform="gateway", + ) + except Exception: + pass + try: + if hasattr(agent, "shutdown_memory_provider"): + agent.shutdown_memory_provider() + except Exception: + pass + + async def _launch_detached_restart_command(self) -> None: + import shutil + import subprocess + + hermes_cmd = _resolve_hermes_bin() + if not hermes_cmd: + logger.error("Could not locate hermes binary for detached /restart") + return + + current_pid = os.getpid() + cmd = " ".join(shlex.quote(part) for part in hermes_cmd) + shell_cmd = ( + f"while kill -0 {current_pid} 2>/dev/null; do sleep 0.2; done; " + f"{cmd} gateway restart" + ) + setsid_bin = shutil.which("setsid") + if setsid_bin: + subprocess.Popen( + [setsid_bin, "bash", "-lc", shell_cmd], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + start_new_session=True, + ) + else: + subprocess.Popen( + ["bash", "-lc", shell_cmd], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + start_new_session=True, + ) + + def request_restart(self, *, detached: bool = False, via_service: bool = False) -> bool: + if self._restart_task_started: + return False + self._restart_requested = True + self._restart_detached = detached + self._restart_via_service = via_service + self._restart_task_started = True + + async def _run_restart() -> None: + await asyncio.sleep(0.05) + await self.stop(restart=True, detached_restart=detached, service_restart=via_service) + + task = asyncio.create_task(_run_restart()) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return True + async def start(self) -> bool: """ Start the gateway and all configured platform adapters. @@ -1165,6 +1381,7 @@ class GatewayRunner: adapter.set_message_handler(self._handle_message) adapter.set_fatal_error_handler(self._handle_adapter_fatal_error) adapter.set_session_store(self.session_store) + adapter.set_busy_session_handler(self._handle_active_session_busy_message) # Try to connect logger.info("Connecting to %s...", platform.value) @@ -1240,11 +1457,7 @@ class GatewayRunner: self.delivery_router.adapters = self.adapters self._running = True - try: - from gateway.status import write_runtime_status - write_runtime_status(gateway_state="running", exit_reason=None) - except Exception: - pass + self._update_runtime_status("running") # Emit gateway:startup hook hook_count = len(self.hooks.loaded_hooks) @@ -1479,6 +1692,7 @@ class GatewayRunner: adapter.set_message_handler(self._handle_message) adapter.set_fatal_error_handler(self._handle_adapter_fatal_error) adapter.set_session_store(self.session_store) + adapter.set_busy_session_handler(self._handle_active_session_busy_message) success = await adapter.connect() if success: @@ -1525,90 +1739,108 @@ class GatewayRunner: return await asyncio.sleep(1) - async def stop(self) -> None: + async def stop( + self, + *, + restart: bool = False, + detached_restart: bool = False, + service_restart: bool = False, + ) -> None: """Stop the gateway and disconnect all adapters.""" - logger.info("Stopping gateway...") - self._running = False + if restart: + self._restart_requested = True + self._restart_detached = detached_restart + self._restart_via_service = service_restart + if self._stop_task is not None: + await self._stop_task + return - for session_key, agent in list(self._running_agents.items()): - if agent is _AGENT_PENDING_SENTINEL: - continue + async def _stop_impl() -> None: + logger.info( + "Stopping gateway%s...", + " for restart" if self._restart_requested else "", + ) + self._running = False + self._draining = True + + timeout = self._restart_drain_timeout + active_agents, timed_out = await self._drain_active_agents(timeout) + if timed_out: + logger.warning( + "Gateway drain timed out after %.1fs with %d active agent(s); interrupting remaining work.", + timeout, + self._running_agent_count(), + ) + self._interrupt_running_agents( + "Gateway restarting" if self._restart_requested else "Gateway shutting down" + ) + interrupt_deadline = asyncio.get_running_loop().time() + 5.0 + while self._running_agents and asyncio.get_running_loop().time() < interrupt_deadline: + self._update_runtime_status("draining") + await asyncio.sleep(0.1) + + if self._restart_requested and self._restart_detached: + try: + await self._launch_detached_restart_command() + except Exception as e: + logger.error("Failed to launch detached gateway restart: %s", e) + + self._finalize_shutdown_agents(active_agents) + + for platform, adapter in list(self.adapters.items()): + try: + await adapter.cancel_background_tasks() + except Exception as e: + logger.debug("✗ %s background-task cancel error: %s", platform.value, e) + try: + await adapter.disconnect() + logger.info("✓ %s disconnected", platform.value) + except Exception as e: + logger.error("✗ %s disconnect error: %s", platform.value, e) + + for _task in list(self._background_tasks): + if _task is self._stop_task: + continue + _task.cancel() + self._background_tasks.clear() + + self.adapters.clear() + self._running_agents.clear() + self._pending_messages.clear() + self._pending_approvals.clear() + self._shutdown_event.set() + + # Global cleanup: kill any remaining tool subprocesses not tied + # to a specific agent (catch-all for zombie prevention). try: - agent.interrupt("Gateway shutting down") - logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20]) - except Exception as e: - logger.debug("Failed interrupting agent during shutdown: %s", e) - # Fire plugin on_session_finalize hook before memory shutdown - try: - from hermes_cli.plugins import invoke_hook as _invoke_hook - _invoke_hook("on_session_finalize", - session_id=getattr(agent, 'session_id', None), - platform="gateway") + from tools.process_registry import process_registry + process_registry.kill_all() except Exception: pass - # Shut down memory provider at actual session boundary try: - if hasattr(agent, 'shutdown_memory_provider'): - agent.shutdown_memory_provider() + from tools.terminal_tool import cleanup_all_environments + cleanup_all_environments() except Exception: pass - # Close tool resources (terminal sandboxes, browser daemons, - # background processes, httpx clients) to prevent zombie - # process accumulation. try: - if hasattr(agent, 'close'): - agent.close() + from tools.browser_tool import cleanup_all_browsers + cleanup_all_browsers() except Exception: pass - for platform, adapter in list(self.adapters.items()): - try: - await adapter.cancel_background_tasks() - except Exception as e: - logger.debug("✗ %s background-task cancel error: %s", platform.value, e) - try: - await adapter.disconnect() - logger.info("✓ %s disconnected", platform.value) - except Exception as e: - logger.error("✗ %s disconnect error: %s", platform.value, e) + from gateway.status import remove_pid_file + remove_pid_file() - # Cancel any pending background tasks - for _task in list(self._background_tasks): - _task.cancel() - self._background_tasks.clear() + if self._restart_requested and self._restart_via_service: + self._exit_code = 75 + self._exit_reason = self._exit_reason or "Gateway restart requested" - self.adapters.clear() - self._running_agents.clear() - self._pending_messages.clear() - self._pending_approvals.clear() - self._shutdown_event.set() + self._draining = False + self._update_runtime_status("stopped", self._exit_reason) + logger.info("Gateway stopped") - # Global cleanup: kill any remaining tool subprocesses not tied - # to a specific agent (catch-all for zombie prevention). - try: - from tools.process_registry import process_registry - process_registry.kill_all() - except Exception: - pass - try: - from tools.terminal_tool import cleanup_all_environments - cleanup_all_environments() - except Exception: - pass - try: - from tools.browser_tool import cleanup_all_browsers - cleanup_all_browsers() - except Exception: - pass - - from gateway.status import remove_pid_file, write_runtime_status - remove_pid_file() - try: - write_runtime_status(gateway_state="stopped", exit_reason=self._exit_reason) - except Exception: - pass - - logger.info("Gateway stopped") + self._stop_task = asyncio.create_task(_stop_impl()) + await self._stop_task async def wait_for_shutdown(self) -> None: """Wait for shutdown signal.""" @@ -2014,6 +2246,9 @@ class GatewayRunner: _evt_cmd = event.get_command() _cmd_def_inner = _resolve_cmd_inner(_evt_cmd) if _evt_cmd else None + if _cmd_def_inner and _cmd_def_inner.name == "restart": + return await self._handle_restart_command(event) + # /stop must hard-kill the session when an agent is running. # A soft interrupt (agent.interrupt()) doesn't help when the agent # is truly hung — the executor thread is blocked and never checks @@ -2123,6 +2358,14 @@ class GatewayRunner: if adapter: adapter._pending_messages[_quick_key] = event return None + if self._draining: + if self._queue_during_drain_enabled(): + self._queue_or_replace_pending_event(_quick_key, event) + return ( + f"⏳ Gateway {self._status_action_gerund()} — queued for the next turn after it comes back." + if self._queue_during_drain_enabled() + else f"⏳ Gateway is {self._status_action_gerund()} and is not accepting another turn right now." + ) logger.debug("PRIORITY interrupt for session %s", _quick_key[:20]) running_agent.interrupt(event.text) if _quick_key in self._pending_messages: @@ -2164,6 +2407,9 @@ class GatewayRunner: if canonical == "status": return await self._handle_status_command(event) + + if canonical == "restart": + return await self._handle_restart_command(event) if canonical == "stop": return await self._handle_stop_command(event) @@ -2262,6 +2508,9 @@ class GatewayRunner: if canonical == "voice": return await self._handle_voice_command(event) + if self._draining: + return f"⏳ Gateway is {self._status_action_gerund()} and is not accepting new work right now." + # User-defined quick commands (bypass agent loop, no LLM call) if command: if isinstance(self.config, dict): @@ -3556,7 +3805,21 @@ class GatewayRunner: return "⚡ Force-stopped. The session is unlocked — you can send a new message." else: return "No active task to stop." - + + async def _handle_restart_command(self, event: MessageEvent) -> str: + """Handle /restart command - drain active work, then restart the gateway.""" + if self._restart_requested or self._draining: + count = self._running_agent_count() + if count: + return f"⏳ Draining {count} active agent(s) before restart..." + return "⏳ Gateway restart already in progress..." + + active_agents = self._running_agent_count() + self.request_restart(detached=True, via_service=False) + if active_agents: + return f"⏳ Draining {active_agents} active agent(s) before restart..." + return "♻ Restarting gateway..." + async def _handle_help_command(self, event: MessageEvent) -> str: """Handle /help command - list available commands.""" from hermes_cli.commands import gateway_help_lines @@ -7375,6 +7638,8 @@ class GatewayRunner: await asyncio.sleep(0.05) if session_key: self._running_agents[session_key] = agent_holder[0] + if self._draining: + self._update_runtime_status("draining") tracking_task = asyncio.create_task(track_agent()) @@ -7627,6 +7892,14 @@ class GatewayRunner: except Exception: pass + if self._draining and pending: + logger.info( + "Discarding pending follow-up for session %s during gateway %s", + session_key[:20] if session_key else "?", + self._status_action_label(), + ) + pending = None + if pending: logger.debug("Processing pending message: '%s...'", pending[:40]) @@ -7703,6 +7976,8 @@ class GatewayRunner: del self._running_agents[session_key] if session_key: self._running_agents_ts.pop(session_key, None) + if self._draining: + self._update_runtime_status("draining") # Wait for cancelled tasks for task in [progress_task, interrupt_monitor, tracking_task, _notify_task]: @@ -7900,13 +8175,21 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = runner = GatewayRunner(config) # Set up signal handlers - def signal_handler(): + def shutdown_signal_handler(): asyncio.create_task(runner.stop()) + + def restart_signal_handler(): + runner.request_restart(detached=False, via_service=True) loop = asyncio.get_event_loop() for sig in (signal.SIGINT, signal.SIGTERM): try: - loop.add_signal_handler(sig, signal_handler) + loop.add_signal_handler(sig, shutdown_signal_handler) + except NotImplementedError: + pass + if hasattr(signal, "SIGUSR1"): + try: + loop.add_signal_handler(signal.SIGUSR1, restart_signal_handler) except NotImplementedError: pass @@ -7956,6 +8239,9 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = except Exception: pass + if runner.exit_code is not None: + raise SystemExit(runner.exit_code) + return True diff --git a/gateway/status.py b/gateway/status.py index ff91262061..5423461c2f 100644 --- a/gateway/status.py +++ b/gateway/status.py @@ -158,6 +158,8 @@ def _build_runtime_status_record() -> dict[str, Any]: payload.update({ "gateway_state": "starting", "exit_reason": None, + "restart_requested": False, + "active_agents": 0, "platforms": {}, "updated_at": _utc_now_iso(), }) @@ -218,6 +220,8 @@ def write_runtime_status( *, gateway_state: Optional[str] = None, exit_reason: Optional[str] = None, + restart_requested: Optional[bool] = None, + active_agents: Optional[int] = None, platform: Optional[str] = None, platform_state: Optional[str] = None, error_code: Optional[str] = None, @@ -236,6 +240,10 @@ def write_runtime_status( payload["gateway_state"] = gateway_state if exit_reason is not None: payload["exit_reason"] = exit_reason + if restart_requested is not None: + payload["restart_requested"] = bool(restart_requested) + if active_agents is not None: + payload["active_agents"] = max(0, int(active_agents)) if platform is not None: platform_payload = payload["platforms"].get(platform, {}) diff --git a/hermes_cli/commands.py b/hermes_cli/commands.py index 84ec873a37..7cf8f30527 100644 --- a/hermes_cli/commands.py +++ b/hermes_cli/commands.py @@ -140,6 +140,8 @@ COMMAND_REGISTRY: list[CommandDef] = [ CommandDef("commands", "Browse all commands and skills (paginated)", "Info", gateway_only=True, args_hint="[page]"), CommandDef("help", "Show available commands", "Info"), + CommandDef("restart", "Gracefully restart the gateway after draining active runs", "Info", + gateway_only=True), CommandDef("usage", "Show token usage and rate limits for the current session", "Info"), CommandDef("insights", "Show usage insights and analytics", "Info", args_hint="[days]"), diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 5ddf37d082..2cb6a8d62a 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -269,6 +269,11 @@ DEFAULT_CONFIG = { # tools or receiving API responses. Only fires when the agent has # been completely idle for this duration. 0 = unlimited. "gateway_timeout": 1800, + # Graceful drain timeout for gateway stop/restart (seconds). + # The gateway stops accepting new work, waits for running agents + # to finish, then interrupts any remaining runs after the timeout. + # 0 = no drain, interrupt immediately. + "restart_drain_timeout": 60, "service_tier": "", # Tool-use enforcement: injects system prompt guidance that tells the # model to actually call tools instead of describing intended actions. diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index 58029c888a..0f5f4d15ff 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -15,7 +15,15 @@ from pathlib import Path PROJECT_ROOT = Path(__file__).parent.parent.resolve() from gateway.status import terminate_pid -from hermes_cli.config import get_env_value, get_hermes_home, save_env_value, is_managed, managed_error +from hermes_cli.config import ( + DEFAULT_CONFIG, + get_env_value, + get_hermes_home, + is_managed, + managed_error, + read_raw_config, + save_env_value, +) # display_hermes_home is imported lazily at call sites to avoid ImportError # when hermes_constants is cached from a pre-update version during `hermes update`. from hermes_cli.setup import ( @@ -687,6 +695,7 @@ def generate_systemd_unit(system: bool = False, run_as_user: str | None = None) path_entries.append(resolved_node_dir) common_bin_paths = ["/usr/local/sbin", "/usr/local/bin", "/usr/sbin", "/usr/bin", "/sbin", "/bin"] + restart_timeout = max(60, int(_get_restart_drain_timeout() or 0)) if system: username, group_name, home_dir = _system_service_identity(run_as_user) @@ -725,9 +734,11 @@ Environment="VIRTUAL_ENV={venv_dir}" Environment="HERMES_HOME={hermes_home}" Restart=on-failure RestartSec=30 +RestartForceExitStatus=75 KillMode=mixed KillSignal=SIGTERM -TimeoutStopSec=60 +ExecReload=/bin/kill -USR1 $MAINPID +TimeoutStopSec={restart_timeout} StandardOutput=journal StandardError=journal @@ -755,9 +766,11 @@ Environment="VIRTUAL_ENV={venv_dir}" Environment="HERMES_HOME={hermes_home}" Restart=on-failure RestartSec=30 +RestartForceExitStatus=75 KillMode=mixed KillSignal=SIGTERM -TimeoutStopSec=60 +ExecReload=/bin/kill -USR1 $MAINPID +TimeoutStopSec={restart_timeout} StandardOutput=journal StandardError=journal @@ -860,6 +873,19 @@ def _select_systemd_scope(system: bool = False) -> bool: return get_systemd_unit_path(system=True).exists() and not get_systemd_unit_path(system=False).exists() +def _get_restart_drain_timeout() -> float: + """Return the configured gateway restart drain timeout in seconds.""" + raw = os.getenv("HERMES_RESTART_DRAIN_TIMEOUT", "").strip() + if not raw: + cfg = read_raw_config() + agent_cfg = cfg.get("agent", {}) if isinstance(cfg, dict) else {} + raw = str(agent_cfg.get("restart_drain_timeout", DEFAULT_CONFIG["agent"]["restart_drain_timeout"])) + try: + return max(0.0, float(raw)) + except (TypeError, ValueError): + return float(DEFAULT_CONFIG["agent"]["restart_drain_timeout"]) + + def systemd_install(force: bool = False, system: bool = False, run_as_user: str | None = None): if system: _require_root_for_system_service("install") @@ -945,7 +971,7 @@ def systemd_restart(system: bool = False): if system: _require_root_for_system_service("restart") refresh_systemd_unit_if_needed(system=system) - subprocess.run(_systemctl_cmd(system) + ["restart", get_service_name()], check=True, timeout=90) + subprocess.run(_systemctl_cmd(system) + ["reload-or-restart", get_service_name()], check=True, timeout=90) print(f"✓ {_service_scope_label(system).capitalize()} service restarted") @@ -1233,7 +1259,7 @@ def launchd_stop(): _wait_for_gateway_exit(timeout=10.0, force_after=5.0) print("✓ Service stopped") -def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0): +def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float | None = 5.0) -> bool: """Wait for the gateway process (by saved PID) to exit. Uses the PID from the gateway.pid file — not launchd labels — so this @@ -1248,21 +1274,21 @@ def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0): from gateway.status import get_running_pid deadline = time.monotonic() + timeout - force_deadline = time.monotonic() + force_after + force_deadline = (time.monotonic() + force_after) if force_after is not None else None force_sent = False while time.monotonic() < deadline: pid = get_running_pid() if pid is None: - return # Process exited cleanly. + return True # Process exited cleanly. - if not force_sent and time.monotonic() >= force_deadline: + if force_after is not None and not force_sent and time.monotonic() >= force_deadline: # Grace period expired — force-kill the specific PID. try: terminate_pid(pid, force=True) print(f"⚠ Gateway PID {pid} did not exit gracefully; sent SIGKILL") except (ProcessLookupError, PermissionError, OSError): - return # Already gone or we can't touch it. + return True # Already gone or we can't touch it. force_sent = True time.sleep(0.3) @@ -1271,15 +1297,27 @@ def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0): remaining_pid = get_running_pid() if remaining_pid is not None: print(f"⚠ Gateway PID {remaining_pid} still running after {timeout}s — restart may fail") + return False + return True def launchd_restart(): label = get_launchd_label() target = f"{_launchd_domain()}/{label}" - # Use kickstart -k so launchd performs an atomic kill+restart. - # A two-step stop/start from inside the gateway's own process tree - # would kill the shell before the start command is reached. + drain_timeout = _get_restart_drain_timeout() + from gateway.status import get_running_pid + try: + pid = get_running_pid() + if pid is not None: + try: + terminate_pid(pid, force=False) + except (ProcessLookupError, PermissionError, OSError): + pid = None + if pid is not None: + exited = _wait_for_gateway_exit(timeout=drain_timeout, force_after=None) + if not exited: + print(f"⚠ Gateway drain timed out after {drain_timeout:.0f}s — forcing launchd restart") subprocess.run(["launchctl", "kickstart", "-k", target], check=True, timeout=90) print("✓ Service restarted") except subprocess.CalledProcessError as e: @@ -1750,6 +1788,8 @@ def _runtime_health_lines() -> list[str]: lines: list[str] = [] gateway_state = state.get("gateway_state") exit_reason = state.get("exit_reason") + active_agents = state.get("active_agents") + restart_requested = state.get("restart_requested") platforms = state.get("platforms", {}) or {} for platform, pdata in platforms.items(): @@ -1759,6 +1799,10 @@ def _runtime_health_lines() -> list[str]: if gateway_state == "startup_failed" and exit_reason: lines.append(f"⚠ Last startup issue: {exit_reason}") + elif gateway_state == "draining": + action = "restart" if restart_requested else "shutdown" + count = int(active_agents or 0) + lines.append(f"⏳ Gateway draining for {action} ({count} active agent(s))") elif gateway_state == "stopped" and exit_reason: lines.append(f"⚠ Last shutdown reason: {exit_reason}") diff --git a/tests/gateway/test_gateway_shutdown.py b/tests/gateway/test_gateway_shutdown.py index 439fbfdb05..b6a7f8fa72 100644 --- a/tests/gateway/test_gateway_shutdown.py +++ b/tests/gateway/test_gateway_shutdown.py @@ -37,6 +37,30 @@ def _source(chat_id="123456", chat_type="dm"): ) +def _make_runner() -> GatewayRunner: + runner = object.__new__(GatewayRunner) + runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}) + runner._running = True + runner._shutdown_event = asyncio.Event() + runner._exit_reason = None + runner._exit_code = None + runner._pending_messages = {} + runner._pending_approvals = {} + runner._background_tasks = set() + runner._running_agents = {} + runner._running_agents_ts = {} + runner._draining = False + runner._restart_requested = False + runner._restart_task_started = False + runner._restart_detached = False + runner._restart_via_service = False + runner._restart_drain_timeout = 60.0 + runner._stop_task = None + runner._shutdown_all_gateway_honcho = lambda: None + runner._update_runtime_status = MagicMock() + return runner + + @pytest.mark.asyncio async def test_cancel_background_tasks_cancels_inflight_message_processing(): adapter = StubAdapter() @@ -65,15 +89,10 @@ async def test_cancel_background_tasks_cancels_inflight_message_processing(): @pytest.mark.asyncio async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(): - runner = object.__new__(GatewayRunner) - runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}) - runner._running = True - runner._shutdown_event = asyncio.Event() - runner._exit_reason = None + runner = _make_runner() runner._pending_messages = {"session": "pending text"} runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}} - runner._background_tasks = set() - runner._shutdown_all_gateway_honcho = lambda: None + runner._restart_drain_timeout = 0.0 adapter = StubAdapter() release = asyncio.Event() @@ -105,3 +124,49 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks( assert runner._pending_messages == {} assert runner._pending_approvals == {} assert runner._shutdown_event.is_set() is True + + +@pytest.mark.asyncio +async def test_gateway_stop_drains_running_agents_before_disconnect(): + runner = _make_runner() + adapter = StubAdapter() + disconnect_mock = AsyncMock() + adapter.disconnect = disconnect_mock + runner.adapters = {Platform.TELEGRAM: adapter} + + running_agent = MagicMock() + runner._running_agents = {"session": running_agent} + + async def finish_agent(): + await asyncio.sleep(0.05) + runner._running_agents.clear() + + asyncio.create_task(finish_agent()) + + with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"): + await runner.stop() + + running_agent.interrupt.assert_not_called() + disconnect_mock.assert_awaited_once() + assert runner._shutdown_event.is_set() is True + + +@pytest.mark.asyncio +async def test_gateway_stop_interrupts_after_drain_timeout(): + runner = _make_runner() + runner._restart_drain_timeout = 0.05 + + adapter = StubAdapter() + disconnect_mock = AsyncMock() + adapter.disconnect = disconnect_mock + runner.adapters = {Platform.TELEGRAM: adapter} + + running_agent = MagicMock() + runner._running_agents = {"session": running_agent} + + with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"): + await runner.stop() + + running_agent.interrupt.assert_called_once_with("Gateway shutting down") + disconnect_mock.assert_awaited_once() + assert runner._shutdown_event.is_set() is True diff --git a/tests/gateway/test_restart_drain.py b/tests/gateway/test_restart_drain.py new file mode 100644 index 0000000000..2c59f9a976 --- /dev/null +++ b/tests/gateway/test_restart_drain.py @@ -0,0 +1,133 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from gateway.config import GatewayConfig, Platform, PlatformConfig +from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult +from gateway.run import GatewayRunner +from gateway.session import SessionSource, build_session_key + + +class RecordingAdapter(BasePlatformAdapter): + def __init__(self): + super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM) + self.sent: list[str] = [] + + async def connect(self): + return True + + async def disconnect(self): + return None + + async def send(self, chat_id, content, reply_to=None, metadata=None): + self.sent.append(content) + return SendResult(success=True, message_id="1") + + async def send_typing(self, chat_id, metadata=None): + return None + + async def get_chat_info(self, chat_id): + return {"id": chat_id} + + +def _source(chat_id="123456"): + return SessionSource( + platform=Platform.TELEGRAM, + chat_id=chat_id, + chat_type="dm", + ) + + +def _make_runner() -> tuple[GatewayRunner, RecordingAdapter]: + runner = object.__new__(GatewayRunner) + runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}) + runner.adapters = {} + runner._running = True + runner._shutdown_event = asyncio.Event() + runner._exit_reason = None + runner._exit_code = None + runner._running_agents = {} + runner._running_agents_ts = {} + runner._pending_messages = {} + runner._pending_approvals = {} + runner._background_tasks = set() + runner._draining = False + runner._restart_requested = False + runner._restart_task_started = False + runner._restart_detached = False + runner._restart_via_service = False + runner._restart_drain_timeout = 60.0 + runner._stop_task = None + runner._busy_input_mode = "interrupt" + runner._update_prompt_pending = {} + runner._voice_mode = {} + runner._update_runtime_status = MagicMock() + runner._queue_or_replace_pending_event = GatewayRunner._queue_or_replace_pending_event.__get__(runner, GatewayRunner) + runner._session_key_for_source = GatewayRunner._session_key_for_source.__get__(runner, GatewayRunner) + runner._handle_active_session_busy_message = GatewayRunner._handle_active_session_busy_message.__get__(runner, GatewayRunner) + runner._handle_restart_command = GatewayRunner._handle_restart_command.__get__(runner, GatewayRunner) + runner._status_action_label = GatewayRunner._status_action_label.__get__(runner, GatewayRunner) + runner._status_action_gerund = GatewayRunner._status_action_gerund.__get__(runner, GatewayRunner) + runner._queue_during_drain_enabled = GatewayRunner._queue_during_drain_enabled.__get__(runner, GatewayRunner) + runner._running_agent_count = GatewayRunner._running_agent_count.__get__(runner, GatewayRunner) + runner.request_restart = MagicMock(return_value=True) + runner._is_user_authorized = lambda _source: True + runner.hooks = MagicMock() + runner.hooks.emit = AsyncMock() + runner.pairing_store = MagicMock() + runner.session_store = MagicMock() + runner.delivery_router = MagicMock() + + adapter = RecordingAdapter() + adapter.set_message_handler(AsyncMock(return_value=None)) + adapter.set_busy_session_handler(runner._handle_active_session_busy_message) + runner.adapters = {Platform.TELEGRAM: adapter} + return runner, adapter + + +@pytest.mark.asyncio +async def test_restart_command_while_busy_requests_drain_without_interrupt(): + runner, _adapter = _make_runner() + event = MessageEvent(text="/restart", message_type=MessageType.TEXT, source=_source(), message_id="m1") + session_key = build_session_key(event.source) + running_agent = MagicMock() + runner._running_agents[session_key] = running_agent + + result = await runner._handle_message(event) + + assert result == "⏳ Draining 1 active agent(s) before restart..." + running_agent.interrupt.assert_not_called() + runner.request_restart.assert_called_once_with(detached=True, via_service=False) + + +@pytest.mark.asyncio +async def test_drain_queue_mode_queues_follow_up_without_interrupt(): + runner, adapter = _make_runner() + runner._draining = True + runner._restart_requested = True + runner._busy_input_mode = "queue" + + event = MessageEvent(text="follow up", message_type=MessageType.TEXT, source=_source(), message_id="m2") + session_key = build_session_key(event.source) + adapter._active_sessions[session_key] = asyncio.Event() + + await adapter.handle_message(event) + + assert session_key in adapter._pending_messages + assert adapter._pending_messages[session_key].text == "follow up" + assert not adapter._active_sessions[session_key].is_set() + assert any("queued for the next turn" in message for message in adapter.sent) + + +@pytest.mark.asyncio +async def test_draining_rejects_new_session_messages(): + runner, _adapter = _make_runner() + runner._draining = True + runner._restart_requested = True + + event = MessageEvent(text="hello", message_type=MessageType.TEXT, source=_source("fresh"), message_id="m3") + + result = await runner._handle_message(event) + + assert result == "⏳ Gateway is restarting and is not accepting new work right now." diff --git a/tests/gateway/test_session_boundary_hooks.py b/tests/gateway/test_session_boundary_hooks.py index 31e02980a7..a556624363 100644 --- a/tests/gateway/test_session_boundary_hooks.py +++ b/tests/gateway/test_session_boundary_hooks.py @@ -127,6 +127,16 @@ async def test_shutdown_fires_finalize_for_active_agents(mock_invoke_hook): runner._shutdown_event = MagicMock() runner.adapters = {} runner._exit_reason = "test" + runner._exit_code = None + runner._draining = False + runner._restart_requested = False + runner._restart_task_started = False + runner._restart_detached = False + runner._restart_via_service = False + runner._restart_drain_timeout = 0.0 + runner._stop_task = None + runner._running_agents_ts = {} + runner._update_runtime_status = MagicMock() agent1 = MagicMock() agent1.session_id = "sess-a" diff --git a/tests/gateway/test_session_race_guard.py b/tests/gateway/test_session_race_guard.py index ff21cdef8c..7a4f6f1011 100644 --- a/tests/gateway/test_session_race_guard.py +++ b/tests/gateway/test_session_race_guard.py @@ -41,6 +41,15 @@ def _make_runner(): runner._pending_approvals = {} runner._voice_mode = {} runner._background_tasks = set() + runner._draining = False + runner._restart_requested = False + runner._restart_task_started = False + runner._restart_detached = False + runner._restart_via_service = False + runner._restart_drain_timeout = 0.0 + runner._stop_task = None + runner._exit_code = None + runner._update_runtime_status = MagicMock() runner._is_user_authorized = lambda _source: True runner.hooks = MagicMock() runner.hooks.emit = AsyncMock() diff --git a/tests/hermes_cli/test_gateway_service.py b/tests/hermes_cli/test_gateway_service.py index b32c7fe787..3586564e8e 100644 --- a/tests/hermes_cli/test_gateway_service.py +++ b/tests/hermes_cli/test_gateway_service.py @@ -74,7 +74,7 @@ class TestSystemdServiceRefresh: assert unit_path.read_text(encoding="utf-8") == "new unit\n" assert calls[:2] == [ ["systemctl", "--user", "daemon-reload"], - ["systemctl", "--user", "restart", gateway_cli.get_service_name()], + ["systemctl", "--user", "reload-or-restart", gateway_cli.get_service_name()], ] @@ -84,6 +84,8 @@ class TestGeneratedSystemdUnits: assert "ExecStart=" in unit assert "ExecStop=" not in unit + assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit + assert "RestartForceExitStatus=75" in unit assert "TimeoutStopSec=60" in unit def test_user_unit_includes_resolved_node_directory_in_path(self, monkeypatch): @@ -98,6 +100,8 @@ class TestGeneratedSystemdUnits: assert "ExecStart=" in unit assert "ExecStop=" not in unit + assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit + assert "RestartForceExitStatus=75" in unit assert "TimeoutStopSec=60" in unit assert "WantedBy=multi-user.target" in unit @@ -234,6 +238,31 @@ class TestLaunchdServiceRecovery: ["launchctl", "kickstart", target], ] + def test_launchd_restart_drains_running_gateway_before_kickstart(self, monkeypatch): + calls = [] + target = f"{gateway_cli._launchd_domain()}/{gateway_cli.get_launchd_label()}" + + monkeypatch.setattr(gateway_cli, "_get_restart_drain_timeout", lambda: 12.0) + monkeypatch.setattr(gateway_cli, "_wait_for_gateway_exit", lambda timeout, force_after=None: True) + monkeypatch.setattr(gateway_cli, "terminate_pid", lambda pid, force=False: calls.append(("term", pid, force))) + monkeypatch.setattr( + "gateway.status.get_running_pid", + lambda: 321, + ) + + def fake_run(cmd, check=False, **kwargs): + calls.append(cmd) + return SimpleNamespace(returncode=0, stdout="", stderr="") + + monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run) + + gateway_cli.launchd_restart() + + assert calls == [ + ("term", 321, False), + ["launchctl", "kickstart", "-k", target], + ] + def test_launchd_stop_uses_bootout_not_kill(self, monkeypatch): """launchd_stop must bootout the service so KeepAlive doesn't respawn it.""" label = gateway_cli.get_launchd_label()