diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index e30c4478ef..f0ee06f8ca 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -10,6 +10,8 @@ Uses discord.py library for: """ import asyncio +import hashlib +import json import logging import os import struct @@ -24,6 +26,9 @@ logger = logging.getLogger(__name__) VALID_THREAD_AUTO_ARCHIVE_MINUTES = {60, 1440, 4320, 10080} _DISCORD_COMMAND_SYNC_POLICIES = {"safe", "bulk", "off"} +_DISCORD_COMMAND_SYNC_STATE_FILE = "discord_command_sync_state.json" +_DISCORD_COMMAND_SYNC_MUTATION_INTERVAL_SECONDS = 4.5 +_DISCORD_COMMAND_SYNC_MAX_RATE_LIMIT_SLEEP_SECONDS = 30.0 try: import discord @@ -45,6 +50,7 @@ from gateway.config import Platform, PlatformConfig import re from gateway.platforms.helpers import MessageDeduplicator, ThreadParticipationTracker +from utils import atomic_json_write from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -825,6 +831,128 @@ class DiscordAdapter(BasePlatformAdapter): logger.info("[%s] Disconnected", self.name) + def _command_sync_state_path(self) -> _Path: + from hermes_constants import get_hermes_home + + return get_hermes_home() / _DISCORD_COMMAND_SYNC_STATE_FILE + + def _read_command_sync_state(self) -> dict: + try: + path = self._command_sync_state_path() + if not path.exists(): + return {} + data = json.loads(path.read_text(encoding="utf-8")) + except Exception: + return {} + return data if isinstance(data, dict) else {} + + def _write_command_sync_state(self, state: dict) -> None: + atomic_json_write( + self._command_sync_state_path(), + state, + indent=None, + separators=(",", ":"), + ) + + def _command_sync_state_key(self, app_id: Any) -> str: + return str(app_id or "unknown") + + def _desired_command_sync_fingerprint(self) -> str: + tree = self._client.tree if self._client else None + desired = [] + if tree is not None: + desired = [ + self._canonicalize_app_command_payload(command.to_dict(tree)) + for command in tree.get_commands() + ] + desired.sort(key=lambda item: (item.get("type", 1), item.get("name", ""))) + payload = json.dumps(desired, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + + def _command_sync_skip_reason(self, app_id: Any, fingerprint: str) -> Optional[str]: + entry = self._read_command_sync_state().get(self._command_sync_state_key(app_id)) + if not isinstance(entry, dict): + return None + now = time.time() + retry_after_until = float(entry.get("retry_after_until") or 0) + if retry_after_until > now: + remaining = max(1, int(retry_after_until - now)) + return f"Discord asked us to wait before syncing slash commands; retry in {remaining}s" + if entry.get("fingerprint") == fingerprint and entry.get("last_success_at"): + return "same slash-command fingerprint already synced" + return None + + def _record_command_sync_attempt(self, app_id: Any, fingerprint: str) -> None: + state = self._read_command_sync_state() + state[self._command_sync_state_key(app_id)] = { + **( + state.get(self._command_sync_state_key(app_id)) + if isinstance(state.get(self._command_sync_state_key(app_id)), dict) + else {} + ), + "fingerprint": fingerprint, + "last_attempt_at": time.time(), + } + self._write_command_sync_state(state) + + def _record_command_sync_rate_limit(self, app_id: Any, fingerprint: str, retry_after: float) -> None: + retry_after = max(1.0, float(retry_after)) + state = self._read_command_sync_state() + state[self._command_sync_state_key(app_id)] = { + **( + state.get(self._command_sync_state_key(app_id)) + if isinstance(state.get(self._command_sync_state_key(app_id)), dict) + else {} + ), + "fingerprint": fingerprint, + "last_attempt_at": time.time(), + "retry_after_until": time.time() + retry_after, + "retry_after": retry_after, + } + self._write_command_sync_state(state) + + def _record_command_sync_success(self, app_id: Any, fingerprint: str, summary: dict) -> None: + state = self._read_command_sync_state() + state[self._command_sync_state_key(app_id)] = { + "fingerprint": fingerprint, + "last_attempt_at": time.time(), + "last_success_at": time.time(), + "summary": summary, + } + self._write_command_sync_state(state) + + @staticmethod + def _extract_discord_retry_after(exc: BaseException) -> Optional[float]: + value = getattr(exc, "retry_after", None) + if value is not None: + try: + return max(1.0, float(value)) + except (TypeError, ValueError): + return None + response = getattr(exc, "response", None) + headers = getattr(response, "headers", None) + if headers: + for key in ("Retry-After", "X-RateLimit-Reset-After"): + try: + raw = headers.get(key) + except Exception: + raw = None + if raw is None: + continue + try: + return max(1.0, float(raw)) + except (TypeError, ValueError): + continue + return None + + def _command_sync_mutation_interval_seconds(self) -> float: + return _DISCORD_COMMAND_SYNC_MUTATION_INTERVAL_SECONDS + + async def _sleep_between_command_sync_mutations(self) -> None: + interval = self._command_sync_mutation_interval_seconds() + if interval > 0: + await asyncio.sleep(interval) + async def _run_post_connect_initialization(self) -> None: """Finish non-critical startup work after Discord is connected.""" if not self._client: @@ -840,14 +968,42 @@ class DiscordAdapter(BasePlatformAdapter): logger.info("[%s] Synced %d slash command(s) via bulk tree sync", self.name, len(synced)) return - # Discord's per-app command-management bucket is ~5 writes / 20 s, - # so a mass-prune-plus-upsert reconcile (e.g. 77 orphans + 30 - # desired = 107 writes) takes several minutes of forced waits. - # A flat 30 s budget blew up reliably under bucket pressure and - # left slash commands broken for ~60 min until the bucket fully - # recovered. Use a wide ceiling; the cap still guards against a - # true hang. (#16713) - summary = await asyncio.wait_for(self._safe_sync_slash_commands(), timeout=600) + app_id = getattr(self._client, "application_id", None) or getattr(getattr(self._client, "user", None), "id", None) + fingerprint = self._desired_command_sync_fingerprint() + skip_reason = self._command_sync_skip_reason(app_id, fingerprint) + if skip_reason: + logger.info("[%s] Skipping Discord slash command sync: %s", self.name, skip_reason) + return + self._record_command_sync_attempt(app_id, fingerprint) + + http = getattr(self._client, "http", None) + has_ratelimit_timeout = http is not None and hasattr(http, "max_ratelimit_timeout") + previous_ratelimit_timeout = getattr(http, "max_ratelimit_timeout", None) if has_ratelimit_timeout else None + if has_ratelimit_timeout: + http.max_ratelimit_timeout = _DISCORD_COMMAND_SYNC_MAX_RATE_LIMIT_SLEEP_SECONDS + + try: + # Discord's per-app command-management bucket is small, and + # discord.py can otherwise sit inside one long retry sleep + # before surfacing the 429. Keep the whole sync bounded and + # persist Discord's retry-after when it refuses the batch. + summary = await asyncio.wait_for(self._safe_sync_slash_commands(), timeout=600) + except Exception as e: + retry_after = self._extract_discord_retry_after(e) + if retry_after is not None: + self._record_command_sync_rate_limit(app_id, fingerprint, retry_after) + logger.warning( + "[%s] Discord rate-limited slash command sync; retrying after %.0fs", + self.name, + retry_after, + ) + return + raise + finally: + if has_ratelimit_timeout: + http.max_ratelimit_timeout = previous_ratelimit_timeout + + self._record_command_sync_success(app_id, fingerprint, summary) logger.info( "[%s] Safely reconciled %d slash command(s): unchanged=%d updated=%d recreated=%d created=%d deleted=%d", self.name, @@ -1009,11 +1165,20 @@ class DiscordAdapter(BasePlatformAdapter): created = 0 deleted = 0 http = self._client.http + mutation_count = 0 + + async def mutate(call, *args): + nonlocal mutation_count + if mutation_count: + await self._sleep_between_command_sync_mutations() + result = await call(*args) + mutation_count += 1 + return result for key, desired in desired_by_key.items(): current = existing_by_key.pop(key, None) if current is None: - await http.upsert_global_command(app_id, desired) + await mutate(http.upsert_global_command, app_id, desired) created += 1 continue @@ -1025,16 +1190,16 @@ class DiscordAdapter(BasePlatformAdapter): continue if self._patchable_app_command_payload(current_existing_payload) == self._patchable_app_command_payload(desired): - await http.delete_global_command(app_id, current.id) - await http.upsert_global_command(app_id, desired) + await mutate(http.delete_global_command, app_id, current.id) + await mutate(http.upsert_global_command, app_id, desired) recreated += 1 continue - await http.edit_global_command(app_id, current.id, desired) + await mutate(http.edit_global_command, app_id, current.id, desired) updated += 1 for current in existing_by_key.values(): - await http.delete_global_command(app_id, current.id) + await mutate(http.delete_global_command, app_id, current.id) deleted += 1 return { diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index 547e8e03c0..232f8dac80 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -505,6 +505,7 @@ def _read_systemd_unit_properties( "SubState", "Result", "ExecMainStatus", + "MainPID", ), ) -> dict[str, str]: """Return selected ``systemctl show`` properties for the gateway unit.""" @@ -538,6 +539,41 @@ def _read_systemd_unit_properties( return parsed +def _systemd_main_pid_from_props(props: dict[str, str]) -> int | None: + try: + pid = int(props.get("MainPID", "0") or "0") + except (TypeError, ValueError): + return None + return pid if pid > 0 else None + + +def _systemd_main_pid(system: bool = False) -> int | None: + return _systemd_main_pid_from_props(_read_systemd_unit_properties(system=system)) + + +def _read_gateway_runtime_status() -> dict | None: + try: + from gateway.status import read_runtime_status + + state = read_runtime_status() + except Exception: + return None + return state if isinstance(state, dict) else None + + +def _gateway_runtime_status_for_pid(pid: int | None) -> dict | None: + if not pid: + return None + state = _read_gateway_runtime_status() + if not state: + return None + try: + state_pid = int(state.get("pid", 0) or 0) + except (TypeError, ValueError): + return None + return state if state_pid == pid else None + + def _wait_for_systemd_service_restart( *, system: bool = False, @@ -550,6 +586,7 @@ def _wait_for_systemd_service_restart( svc = get_service_name() scope_label = _service_scope_label(system).capitalize() deadline = time.time() + timeout + printed_runtime_wait = False while time.time() < deadline: props = _read_systemd_unit_properties(system=system) @@ -562,19 +599,32 @@ def _wait_for_systemd_service_restart( new_pid = get_running_pid() except Exception: new_pid = None + if not new_pid: + new_pid = _systemd_main_pid_from_props(props) if active_state == "active": if new_pid and (previous_pid is None or new_pid != previous_pid): - print(f"✓ {scope_label} service restarted (PID {new_pid})") - return True - if previous_pid is None: - print(f"✓ {scope_label} service restarted") - return True + runtime_state = _gateway_runtime_status_for_pid(new_pid) + gateway_state = (runtime_state or {}).get("gateway_state") + if gateway_state == "running": + print(f"✓ {scope_label} service restarted (PID {new_pid})") + return True + if gateway_state == "startup_failed": + reason = (runtime_state or {}).get("exit_reason") or "startup failed" + print(f"⚠ {scope_label} service process restarted (PID {new_pid}), but gateway startup failed: {reason}") + return False + if not printed_runtime_wait: + print(f"⏳ {scope_label} service process started (PID {new_pid}); waiting for gateway runtime...") + printed_runtime_wait = True if active_state == "activating" and sub_state == "auto-restart": time.sleep(1) continue + if _systemd_unit_is_start_limited(props): + _print_systemd_start_limit_wait(system=system) + return False + time.sleep(2) print( @@ -585,6 +635,46 @@ def _wait_for_systemd_service_restart( return False +def _systemd_unit_is_start_limited(props: dict[str, str]) -> bool: + result = props.get("Result", "").lower() + sub_state = props.get("SubState", "").lower() + return result == "start-limit-hit" or sub_state == "start-limit-hit" + + +def _systemd_error_indicates_start_limit(exc: subprocess.CalledProcessError) -> bool: + parts: list[str] = [] + for attr in ("stderr", "stdout", "output"): + value = getattr(exc, attr, None) + if not value: + continue + if isinstance(value, bytes): + value = value.decode(errors="replace") + parts.append(str(value)) + text = "\n".join(parts).lower() + return ( + "start-limit-hit" in text + or "start request repeated too quickly" in text + or "start-limit" in text + ) + + +def _systemd_service_is_start_limited(system: bool = False) -> bool: + return _systemd_unit_is_start_limited(_read_systemd_unit_properties(system=system)) + + +def _print_systemd_start_limit_wait(system: bool = False) -> None: + svc = get_service_name() + scope_label = _service_scope_label(system).capitalize() + scope_flag = " --system" if system else "" + systemctl_prefix = "systemctl " if system else "systemctl --user " + journal_prefix = "journalctl " if system else "journalctl --user " + print(f"⏳ {scope_label} service is temporarily rate-limited by systemd.") + print(" systemd is refusing another immediate start after repeated exits.") + print(f" Wait for the start-limit window to expire, then run: {'sudo ' if system else ''}hermes gateway restart{scope_flag}") + print(f" Or clear the failed state manually: {systemctl_prefix}reset-failed {svc}") + print(f" Check logs: {journal_prefix}-u {svc} -l --since '5 min ago'") + + def _recover_pending_systemd_restart(system: bool = False, previous_pid: int | None = None) -> bool: """Recover a planned service restart that is stuck in systemd state.""" props = _read_systemd_unit_properties(system=system) @@ -2135,41 +2225,52 @@ def systemd_restart(system: bool = False): refresh_systemd_unit_if_needed(system=system) from gateway.status import get_running_pid - pid = get_running_pid() - if pid is not None and _request_gateway_self_restart(pid): - import time + pid = get_running_pid() or _systemd_main_pid(system=system) + if pid is not None: scope_label = _service_scope_label(system).capitalize() svc = get_service_name() + drain_timeout = _get_restart_drain_timeout() - # Phase 1: wait for old process to exit (drain + shutdown) - print(f"⏳ {scope_label} service draining active work...") - deadline = time.time() + 90 - while time.time() < deadline: - try: - os.kill(pid, 0) - time.sleep(1) - except (ProcessLookupError, PermissionError): - break # old process is gone - else: - print(f"⚠ Old process (PID {pid}) still alive after 90s") + print(f"⏳ {scope_label} service restarting gracefully (PID {pid})...") + if _graceful_restart_via_sigusr1(pid, drain_timeout + 5): + # The gateway exits with code 75 for a planned service restart. + # RestartSec can otherwise delay the relaunch even though the + # operator asked for an immediate restart, so kick the unit once + # the old PID has exited and then wait for the replacement PID. + _run_systemctl( + ["reset-failed", svc], + system=system, + check=False, + timeout=30, + ) + _run_systemctl( + ["restart", svc], + system=system, + check=False, + timeout=90, + ) + if _wait_for_systemd_service_restart(system=system, previous_pid=pid): + return + if _systemd_service_is_start_limited(system=system): + return - # The gateway exits with code 75 for a planned service restart. - # systemd can sit in the RestartSec window or even wedge itself into a - # failed/rate-limited state if the operator asks for another restart in - # the middle of that handoff. Clear any stale failed state and kick the - # unit immediately so `hermes gateway restart` behaves idempotently. + print( + f"⚠ Graceful restart did not complete within {int(drain_timeout + 5)}s; " + "forcing a service restart..." + ) _run_systemctl( ["reset-failed", svc], system=system, check=False, timeout=30, ) - _run_systemctl( - ["start", svc], - system=system, - check=False, - timeout=90, - ) + try: + _run_systemctl(["restart", svc], system=system, check=True, timeout=90) + except subprocess.CalledProcessError as exc: + if _systemd_error_indicates_start_limit(exc) or _systemd_service_is_start_limited(system=system): + _print_systemd_start_limit_wait(system=system) + return + raise _wait_for_systemd_service_restart(system=system, previous_pid=pid) return @@ -2182,8 +2283,14 @@ def systemd_restart(system: bool = False): check=False, timeout=30, ) - _run_systemctl(["reload-or-restart", get_service_name()], system=system, check=True, timeout=90) - print(f"✓ {_service_scope_label(system).capitalize()} service restarted") + try: + _run_systemctl(["restart", get_service_name()], system=system, check=True, timeout=90) + except subprocess.CalledProcessError as exc: + if _systemd_error_indicates_start_limit(exc) or _systemd_service_is_start_limited(system=system): + _print_systemd_start_limit_wait(system=system) + return + raise + _wait_for_systemd_service_restart(system=system, previous_pid=pid) @@ -2255,6 +2362,10 @@ def systemd_status(deep: bool = False, system: bool = False, full: bool = False) result_code = unit_props.get("Result", "") if active_state == "activating" and sub_state == "auto-restart": print(" ⏳ Restart pending: systemd is waiting to relaunch the gateway") + elif _systemd_unit_is_start_limited(unit_props): + print(" ⏳ Restart pending: systemd is temporarily rate-limiting starts") + print(f" Run after the start-limit window expires: {'sudo ' if system else ''}hermes gateway restart{scope_flag}") + print(f" Or clear it manually: systemctl {'--user ' if not system else ''}reset-failed {get_service_name()}") elif active_state == "failed" and exec_main_status == str(GATEWAY_SERVICE_RESTART_EXIT_CODE): print(" ⚠ Planned restart is stuck in systemd failed state (exit 75)") print(f" Run: systemctl {'--user ' if not system else ''}reset-failed {get_service_name()} && {'sudo ' if system else ''}hermes gateway start{scope_flag}") diff --git a/tests/gateway/test_discord_connect.py b/tests/gateway/test_discord_connect.py index dd49e78e18..57b3791a05 100644 --- a/tests/gateway/test_discord_connect.py +++ b/tests/gateway/test_discord_connect.py @@ -1,4 +1,5 @@ import asyncio +import json import sys from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock @@ -70,6 +71,15 @@ import gateway.platforms.discord as discord_platform # noqa: E402 from gateway.platforms.discord import DiscordAdapter # noqa: E402 +@pytest.fixture(autouse=True) +def _speed_up_command_sync_mutation_pacing(monkeypatch): + monkeypatch.setattr( + DiscordAdapter, + "_command_sync_mutation_interval_seconds", + lambda self: 0.0, + ) + + class FakeTree: def __init__(self): self.sync = AsyncMock(return_value=[]) @@ -536,6 +546,136 @@ async def test_post_connect_initialization_skips_sync_when_policy_off(monkeypatc fake_tree.sync.assert_not_called() +@pytest.mark.asyncio +async def test_post_connect_initialization_skips_same_fingerprint_after_success(tmp_path, monkeypatch): + adapter = DiscordAdapter(PlatformConfig(enabled=True, token="test-token")) + monkeypatch.setattr("hermes_constants.get_hermes_home", lambda: tmp_path) + + class _DesiredCommand: + def to_dict(self, tree): + return { + "name": "status", + "description": "Show Hermes status", + "type": 1, + "options": [], + } + + fake_tree = SimpleNamespace( + get_commands=lambda: [_DesiredCommand()], + fetch_commands=AsyncMock(return_value=[]), + ) + fake_http = SimpleNamespace( + upsert_global_command=AsyncMock(), + edit_global_command=AsyncMock(), + delete_global_command=AsyncMock(), + ) + adapter._client = SimpleNamespace( + tree=fake_tree, + http=fake_http, + application_id=999, + user=SimpleNamespace(id=999), + ) + + await adapter._run_post_connect_initialization() + await adapter._run_post_connect_initialization() + + fake_tree.fetch_commands.assert_awaited_once() + fake_http.upsert_global_command.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_post_connect_initialization_respects_discord_retry_after(tmp_path, monkeypatch): + adapter = DiscordAdapter(PlatformConfig(enabled=True, token="test-token")) + monkeypatch.setattr("hermes_constants.get_hermes_home", lambda: tmp_path) + + class _DesiredCommand: + def to_dict(self, tree): + return { + "name": "status", + "description": "Show Hermes status", + "type": 1, + "options": [], + } + + adapter._client = SimpleNamespace( + tree=SimpleNamespace(get_commands=lambda: [_DesiredCommand()]), + application_id=999, + user=SimpleNamespace(id=999), + ) + class _DiscordRateLimit(RuntimeError): + retry_after = 123.0 + + sync = AsyncMock(side_effect=_DiscordRateLimit("discord rate limited")) + monkeypatch.setattr(adapter, "_safe_sync_slash_commands", sync) + + await adapter._run_post_connect_initialization() + await adapter._run_post_connect_initialization() + + sync.assert_awaited_once() + state = json.loads((tmp_path / discord_platform._DISCORD_COMMAND_SYNC_STATE_FILE).read_text()) + entry = state["999"] + assert entry["retry_after"] == 123.0 + assert entry["retry_after_until"] > entry["last_attempt_at"] + + +@pytest.mark.asyncio +async def test_safe_sync_slash_commands_paces_mutation_writes(monkeypatch): + adapter = DiscordAdapter(PlatformConfig(enabled=True, token="test-token")) + monkeypatch.setattr( + DiscordAdapter, + "_command_sync_mutation_interval_seconds", + lambda self: 1.25, + ) + sleeps = [] + + async def fake_sleep(delay): + sleeps.append(delay) + + monkeypatch.setattr(discord_platform.asyncio, "sleep", fake_sleep) + + class _DesiredCommand: + def __init__(self, payload): + self._payload = payload + + def to_dict(self, tree): + assert tree is not None + return dict(self._payload) + + desired_one = { + "name": "status", + "description": "Show Hermes status", + "type": 1, + "options": [], + } + desired_two = { + "name": "debug", + "description": "Generate a debug report", + "type": 1, + "options": [], + } + fake_tree = SimpleNamespace( + get_commands=lambda: [_DesiredCommand(desired_one), _DesiredCommand(desired_two)], + fetch_commands=AsyncMock(return_value=[]), + ) + fake_http = SimpleNamespace( + upsert_global_command=AsyncMock(), + edit_global_command=AsyncMock(), + delete_global_command=AsyncMock(), + ) + adapter._client = SimpleNamespace( + tree=fake_tree, + http=fake_http, + application_id=999, + user=SimpleNamespace(id=999), + ) + + summary = await adapter._safe_sync_slash_commands() + + assert summary["created"] == 2 + assert fake_http.upsert_global_command.await_count == 2 + assert sleeps == [1.25] + + @pytest.mark.asyncio async def test_safe_sync_reads_permission_attrs_from_existing_command(): """Regression: AppCommand.to_dict() in discord.py does NOT include diff --git a/tests/hermes_cli/test_gateway_service.py b/tests/hermes_cli/test_gateway_service.py index b3d9036207..15968f798e 100644 --- a/tests/hermes_cli/test_gateway_service.py +++ b/tests/hermes_cli/test_gateway_service.py @@ -2,6 +2,7 @@ import os import pwd +import subprocess from pathlib import Path from types import SimpleNamespace @@ -90,6 +91,13 @@ class TestSystemdServiceRefresh: monkeypatch.setattr(gateway_cli, "generate_systemd_unit", lambda system=False, run_as_user=None: "new unit\n") calls = [] + monkeypatch.setattr("gateway.status.get_running_pid", lambda: None) + monkeypatch.setattr(gateway_cli, "_recover_pending_systemd_restart", lambda system=False, previous_pid=None: False) + monkeypatch.setattr( + gateway_cli, + "_wait_for_systemd_service_restart", + lambda system=False, previous_pid=None: calls.append(("wait", system, previous_pid)) or True, + ) def fake_run(cmd, check=True, **kwargs): calls.append(cmd) @@ -100,11 +108,12 @@ class TestSystemdServiceRefresh: gateway_cli.systemd_restart() assert unit_path.read_text(encoding="utf-8") == "new unit\n" - assert calls[:4] == [ + assert calls[:5] == [ ["systemctl", "--user", "daemon-reload"], - ["systemctl", "--user", "show", gateway_cli.get_service_name(), "--no-pager", "--property", "ActiveState,SubState,Result,ExecMainStatus"], + ["systemctl", "--user", "show", gateway_cli.get_service_name(), "--no-pager", "--property", "ActiveState,SubState,Result,ExecMainStatus,MainPID"], ["systemctl", "--user", "reset-failed", gateway_cli.get_service_name()], - ["systemctl", "--user", "reload-or-restart", gateway_cli.get_service_name()], + ["systemctl", "--user", "restart", gateway_cli.get_service_name()], + ("wait", False, None), ] def test_systemd_stop_marks_running_gateway_as_planned_stop(self, monkeypatch): @@ -611,62 +620,141 @@ class TestGatewayServiceDetection: assert gateway_cli._is_service_running() is False class TestGatewaySystemServiceRouting: - def test_systemd_restart_self_requests_graceful_restart_and_waits(self, monkeypatch, capsys): + def test_systemd_restart_gracefully_restarts_running_service_and_waits(self, monkeypatch, capsys): calls = [] monkeypatch.setattr(gateway_cli, "_select_systemd_scope", lambda system=False: False) monkeypatch.setattr(gateway_cli, "_require_service_installed", lambda action, system=False: None) monkeypatch.setattr(gateway_cli, "refresh_systemd_unit_if_needed", lambda system=False: calls.append(("refresh", system))) + monkeypatch.setattr(gateway_cli, "_get_restart_drain_timeout", lambda: 12.0) monkeypatch.setattr( "gateway.status.get_running_pid", lambda: 654, ) monkeypatch.setattr( gateway_cli, - "_request_gateway_self_restart", - lambda pid: calls.append(("self", pid)) or True, + "_graceful_restart_via_sigusr1", + lambda pid, timeout: calls.append(("graceful", pid, timeout)) or True, ) - # Simulate: old process dies immediately, new process becomes active - kill_call_count = [0] - def fake_kill(pid, sig): - kill_call_count[0] += 1 - if kill_call_count[0] >= 2: # first call checks, second = dead - raise ProcessLookupError() - monkeypatch.setattr(os, "kill", fake_kill) - - # Simulate systemctl reset-failed/start followed by an active unit - new_pid = [None] + # Simulate systemctl reset-failed/restart followed by an active unit. + # A plain start does not break systemd's auto-restart timer once the + # old gateway has exited with the planned restart code. def fake_subprocess_run(cmd, **kwargs): if "reset-failed" in cmd: calls.append(("reset-failed", cmd)) return SimpleNamespace(stdout="", returncode=0) - if "start" in cmd: - calls.append(("start", cmd)) + if "restart" in cmd: + calls.append(("restart", cmd)) return SimpleNamespace(stdout="", returncode=0) - if "show" in cmd: - new_pid[0] = 999 - return SimpleNamespace( - stdout="ActiveState=active\nSubState=running\nResult=success\nExecMainStatus=0\n", - returncode=0, - ) raise AssertionError(f"Unexpected systemctl call: {cmd}") monkeypatch.setattr(gateway_cli.subprocess, "run", fake_subprocess_run) - # get_running_pid returns new PID after restart - pid_calls = [0] - def fake_get_pid(): - pid_calls[0] += 1 - return 999 if pid_calls[0] > 1 else 654 - monkeypatch.setattr("gateway.status.get_running_pid", fake_get_pid) + monkeypatch.setattr( + gateway_cli, + "_wait_for_systemd_service_restart", + lambda system=False, previous_pid=None: calls.append(("wait", system, previous_pid)) or True, + ) gateway_cli.systemd_restart() - assert ("self", 654) in calls + assert ("graceful", 654, 17.0) in calls assert any(call[0] == "reset-failed" for call in calls) - assert any(call[0] == "start" for call in calls) + assert any(call[0] == "restart" for call in calls) + assert ("wait", False, 654) in calls out = capsys.readouterr().out.lower() - assert "restarted" in out + assert "restarting gracefully" in out + + def test_systemd_restart_uses_systemd_main_pid_when_pid_file_is_missing(self, monkeypatch, capsys): + calls = [] + + monkeypatch.setattr(gateway_cli, "_select_systemd_scope", lambda system=False: False) + monkeypatch.setattr(gateway_cli, "_require_service_installed", lambda action, system=False: None) + monkeypatch.setattr(gateway_cli, "refresh_systemd_unit_if_needed", lambda system=False: None) + monkeypatch.setattr(gateway_cli, "_get_restart_drain_timeout", lambda: 10.0) + monkeypatch.setattr("gateway.status.get_running_pid", lambda: None) + monkeypatch.setattr( + gateway_cli, + "_read_systemd_unit_properties", + lambda system=False: { + "ActiveState": "active", + "SubState": "running", + "Result": "success", + "ExecMainStatus": "0", + "MainPID": "777", + }, + ) + monkeypatch.setattr( + gateway_cli, + "_graceful_restart_via_sigusr1", + lambda pid, timeout: calls.append(("graceful", pid, timeout)) or True, + ) + monkeypatch.setattr(gateway_cli, "_run_systemctl", lambda args, **kwargs: calls.append(args) or SimpleNamespace(stdout="", returncode=0)) + monkeypatch.setattr( + gateway_cli, + "_wait_for_systemd_service_restart", + lambda system=False, previous_pid=None: calls.append(("wait", system, previous_pid)) or True, + ) + + gateway_cli.systemd_restart() + + assert ("graceful", 777, 15.0) in calls + assert ("wait", False, 777) in calls + assert "restarting gracefully (pid 777)" in capsys.readouterr().out.lower() + + def test_wait_for_systemd_restart_waits_for_runtime_running(self, monkeypatch, capsys): + monkeypatch.setattr( + gateway_cli, + "_read_systemd_unit_properties", + lambda system=False: { + "ActiveState": "active", + "SubState": "running", + "Result": "success", + "ExecMainStatus": "0", + "MainPID": "999", + }, + ) + monkeypatch.setattr("gateway.status.get_running_pid", lambda: None) + monkeypatch.setattr( + gateway_cli, + "_gateway_runtime_status_for_pid", + lambda pid: {"pid": pid, "gateway_state": "running"}, + ) + + assert gateway_cli._wait_for_systemd_service_restart(previous_pid=777, timeout=0.1) is True + assert "restarted (pid 999)" in capsys.readouterr().out.lower() + + def test_systemd_restart_reports_start_limit_hit(self, monkeypatch, capsys): + calls = [] + + monkeypatch.setattr(gateway_cli, "_select_systemd_scope", lambda system=False: False) + monkeypatch.setattr(gateway_cli, "_require_service_installed", lambda action, system=False: None) + monkeypatch.setattr(gateway_cli, "refresh_systemd_unit_if_needed", lambda system=False: None) + monkeypatch.setattr("gateway.status.get_running_pid", lambda: None) + monkeypatch.setattr(gateway_cli, "_recover_pending_systemd_restart", lambda system=False, previous_pid=None: False) + + def fake_run_systemctl(args, **kwargs): + calls.append(args) + if args[0] == "show": + return SimpleNamespace(stdout="ActiveState=inactive\nSubState=dead\nResult=success\nExecMainStatus=0\nMainPID=0\n", stderr="", returncode=0) + if args[0] == "reset-failed": + return SimpleNamespace(stdout="", stderr="", returncode=0) + if args[0] == "restart": + raise subprocess.CalledProcessError( + 1, + ["systemctl", "--user", *args], + stderr="Job failed. See result 'start-limit-hit'.", + ) + raise AssertionError(f"Unexpected args: {args}") + + monkeypatch.setattr(gateway_cli, "_run_systemctl", fake_run_systemctl) + + gateway_cli.systemd_restart() + + assert ["restart", gateway_cli.get_service_name()] in calls + out = capsys.readouterr().out.lower() + assert "rate-limited by systemd" in out + assert "reset-failed" in out def test_systemd_restart_recovers_failed_planned_restart(self, monkeypatch, capsys): monkeypatch.setattr(gateway_cli, "_select_systemd_scope", lambda system=False: False) @@ -711,6 +799,11 @@ class TestGatewaySystemServiceRouting: "gateway.status.get_running_pid", lambda: 999 if started["value"] else None, ) + monkeypatch.setattr( + gateway_cli, + "_gateway_runtime_status_for_pid", + lambda pid: {"pid": pid, "gateway_state": "running"}, + ) gateway_cli.systemd_restart()