diff --git a/gateway/hooks.py b/gateway/hooks.py index c50394b20..6a3250e24 100644 --- a/gateway/hooks.py +++ b/gateway/hooks.py @@ -3,11 +3,12 @@ Event Hook System A lightweight event-driven system that fires handlers at key lifecycle points. Hooks are discovered from ~/.hermes/hooks/ directories, each containing: - - HOOK.yaml (metadata: name, description, events list) + - HOOK.yaml (metadata: name, description, events list, optional startup_readiness) - handler.py (Python handler with async def handle(event_type, context)) Events: - gateway:startup -- Gateway process starts + - gateway:shutdown -- Gateway process is shutting down - session:start -- New session created (first message of a new session) - session:end -- Session ends (user ran /new or /reset) - session:reset -- Session reset completed (new session entry created) @@ -31,6 +32,26 @@ from hermes_cli.config import get_hermes_home HOOKS_DIR = get_hermes_home() / "hooks" +def _normalize_startup_readiness(hook_name: str, manifest: dict[str, Any]) -> Optional[dict[str, Any]]: + """Validate and normalize optional startup readiness metadata.""" + readiness = manifest.get("startup_readiness") + if readiness is None: + return None + if not isinstance(readiness, dict): + print(f"[hooks] Ignoring startup_readiness for {hook_name}: expected mapping", flush=True) + return None + + check_id = str(readiness.get("id", "")).strip() + if not check_id: + print(f"[hooks] Ignoring startup_readiness for {hook_name}: missing id", flush=True) + return None + + return { + "id": check_id, + "required": bool(readiness.get("required", True)), + } + + class HookRegistry: """ Discovers, loads, and fires event hooks. @@ -62,6 +83,7 @@ class HookRegistry: "description": "Run ~/.hermes/BOOT.md on gateway startup", "events": ["gateway:startup"], "path": "(builtin)", + "startup_readiness": None, }) except Exception as e: print(f"[hooks] Could not load built-in boot-md hook: {e}", flush=True) @@ -102,6 +124,7 @@ class HookRegistry: if not events: print(f"[hooks] Skipping {hook_name}: no events declared", flush=True) continue + startup_readiness = _normalize_startup_readiness(hook_name, manifest) # Dynamically load the handler module spec = importlib.util.spec_from_file_location( @@ -128,6 +151,7 @@ class HookRegistry: "description": manifest.get("description", ""), "events": events, "path": str(hook_dir), + "startup_readiness": startup_readiness, }) print(f"[hooks] Loaded hook '{hook_name}' for events: {events}", flush=True) diff --git a/gateway/run.py b/gateway/run.py index 5c3e5f13c..4f6b9da40 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1540,7 +1540,7 @@ class GatewayRunner: pass try: from gateway.status import write_runtime_status - write_runtime_status(gateway_state="starting", exit_reason=None) + write_runtime_status(gateway_state="starting", exit_reason=None, startup_checks={}) except Exception: pass @@ -1582,8 +1582,23 @@ class GatewayRunner: "or configure platform allowlists (e.g., TELEGRAM_ALLOWED_USERS=your_id)." ) + # Discover plugins before hooks so plugin-owned hook bundles can + # participate in this same startup cycle. + try: + from hermes_cli.plugins import discover_plugins + + discover_plugins() + except Exception as e: + logger.warning("Plugin discovery during gateway startup failed: %s", e) + # Discover and load event hooks self.hooks.discover_and_load() + try: + from gateway.status import reset_startup_checks + + reset_startup_checks(self.hooks.loaded_hooks) + except Exception as e: + logger.warning("Startup readiness initialization failed: %s", e) # Recover background processes from checkpoint (crash recovery) try: @@ -2104,6 +2119,11 @@ class GatewayRunner: logger.error("Failed to launch detached gateway restart: %s", e) self._finalize_shutdown_agents(active_agents) + await self.hooks.emit("gateway:shutdown", { + "restart": self._restart_requested, + "service_restart": self._restart_via_service, + "detached_restart": self._restart_detached, + }) for platform, adapter in list(self.adapters.items()): try: diff --git a/gateway/status.py b/gateway/status.py index becf9e8cb..eb4995485 100644 --- a/gateway/status.py +++ b/gateway/status.py @@ -27,6 +27,7 @@ _RUNTIME_STATUS_FILE = "gateway_state.json" _LOCKS_DIRNAME = "gateway-locks" _IS_WINDOWS = sys.platform == "win32" _UNSET = object() +_VALID_STARTUP_CHECK_STATES = {"pending", "ready", "failed"} def _get_pid_path() -> Path: @@ -162,11 +163,39 @@ def _build_runtime_status_record() -> dict[str, Any]: "restart_requested": False, "active_agents": 0, "platforms": {}, + "startup_checks": {}, "updated_at": _utc_now_iso(), }) return payload +def _normalize_startup_check_entries( + startup_checks: Optional[dict[str, Any]], +) -> dict[str, dict[str, Any]]: + """Normalize persisted startup readiness entries.""" + if not isinstance(startup_checks, dict): + return {} + + now = _utc_now_iso() + normalized: dict[str, dict[str, Any]] = {} + for raw_id, raw_payload in startup_checks.items(): + check_id = str(raw_id).strip() + if not check_id: + continue + payload = raw_payload if isinstance(raw_payload, dict) else {} + state = str(payload.get("state", "pending")).strip().lower() + if state not in _VALID_STARTUP_CHECK_STATES: + state = "pending" + normalized[check_id] = { + "state": state, + "required": bool(payload.get("required", True)), + "source": payload.get("source"), + "detail": payload.get("detail"), + "updated_at": payload.get("updated_at") or now, + } + return normalized + + def _read_json_file(path: Path) -> Optional[dict[str, Any]]: if not path.exists(): return None @@ -223,6 +252,7 @@ def write_runtime_status( exit_reason: Any = _UNSET, restart_requested: Any = _UNSET, active_agents: Any = _UNSET, + startup_checks: Any = _UNSET, platform: Any = _UNSET, platform_state: Any = _UNSET, error_code: Any = _UNSET, @@ -245,6 +275,8 @@ def write_runtime_status( payload["restart_requested"] = bool(restart_requested) if active_agents is not _UNSET: payload["active_agents"] = max(0, int(active_agents)) + if startup_checks is not _UNSET: + payload["startup_checks"] = _normalize_startup_check_entries(startup_checks) if platform is not _UNSET: platform_payload = payload["platforms"].get(platform, {}) @@ -262,7 +294,109 @@ def write_runtime_status( def read_runtime_status() -> Optional[dict[str, Any]]: """Read the persisted gateway runtime health/status information.""" - return _read_json_file(_get_runtime_status_path()) + payload = _read_json_file(_get_runtime_status_path()) + if payload is None: + return None + payload.setdefault("platforms", {}) + payload["startup_checks"] = _normalize_startup_check_entries(payload.get("startup_checks")) + return payload + + +def reset_startup_checks(checks: Optional[list[dict[str, Any]]] = None) -> dict[str, dict[str, Any]]: + """Replace persisted startup readiness checks for the current run.""" + normalized: dict[str, dict[str, Any]] = {} + now = _utc_now_iso() + + for hook in checks or []: + if not isinstance(hook, dict): + continue + readiness = hook.get("startup_readiness") + if not isinstance(readiness, dict): + continue + check_id = str(readiness.get("id", "")).strip() + if not check_id: + continue + normalized[check_id] = { + "state": "pending", + "required": bool(readiness.get("required", True)), + "source": hook.get("name"), + "detail": None, + "updated_at": now, + } + + write_runtime_status(startup_checks=normalized) + return normalized + + +def update_startup_check( + check_id: str, + state: str, + *, + detail: Any = _UNSET, + required: Any = _UNSET, + source: Any = _UNSET, +) -> dict[str, Any]: + """Update a single startup readiness check in the runtime status file.""" + normalized_id = str(check_id).strip() + if not normalized_id: + raise ValueError("startup readiness check id is required") + + normalized_state = str(state).strip().lower() + if normalized_state not in _VALID_STARTUP_CHECK_STATES: + raise ValueError(f"invalid startup readiness state: {state}") + + path = _get_runtime_status_path() + payload = _read_json_file(path) or _build_runtime_status_record() + checks = _normalize_startup_check_entries(payload.get("startup_checks")) + existing = checks.get(normalized_id, {}) + now = _utc_now_iso() + + checks[normalized_id] = { + "state": normalized_state, + "required": bool(existing.get("required", True) if required is _UNSET else required), + "source": existing.get("source") if source is _UNSET else source, + "detail": existing.get("detail") if detail is _UNSET else detail, + "updated_at": now, + } + + payload["startup_checks"] = checks + payload.setdefault("platforms", {}) + payload.setdefault("kind", _GATEWAY_KIND) + payload["pid"] = os.getpid() + payload["start_time"] = _get_process_start_time(os.getpid()) + payload["updated_at"] = now + _write_json_file(path, payload) + return checks[normalized_id] + + +def mark_startup_check_pending( + check_id: str, + *, + detail: Any = _UNSET, + required: Any = _UNSET, + source: Any = _UNSET, +) -> dict[str, Any]: + return update_startup_check(check_id, "pending", detail=detail, required=required, source=source) + + +def mark_startup_check_ready( + check_id: str, + *, + detail: Any = _UNSET, + required: Any = _UNSET, + source: Any = _UNSET, +) -> dict[str, Any]: + return update_startup_check(check_id, "ready", detail=detail, required=required, source=source) + + +def mark_startup_check_failed( + check_id: str, + *, + detail: Any = _UNSET, + required: Any = _UNSET, + source: Any = _UNSET, +) -> dict[str, Any]: + return update_startup_check(check_id, "failed", detail=detail, required=required, source=source) def remove_pid_file() -> None: diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index fe7bb9bd8..61278bed7 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -10,6 +10,7 @@ import shutil import signal import subprocess import sys +import time from pathlib import Path PROJECT_ROOT = Path(__file__).parent.parent.resolve() @@ -37,6 +38,10 @@ from hermes_cli.setup import ( from hermes_cli.colors import Colors, color +_SERVICE_READINESS_TIMEOUT = 30.0 +_SERVICE_READINESS_POLL_INTERVAL = 0.2 + + # ============================================================================= # Process Management (for manual gateway runs) # ============================================================================= @@ -1100,12 +1105,123 @@ def systemd_uninstall(system: bool = False): print(f"✓ {_service_scope_label(system).capitalize()} service uninstalled") +def _describe_startup_check(check_id: str, check: dict) -> str: + source = check.get("source") + detail = check.get("detail") + label = f"{check_id} ({source})" if source and source != check_id else check_id + return f"{label}: {detail}" if detail else label + + +def _classify_startup_checks(state: dict | None) -> tuple[list[str], list[str], list[str]]: + checks = (state or {}).get("startup_checks") or {} + pending_required: list[str] = [] + failed_required: list[str] = [] + optional_warnings: list[str] = [] + + if not isinstance(checks, dict): + return pending_required, failed_required, optional_warnings + + for check_id, raw_check in checks.items(): + check = raw_check if isinstance(raw_check, dict) else {} + label = _describe_startup_check(str(check_id), check) + check_state = str(check.get("state", "pending")).strip().lower() + required = bool(check.get("required", True)) + + if check_state == "ready": + continue + if required: + if check_state == "failed": + failed_required.append(label) + else: + pending_required.append(label) + else: + prefix = "failed" if check_state == "failed" else "pending" + optional_warnings.append(f"{prefix}: {label}") + + return pending_required, failed_required, optional_warnings + + +def _wait_for_service_readiness( + *, + action: str, + previous_pid: int | None = None, + timeout: float = _SERVICE_READINESS_TIMEOUT, + poll_interval: float = _SERVICE_READINESS_POLL_INTERVAL, +) -> list[str]: + from gateway.status import get_running_pid, read_runtime_status + + deadline = time.monotonic() + timeout + last_pending: list[str] = [] + + while time.monotonic() < deadline: + live_pid = get_running_pid() + if live_pid is None or (previous_pid is not None and live_pid == previous_pid): + time.sleep(poll_interval) + continue + + runtime = read_runtime_status() or {} + try: + runtime_pid = int(runtime.get("pid")) + except (TypeError, ValueError): + runtime_pid = None + if runtime_pid != live_pid: + time.sleep(poll_interval) + continue + + gateway_state = runtime.get("gateway_state") + pending_required, failed_required, optional_warnings = _classify_startup_checks(runtime) + last_pending = pending_required + + if gateway_state == "startup_failed": + reason = runtime.get("exit_reason") or f"gateway {action} failed during startup" + raise RuntimeError(reason) + if failed_required: + raise RuntimeError( + "required startup checks failed: " + "; ".join(failed_required) + ) + if gateway_state == "running" and not pending_required: + return optional_warnings + + time.sleep(poll_interval) + + if last_pending: + raise RuntimeError( + "timed out waiting for required startup checks: " + "; ".join(last_pending) + ) + if previous_pid is not None: + raise RuntimeError( + f"timed out waiting for gateway {action}; previous process is still active or no new runtime became ready" + ) + raise RuntimeError(f"timed out waiting for gateway {action} readiness") + + +def _await_service_ready_or_exit( + *, + action: str, + previous_pid: int | None = None, + timeout: float = _SERVICE_READINESS_TIMEOUT, +) -> None: + try: + optional_warnings = _wait_for_service_readiness( + action=action, + previous_pid=previous_pid, + timeout=timeout, + ) + except RuntimeError as exc: + print_error(f" Gateway {action} did not become ready: {exc}") + raise SystemExit(1) from exc + + for warning in optional_warnings: + print_warning(f" Optional startup check {warning}") + + def systemd_start(system: bool = False): system = _select_systemd_scope(system) if system: _require_root_for_system_service("start") refresh_systemd_unit_if_needed(system=system) _run_systemctl(["start", get_service_name()], system=system, check=True, timeout=30) + _await_service_ready_or_exit(action="start") print(f"✓ {_service_scope_label(system).capitalize()} service started") @@ -1128,9 +1244,11 @@ def systemd_restart(system: bool = False): pid = get_running_pid() if pid is not None and _request_gateway_self_restart(pid): - print(f"✓ {_service_scope_label(system).capitalize()} service restart requested") + _await_service_ready_or_exit(action="restart", previous_pid=pid) + print(f"✓ {_service_scope_label(system).capitalize()} service restarted") return _run_systemctl(["reload-or-restart", get_service_name()], system=system, check=True, timeout=90) + _await_service_ready_or_exit(action="restart", previous_pid=pid) print(f"✓ {_service_scope_label(system).capitalize()} service restarted") @@ -1389,6 +1507,7 @@ def launchd_start(): plist_path.write_text(generate_launchd_plist(), encoding="utf-8") subprocess.run(["launchctl", "bootstrap", _launchd_domain(), str(plist_path)], check=True, timeout=30) subprocess.run(["launchctl", "kickstart", f"{_launchd_domain()}/{label}"], check=True, timeout=30) + _await_service_ready_or_exit(action="start") print("✓ Service started") return @@ -1401,6 +1520,7 @@ def launchd_start(): print("↻ launchd job was unloaded; reloading service definition") subprocess.run(["launchctl", "bootstrap", _launchd_domain(), str(plist_path)], check=True, timeout=30) subprocess.run(["launchctl", "kickstart", f"{_launchd_domain()}/{label}"], check=True, timeout=30) + _await_service_ready_or_exit(action="start") print("✓ Service started") def launchd_stop(): @@ -1471,7 +1591,8 @@ def launchd_restart(): try: pid = get_running_pid() if pid is not None and _request_gateway_self_restart(pid): - print("✓ Service restart requested") + _await_service_ready_or_exit(action="restart", previous_pid=pid) + print("✓ Service restarted") return if pid is not None: try: @@ -1483,6 +1604,7 @@ def launchd_restart(): if not exited: print(f"⚠ Gateway drain timed out after {drain_timeout:.0f}s — forcing launchd restart") subprocess.run(["launchctl", "kickstart", "-k", target], check=True, timeout=90) + _await_service_ready_or_exit(action="restart", previous_pid=pid) print("✓ Service restarted") except subprocess.CalledProcessError as e: if e.returncode not in (3, 113): @@ -1492,6 +1614,7 @@ def launchd_restart(): plist_path = get_launchd_plist_path() subprocess.run(["launchctl", "bootstrap", _launchd_domain(), str(plist_path)], check=True, timeout=30) subprocess.run(["launchctl", "kickstart", target], check=True, timeout=30) + _await_service_ready_or_exit(action="restart", previous_pid=pid) print("✓ Service restarted") def launchd_status(deep: bool = False): diff --git a/tests/gateway/test_gateway_shutdown.py b/tests/gateway/test_gateway_shutdown.py index 4dc9919bc..5a0be040d 100644 --- a/tests/gateway/test_gateway_shutdown.py +++ b/tests/gateway/test_gateway_shutdown.py @@ -125,6 +125,25 @@ async def test_gateway_stop_service_restart_sets_named_exit_code(): assert runner._exit_code == GATEWAY_SERVICE_RESTART_EXIT_CODE +@pytest.mark.asyncio +async def test_gateway_stop_emits_shutdown_hook_after_drain(monkeypatch): + runner, adapter = make_restart_runner() + adapter.disconnect = AsyncMock() + runner.hooks.emit = AsyncMock() + + with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"): + await runner.stop(restart=True, service_restart=True) + + runner.hooks.emit.assert_awaited_once_with( + "gateway:shutdown", + { + "restart": True, + "service_restart": True, + "detached_restart": False, + }, + ) + + @pytest.mark.asyncio async def test_drain_active_agents_throttles_status_updates(): runner, _adapter = make_restart_runner() diff --git a/tests/gateway/test_hooks.py b/tests/gateway/test_hooks.py index 1301aebae..49d80afd7 100644 --- a/tests/gateway/test_hooks.py +++ b/tests/gateway/test_hooks.py @@ -9,7 +9,7 @@ import pytest from gateway.hooks import HookRegistry -def _create_hook(hooks_dir, hook_name, events, handler_code): +def _create_hook(hooks_dir, hook_name, events, handler_code, *, manifest_extra=""): """Helper to create a hook directory with HOOK.yaml and handler.py.""" hook_dir = hooks_dir / hook_name hook_dir.mkdir(parents=True) @@ -17,6 +17,7 @@ def _create_hook(hooks_dir, hook_name, events, handler_code): f"name: {hook_name}\n" f"description: Test hook\n" f"events: {events}\n" + f"{manifest_extra}" ) (hook_dir / "handler.py").write_text(handler_code) return hook_dir @@ -112,6 +113,24 @@ class TestDiscoverAndLoad: assert len(reg.loaded_hooks) == 2 + def test_preserves_optional_startup_readiness_metadata(self, tmp_path): + _create_hook( + tmp_path, + "ready-hook", + '["gateway:startup"]', + "def handle(e, c): pass\n", + manifest_extra="startup_readiness:\n id: beam-runtime\n required: false\n", + ) + + reg = HookRegistry() + with patch("gateway.hooks.HOOKS_DIR", tmp_path), _patch_no_builtins(reg): + reg.discover_and_load() + + assert reg.loaded_hooks[0]["startup_readiness"] == { + "id": "beam-runtime", + "required": False, + } + class TestEmit: @pytest.mark.asyncio diff --git a/tests/gateway/test_runner_startup_failures.py b/tests/gateway/test_runner_startup_failures.py index 77bd25ae2..f2dcbc861 100644 --- a/tests/gateway/test_runner_startup_failures.py +++ b/tests/gateway/test_runner_startup_failures.py @@ -132,6 +132,68 @@ async def test_runner_records_connected_platform_state_on_success(monkeypatch, t assert state["platforms"]["discord"]["error_message"] is None +@pytest.mark.asyncio +async def test_runner_discovers_plugins_before_loading_hooks(monkeypatch, tmp_path): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + config = GatewayConfig( + platforms={ + Platform.DISCORD: PlatformConfig(enabled=True, token="***") + }, + sessions_dir=tmp_path / "sessions", + ) + runner = GatewayRunner(config) + order: list[str] = [] + + monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _SuccessfulAdapter()) + monkeypatch.setattr("hermes_cli.plugins.discover_plugins", lambda: order.append("plugins")) + monkeypatch.setattr(runner.hooks, "discover_and_load", lambda: order.append("hooks")) + monkeypatch.setattr(runner.hooks, "emit", AsyncMock()) + + ok = await runner.start() + + assert ok is True + assert order == ["plugins", "hooks"] + + +@pytest.mark.asyncio +async def test_runner_initializes_startup_checks_before_gateway_startup_emit(monkeypatch, tmp_path): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + config = GatewayConfig( + platforms={ + Platform.DISCORD: PlatformConfig(enabled=True, token="***") + }, + sessions_dir=tmp_path / "sessions", + ) + runner = GatewayRunner(config) + + runner.hooks._loaded_hooks = [ + { + "name": "beam-runtime", + "events": ["gateway:startup"], + "path": str(tmp_path / "hook"), + "startup_readiness": { + "id": "beam-runtime", + "required": True, + }, + } + ] + monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _SuccessfulAdapter()) + monkeypatch.setattr("hermes_cli.plugins.discover_plugins", lambda: None) + monkeypatch.setattr(runner.hooks, "discover_and_load", lambda: None) + + async def _assert_checks(event_type, context): + state = read_runtime_status() + assert event_type == "gateway:startup" + assert state["startup_checks"]["beam-runtime"]["state"] == "pending" + assert state["startup_checks"]["beam-runtime"]["required"] is True + + monkeypatch.setattr(runner.hooks, "emit", _assert_checks) + + ok = await runner.start() + + assert ok is True + + @pytest.mark.asyncio async def test_start_gateway_verbosity_imports_redacting_formatter(monkeypatch, tmp_path): """Verbosity != None must not crash with NameError on RedactingFormatter (#8044).""" diff --git a/tests/gateway/test_status.py b/tests/gateway/test_status.py index 4b9675e72..d80c473af 100644 --- a/tests/gateway/test_status.py +++ b/tests/gateway/test_status.py @@ -132,6 +132,72 @@ class TestGatewayRuntimeStatus: assert payload["platforms"]["discord"]["error_code"] is None assert payload["platforms"]["discord"]["error_message"] is None + def test_reset_startup_checks_replaces_previous_run_entries(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + status.write_runtime_status( + gateway_state="running", + startup_checks={ + "old-check": { + "state": "ready", + "required": True, + "source": "old-hook", + "detail": None, + } + }, + ) + + status.reset_startup_checks([ + { + "name": "new-hook", + "startup_readiness": { + "id": "new-check", + "required": False, + }, + } + ]) + + payload = status.read_runtime_status() + assert set(payload["startup_checks"]) == {"new-check"} + assert payload["startup_checks"]["new-check"]["state"] == "pending" + assert payload["startup_checks"]["new-check"]["required"] is False + assert payload["startup_checks"]["new-check"]["source"] == "new-hook" + + def test_mark_startup_check_ready_persists_detail(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + status.reset_startup_checks([ + { + "name": "beam", + "startup_readiness": { + "id": "beam-runtime", + "required": True, + }, + } + ]) + + status.mark_startup_check_ready("beam-runtime", detail="ready for RPC") + + payload = status.read_runtime_status() + assert payload["startup_checks"]["beam-runtime"]["state"] == "ready" + assert payload["startup_checks"]["beam-runtime"]["detail"] == "ready for RPC" + + def test_mark_startup_check_failed_creates_missing_entry(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + status.mark_startup_check_failed( + "late-hook", + detail="startup hook crashed", + required=False, + source="late-hook", + ) + + payload = status.read_runtime_status() + assert payload["startup_checks"]["late-hook"]["state"] == "failed" + assert payload["startup_checks"]["late-hook"]["required"] is False + assert payload["startup_checks"]["late-hook"]["source"] == "late-hook" + assert payload["startup_checks"]["late-hook"]["detail"] == "startup hook crashed" + class TestTerminatePid: def test_force_uses_taskkill_on_windows(self, monkeypatch): diff --git a/tests/hermes_cli/test_gateway_service.py b/tests/hermes_cli/test_gateway_service.py index ec35aa997..52ed2709a 100644 --- a/tests/hermes_cli/test_gateway_service.py +++ b/tests/hermes_cli/test_gateway_service.py @@ -6,12 +6,21 @@ from pathlib import Path from types import SimpleNamespace import hermes_cli.gateway as gateway_cli +import pytest from gateway.restart import ( DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT, GATEWAY_SERVICE_RESTART_EXIT_CODE, ) +_REAL_AWAIT_SERVICE_READY = gateway_cli._await_service_ready_or_exit + + +@pytest.fixture(autouse=True) +def _stub_service_readiness(monkeypatch): + monkeypatch.setattr(gateway_cli, "_await_service_ready_or_exit", lambda **kwargs: None) + + class TestSystemdServiceRefresh: def test_systemd_install_repairs_outdated_unit_without_force(self, tmp_path, monkeypatch): unit_path = tmp_path / "hermes-gateway.service" @@ -82,6 +91,30 @@ class TestSystemdServiceRefresh: ["systemctl", "--user", "reload-or-restart", gateway_cli.get_service_name()], ] + def test_systemd_start_waits_for_readiness_before_reporting_success(self, monkeypatch): + calls = [] + + monkeypatch.setattr(gateway_cli, "_select_systemd_scope", lambda system=False: False) + monkeypatch.setattr(gateway_cli, "refresh_systemd_unit_if_needed", lambda system=False: calls.append(("refresh", system))) + monkeypatch.setattr( + gateway_cli, + "_run_systemctl", + lambda cmd, system=False, check=True, timeout=30, **kwargs: calls.append((tuple(cmd), system, timeout)), + ) + monkeypatch.setattr( + gateway_cli, + "_await_service_ready_or_exit", + lambda **kwargs: calls.append(("ready", kwargs)), + ) + + gateway_cli.systemd_start() + + assert calls == [ + ("refresh", False), + (("start", gateway_cli.get_service_name()), False, 30), + ("ready", {"action": "start"}), + ] + class TestGeneratedSystemdUnits: def test_user_unit_avoids_recursive_execstop_and_uses_extended_stop_timeout(self): @@ -268,6 +301,32 @@ class TestLaunchdServiceRecovery: ["launchctl", "kickstart", target], ] + def test_launchd_start_waits_for_readiness_before_reporting_success(self, tmp_path, monkeypatch): + plist_path = tmp_path / "ai.hermes.gateway.plist" + plist_path.write_text(gateway_cli.generate_launchd_plist(), encoding="utf-8") + label = gateway_cli.get_launchd_label() + calls = [] + + monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path) + monkeypatch.setattr(gateway_cli, "refresh_launchd_plist_if_needed", lambda: None) + monkeypatch.setattr( + gateway_cli.subprocess, + "run", + lambda cmd, check=False, **kwargs: calls.append(cmd) or SimpleNamespace(returncode=0, stdout="", stderr=""), + ) + monkeypatch.setattr( + gateway_cli, + "_await_service_ready_or_exit", + lambda **kwargs: calls.append(("ready", kwargs)), + ) + + gateway_cli.launchd_start() + + assert calls == [ + ["launchctl", "kickstart", f"{gateway_cli._launchd_domain()}/{label}"], + ("ready", {"action": "start"}), + ] + def test_launchd_restart_drains_running_gateway_before_kickstart(self, monkeypatch): calls = [] target = f"{gateway_cli._launchd_domain()}/{gateway_cli.get_launchd_label()}" @@ -315,7 +374,7 @@ class TestLaunchdServiceRecovery: gateway_cli.launchd_restart() assert calls == [("self", 321)] - assert "restart requested" in capsys.readouterr().out.lower() + assert "service restarted" in capsys.readouterr().out.lower() def test_launchd_stop_uses_bootout_not_kill(self, monkeypatch): """launchd_stop must bootout the service so KeepAlive doesn't respawn it.""" @@ -393,6 +452,109 @@ class TestLaunchdServiceRecovery: assert "not loaded" in output.lower() +class TestGatewayServiceReadiness: + def test_wait_for_service_readiness_accepts_running_gateway_without_checks(self, monkeypatch): + monkeypatch.setattr("gateway.status.get_running_pid", lambda: 123) + monkeypatch.setattr( + "gateway.status.read_runtime_status", + lambda: {"pid": 123, "gateway_state": "running", "startup_checks": {}}, + ) + + warnings = gateway_cli._wait_for_service_readiness(action="start", timeout=0.1, poll_interval=0.0) + + assert warnings == [] + + def test_wait_for_service_readiness_ignores_stale_runtime_state_until_pid_matches(self, monkeypatch): + runtime_states = iter( + [ + {"pid": 999, "gateway_state": "running", "startup_checks": {}}, + {"pid": 123, "gateway_state": "running", "startup_checks": {}}, + ] + ) + + monkeypatch.setattr("gateway.status.get_running_pid", lambda: 123) + monkeypatch.setattr("gateway.status.read_runtime_status", lambda: next(runtime_states)) + + warnings = gateway_cli._wait_for_service_readiness(action="start", timeout=0.1, poll_interval=0.0) + + assert warnings == [] + + def test_wait_for_service_readiness_returns_optional_pending_warnings(self, monkeypatch): + monkeypatch.setattr("gateway.status.get_running_pid", lambda: 123) + monkeypatch.setattr( + "gateway.status.read_runtime_status", + lambda: { + "pid": 123, + "gateway_state": "running", + "startup_checks": { + "optional-check": { + "state": "pending", + "required": False, + "source": "test-hook", + "detail": "still warming", + } + }, + }, + ) + + warnings = gateway_cli._wait_for_service_readiness(action="start", timeout=0.1, poll_interval=0.0) + + assert warnings == ["pending: optional-check (test-hook): still warming"] + + def test_wait_for_service_readiness_fails_when_required_check_fails(self, monkeypatch): + monkeypatch.setattr("gateway.status.get_running_pid", lambda: 123) + monkeypatch.setattr( + "gateway.status.read_runtime_status", + lambda: { + "pid": 123, + "gateway_state": "running", + "startup_checks": { + "beam-runtime": { + "state": "failed", + "required": True, + "source": "beam", + "detail": "RPC boot failed", + } + }, + }, + ) + + with pytest.raises(RuntimeError, match=r"required startup checks failed: beam-runtime \(beam\): RPC boot failed"): + gateway_cli._wait_for_service_readiness(action="start", timeout=0.1, poll_interval=0.0) + + def test_wait_for_service_readiness_times_out_on_pending_required_check(self, monkeypatch): + monkeypatch.setattr("gateway.status.get_running_pid", lambda: 123) + monkeypatch.setattr( + "gateway.status.read_runtime_status", + lambda: { + "pid": 123, + "gateway_state": "running", + "startup_checks": { + "beam-runtime": { + "state": "pending", + "required": True, + "source": "beam", + "detail": "waiting for runtime", + } + }, + }, + ) + + with pytest.raises(RuntimeError, match=r"timed out waiting for required startup checks: beam-runtime \(beam\): waiting for runtime"): + gateway_cli._wait_for_service_readiness(action="start", timeout=0.01, poll_interval=0.0) + + def test_await_service_ready_or_exit_raises_system_exit_when_not_ready(self, monkeypatch): + monkeypatch.setattr(gateway_cli, "_await_service_ready_or_exit", _REAL_AWAIT_SERVICE_READY) + monkeypatch.setattr( + gateway_cli, + "_wait_for_service_readiness", + lambda **kwargs: (_ for _ in ()).throw(RuntimeError("not ready")), + ) + + with pytest.raises(SystemExit, match="1"): + gateway_cli._await_service_ready_or_exit(action="start") + + class TestGatewayServiceDetection: def test_supports_systemd_services_requires_systemctl_binary(self, monkeypatch): monkeypatch.setattr(gateway_cli, "is_linux", lambda: True) @@ -475,7 +637,7 @@ class TestGatewaySystemServiceRouting: gateway_cli.systemd_restart() assert calls == [("refresh", False), ("self", 654)] - assert "restart requested" in capsys.readouterr().out.lower() + assert "service restarted" in capsys.readouterr().out.lower() def test_gateway_install_passes_system_flags(self, monkeypatch): monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)