fix(gateway): wait for systemd restart readiness

This commit is contained in:
helix4u 2026-05-05 21:55:58 -06:00 committed by Teknium
parent 3cdbf334d5
commit d797755a1c
4 changed files with 587 additions and 78 deletions

View file

@ -10,6 +10,8 @@ Uses discord.py library for:
""" """
import asyncio import asyncio
import hashlib
import json
import logging import logging
import os import os
import struct import struct
@ -24,6 +26,9 @@ logger = logging.getLogger(__name__)
VALID_THREAD_AUTO_ARCHIVE_MINUTES = {60, 1440, 4320, 10080} VALID_THREAD_AUTO_ARCHIVE_MINUTES = {60, 1440, 4320, 10080}
_DISCORD_COMMAND_SYNC_POLICIES = {"safe", "bulk", "off"} _DISCORD_COMMAND_SYNC_POLICIES = {"safe", "bulk", "off"}
_DISCORD_COMMAND_SYNC_STATE_FILE = "discord_command_sync_state.json"
_DISCORD_COMMAND_SYNC_MUTATION_INTERVAL_SECONDS = 4.5
_DISCORD_COMMAND_SYNC_MAX_RATE_LIMIT_SLEEP_SECONDS = 30.0
try: try:
import discord import discord
@ -45,6 +50,7 @@ from gateway.config import Platform, PlatformConfig
import re import re
from gateway.platforms.helpers import MessageDeduplicator, ThreadParticipationTracker from gateway.platforms.helpers import MessageDeduplicator, ThreadParticipationTracker
from utils import atomic_json_write
from gateway.platforms.base import ( from gateway.platforms.base import (
BasePlatformAdapter, BasePlatformAdapter,
MessageEvent, MessageEvent,
@ -825,6 +831,128 @@ class DiscordAdapter(BasePlatformAdapter):
logger.info("[%s] Disconnected", self.name) logger.info("[%s] Disconnected", self.name)
def _command_sync_state_path(self) -> _Path:
from hermes_constants import get_hermes_home
return get_hermes_home() / _DISCORD_COMMAND_SYNC_STATE_FILE
def _read_command_sync_state(self) -> dict:
try:
path = self._command_sync_state_path()
if not path.exists():
return {}
data = json.loads(path.read_text(encoding="utf-8"))
except Exception:
return {}
return data if isinstance(data, dict) else {}
def _write_command_sync_state(self, state: dict) -> None:
atomic_json_write(
self._command_sync_state_path(),
state,
indent=None,
separators=(",", ":"),
)
def _command_sync_state_key(self, app_id: Any) -> str:
return str(app_id or "unknown")
def _desired_command_sync_fingerprint(self) -> str:
tree = self._client.tree if self._client else None
desired = []
if tree is not None:
desired = [
self._canonicalize_app_command_payload(command.to_dict(tree))
for command in tree.get_commands()
]
desired.sort(key=lambda item: (item.get("type", 1), item.get("name", "")))
payload = json.dumps(desired, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
def _command_sync_skip_reason(self, app_id: Any, fingerprint: str) -> Optional[str]:
entry = self._read_command_sync_state().get(self._command_sync_state_key(app_id))
if not isinstance(entry, dict):
return None
now = time.time()
retry_after_until = float(entry.get("retry_after_until") or 0)
if retry_after_until > now:
remaining = max(1, int(retry_after_until - now))
return f"Discord asked us to wait before syncing slash commands; retry in {remaining}s"
if entry.get("fingerprint") == fingerprint and entry.get("last_success_at"):
return "same slash-command fingerprint already synced"
return None
def _record_command_sync_attempt(self, app_id: Any, fingerprint: str) -> None:
state = self._read_command_sync_state()
state[self._command_sync_state_key(app_id)] = {
**(
state.get(self._command_sync_state_key(app_id))
if isinstance(state.get(self._command_sync_state_key(app_id)), dict)
else {}
),
"fingerprint": fingerprint,
"last_attempt_at": time.time(),
}
self._write_command_sync_state(state)
def _record_command_sync_rate_limit(self, app_id: Any, fingerprint: str, retry_after: float) -> None:
retry_after = max(1.0, float(retry_after))
state = self._read_command_sync_state()
state[self._command_sync_state_key(app_id)] = {
**(
state.get(self._command_sync_state_key(app_id))
if isinstance(state.get(self._command_sync_state_key(app_id)), dict)
else {}
),
"fingerprint": fingerprint,
"last_attempt_at": time.time(),
"retry_after_until": time.time() + retry_after,
"retry_after": retry_after,
}
self._write_command_sync_state(state)
def _record_command_sync_success(self, app_id: Any, fingerprint: str, summary: dict) -> None:
state = self._read_command_sync_state()
state[self._command_sync_state_key(app_id)] = {
"fingerprint": fingerprint,
"last_attempt_at": time.time(),
"last_success_at": time.time(),
"summary": summary,
}
self._write_command_sync_state(state)
@staticmethod
def _extract_discord_retry_after(exc: BaseException) -> Optional[float]:
value = getattr(exc, "retry_after", None)
if value is not None:
try:
return max(1.0, float(value))
except (TypeError, ValueError):
return None
response = getattr(exc, "response", None)
headers = getattr(response, "headers", None)
if headers:
for key in ("Retry-After", "X-RateLimit-Reset-After"):
try:
raw = headers.get(key)
except Exception:
raw = None
if raw is None:
continue
try:
return max(1.0, float(raw))
except (TypeError, ValueError):
continue
return None
def _command_sync_mutation_interval_seconds(self) -> float:
return _DISCORD_COMMAND_SYNC_MUTATION_INTERVAL_SECONDS
async def _sleep_between_command_sync_mutations(self) -> None:
interval = self._command_sync_mutation_interval_seconds()
if interval > 0:
await asyncio.sleep(interval)
async def _run_post_connect_initialization(self) -> None: async def _run_post_connect_initialization(self) -> None:
"""Finish non-critical startup work after Discord is connected.""" """Finish non-critical startup work after Discord is connected."""
if not self._client: if not self._client:
@ -840,14 +968,42 @@ class DiscordAdapter(BasePlatformAdapter):
logger.info("[%s] Synced %d slash command(s) via bulk tree sync", self.name, len(synced)) logger.info("[%s] Synced %d slash command(s) via bulk tree sync", self.name, len(synced))
return return
# Discord's per-app command-management bucket is ~5 writes / 20 s, app_id = getattr(self._client, "application_id", None) or getattr(getattr(self._client, "user", None), "id", None)
# so a mass-prune-plus-upsert reconcile (e.g. 77 orphans + 30 fingerprint = self._desired_command_sync_fingerprint()
# desired = 107 writes) takes several minutes of forced waits. skip_reason = self._command_sync_skip_reason(app_id, fingerprint)
# A flat 30 s budget blew up reliably under bucket pressure and if skip_reason:
# left slash commands broken for ~60 min until the bucket fully logger.info("[%s] Skipping Discord slash command sync: %s", self.name, skip_reason)
# recovered. Use a wide ceiling; the cap still guards against a return
# true hang. (#16713) self._record_command_sync_attempt(app_id, fingerprint)
summary = await asyncio.wait_for(self._safe_sync_slash_commands(), timeout=600)
http = getattr(self._client, "http", None)
has_ratelimit_timeout = http is not None and hasattr(http, "max_ratelimit_timeout")
previous_ratelimit_timeout = getattr(http, "max_ratelimit_timeout", None) if has_ratelimit_timeout else None
if has_ratelimit_timeout:
http.max_ratelimit_timeout = _DISCORD_COMMAND_SYNC_MAX_RATE_LIMIT_SLEEP_SECONDS
try:
# Discord's per-app command-management bucket is small, and
# discord.py can otherwise sit inside one long retry sleep
# before surfacing the 429. Keep the whole sync bounded and
# persist Discord's retry-after when it refuses the batch.
summary = await asyncio.wait_for(self._safe_sync_slash_commands(), timeout=600)
except Exception as e:
retry_after = self._extract_discord_retry_after(e)
if retry_after is not None:
self._record_command_sync_rate_limit(app_id, fingerprint, retry_after)
logger.warning(
"[%s] Discord rate-limited slash command sync; retrying after %.0fs",
self.name,
retry_after,
)
return
raise
finally:
if has_ratelimit_timeout:
http.max_ratelimit_timeout = previous_ratelimit_timeout
self._record_command_sync_success(app_id, fingerprint, summary)
logger.info( logger.info(
"[%s] Safely reconciled %d slash command(s): unchanged=%d updated=%d recreated=%d created=%d deleted=%d", "[%s] Safely reconciled %d slash command(s): unchanged=%d updated=%d recreated=%d created=%d deleted=%d",
self.name, self.name,
@ -1009,11 +1165,20 @@ class DiscordAdapter(BasePlatformAdapter):
created = 0 created = 0
deleted = 0 deleted = 0
http = self._client.http http = self._client.http
mutation_count = 0
async def mutate(call, *args):
nonlocal mutation_count
if mutation_count:
await self._sleep_between_command_sync_mutations()
result = await call(*args)
mutation_count += 1
return result
for key, desired in desired_by_key.items(): for key, desired in desired_by_key.items():
current = existing_by_key.pop(key, None) current = existing_by_key.pop(key, None)
if current is None: if current is None:
await http.upsert_global_command(app_id, desired) await mutate(http.upsert_global_command, app_id, desired)
created += 1 created += 1
continue continue
@ -1025,16 +1190,16 @@ class DiscordAdapter(BasePlatformAdapter):
continue continue
if self._patchable_app_command_payload(current_existing_payload) == self._patchable_app_command_payload(desired): if self._patchable_app_command_payload(current_existing_payload) == self._patchable_app_command_payload(desired):
await http.delete_global_command(app_id, current.id) await mutate(http.delete_global_command, app_id, current.id)
await http.upsert_global_command(app_id, desired) await mutate(http.upsert_global_command, app_id, desired)
recreated += 1 recreated += 1
continue continue
await http.edit_global_command(app_id, current.id, desired) await mutate(http.edit_global_command, app_id, current.id, desired)
updated += 1 updated += 1
for current in existing_by_key.values(): for current in existing_by_key.values():
await http.delete_global_command(app_id, current.id) await mutate(http.delete_global_command, app_id, current.id)
deleted += 1 deleted += 1
return { return {

View file

@ -505,6 +505,7 @@ def _read_systemd_unit_properties(
"SubState", "SubState",
"Result", "Result",
"ExecMainStatus", "ExecMainStatus",
"MainPID",
), ),
) -> dict[str, str]: ) -> dict[str, str]:
"""Return selected ``systemctl show`` properties for the gateway unit.""" """Return selected ``systemctl show`` properties for the gateway unit."""
@ -538,6 +539,41 @@ def _read_systemd_unit_properties(
return parsed return parsed
def _systemd_main_pid_from_props(props: dict[str, str]) -> int | None:
try:
pid = int(props.get("MainPID", "0") or "0")
except (TypeError, ValueError):
return None
return pid if pid > 0 else None
def _systemd_main_pid(system: bool = False) -> int | None:
return _systemd_main_pid_from_props(_read_systemd_unit_properties(system=system))
def _read_gateway_runtime_status() -> dict | None:
try:
from gateway.status import read_runtime_status
state = read_runtime_status()
except Exception:
return None
return state if isinstance(state, dict) else None
def _gateway_runtime_status_for_pid(pid: int | None) -> dict | None:
if not pid:
return None
state = _read_gateway_runtime_status()
if not state:
return None
try:
state_pid = int(state.get("pid", 0) or 0)
except (TypeError, ValueError):
return None
return state if state_pid == pid else None
def _wait_for_systemd_service_restart( def _wait_for_systemd_service_restart(
*, *,
system: bool = False, system: bool = False,
@ -550,6 +586,7 @@ def _wait_for_systemd_service_restart(
svc = get_service_name() svc = get_service_name()
scope_label = _service_scope_label(system).capitalize() scope_label = _service_scope_label(system).capitalize()
deadline = time.time() + timeout deadline = time.time() + timeout
printed_runtime_wait = False
while time.time() < deadline: while time.time() < deadline:
props = _read_systemd_unit_properties(system=system) props = _read_systemd_unit_properties(system=system)
@ -562,19 +599,32 @@ def _wait_for_systemd_service_restart(
new_pid = get_running_pid() new_pid = get_running_pid()
except Exception: except Exception:
new_pid = None new_pid = None
if not new_pid:
new_pid = _systemd_main_pid_from_props(props)
if active_state == "active": if active_state == "active":
if new_pid and (previous_pid is None or new_pid != previous_pid): if new_pid and (previous_pid is None or new_pid != previous_pid):
print(f"{scope_label} service restarted (PID {new_pid})") runtime_state = _gateway_runtime_status_for_pid(new_pid)
return True gateway_state = (runtime_state or {}).get("gateway_state")
if previous_pid is None: if gateway_state == "running":
print(f"{scope_label} service restarted") print(f"{scope_label} service restarted (PID {new_pid})")
return True return True
if gateway_state == "startup_failed":
reason = (runtime_state or {}).get("exit_reason") or "startup failed"
print(f"{scope_label} service process restarted (PID {new_pid}), but gateway startup failed: {reason}")
return False
if not printed_runtime_wait:
print(f"{scope_label} service process started (PID {new_pid}); waiting for gateway runtime...")
printed_runtime_wait = True
if active_state == "activating" and sub_state == "auto-restart": if active_state == "activating" and sub_state == "auto-restart":
time.sleep(1) time.sleep(1)
continue continue
if _systemd_unit_is_start_limited(props):
_print_systemd_start_limit_wait(system=system)
return False
time.sleep(2) time.sleep(2)
print( print(
@ -585,6 +635,46 @@ def _wait_for_systemd_service_restart(
return False return False
def _systemd_unit_is_start_limited(props: dict[str, str]) -> bool:
result = props.get("Result", "").lower()
sub_state = props.get("SubState", "").lower()
return result == "start-limit-hit" or sub_state == "start-limit-hit"
def _systemd_error_indicates_start_limit(exc: subprocess.CalledProcessError) -> bool:
parts: list[str] = []
for attr in ("stderr", "stdout", "output"):
value = getattr(exc, attr, None)
if not value:
continue
if isinstance(value, bytes):
value = value.decode(errors="replace")
parts.append(str(value))
text = "\n".join(parts).lower()
return (
"start-limit-hit" in text
or "start request repeated too quickly" in text
or "start-limit" in text
)
def _systemd_service_is_start_limited(system: bool = False) -> bool:
return _systemd_unit_is_start_limited(_read_systemd_unit_properties(system=system))
def _print_systemd_start_limit_wait(system: bool = False) -> None:
svc = get_service_name()
scope_label = _service_scope_label(system).capitalize()
scope_flag = " --system" if system else ""
systemctl_prefix = "systemctl " if system else "systemctl --user "
journal_prefix = "journalctl " if system else "journalctl --user "
print(f"{scope_label} service is temporarily rate-limited by systemd.")
print(" systemd is refusing another immediate start after repeated exits.")
print(f" Wait for the start-limit window to expire, then run: {'sudo ' if system else ''}hermes gateway restart{scope_flag}")
print(f" Or clear the failed state manually: {systemctl_prefix}reset-failed {svc}")
print(f" Check logs: {journal_prefix}-u {svc} -l --since '5 min ago'")
def _recover_pending_systemd_restart(system: bool = False, previous_pid: int | None = None) -> bool: def _recover_pending_systemd_restart(system: bool = False, previous_pid: int | None = None) -> bool:
"""Recover a planned service restart that is stuck in systemd state.""" """Recover a planned service restart that is stuck in systemd state."""
props = _read_systemd_unit_properties(system=system) props = _read_systemd_unit_properties(system=system)
@ -2135,41 +2225,52 @@ def systemd_restart(system: bool = False):
refresh_systemd_unit_if_needed(system=system) refresh_systemd_unit_if_needed(system=system)
from gateway.status import get_running_pid from gateway.status import get_running_pid
pid = get_running_pid() pid = get_running_pid() or _systemd_main_pid(system=system)
if pid is not None and _request_gateway_self_restart(pid): if pid is not None:
import time
scope_label = _service_scope_label(system).capitalize() scope_label = _service_scope_label(system).capitalize()
svc = get_service_name() svc = get_service_name()
drain_timeout = _get_restart_drain_timeout()
# Phase 1: wait for old process to exit (drain + shutdown) print(f"{scope_label} service restarting gracefully (PID {pid})...")
print(f"{scope_label} service draining active work...") if _graceful_restart_via_sigusr1(pid, drain_timeout + 5):
deadline = time.time() + 90 # The gateway exits with code 75 for a planned service restart.
while time.time() < deadline: # RestartSec can otherwise delay the relaunch even though the
try: # operator asked for an immediate restart, so kick the unit once
os.kill(pid, 0) # the old PID has exited and then wait for the replacement PID.
time.sleep(1) _run_systemctl(
except (ProcessLookupError, PermissionError): ["reset-failed", svc],
break # old process is gone system=system,
else: check=False,
print(f"⚠ Old process (PID {pid}) still alive after 90s") timeout=30,
)
_run_systemctl(
["restart", svc],
system=system,
check=False,
timeout=90,
)
if _wait_for_systemd_service_restart(system=system, previous_pid=pid):
return
if _systemd_service_is_start_limited(system=system):
return
# The gateway exits with code 75 for a planned service restart. print(
# systemd can sit in the RestartSec window or even wedge itself into a f"⚠ Graceful restart did not complete within {int(drain_timeout + 5)}s; "
# failed/rate-limited state if the operator asks for another restart in "forcing a service restart..."
# the middle of that handoff. Clear any stale failed state and kick the )
# unit immediately so `hermes gateway restart` behaves idempotently.
_run_systemctl( _run_systemctl(
["reset-failed", svc], ["reset-failed", svc],
system=system, system=system,
check=False, check=False,
timeout=30, timeout=30,
) )
_run_systemctl( try:
["start", svc], _run_systemctl(["restart", svc], system=system, check=True, timeout=90)
system=system, except subprocess.CalledProcessError as exc:
check=False, if _systemd_error_indicates_start_limit(exc) or _systemd_service_is_start_limited(system=system):
timeout=90, _print_systemd_start_limit_wait(system=system)
) return
raise
_wait_for_systemd_service_restart(system=system, previous_pid=pid) _wait_for_systemd_service_restart(system=system, previous_pid=pid)
return return
@ -2182,8 +2283,14 @@ def systemd_restart(system: bool = False):
check=False, check=False,
timeout=30, timeout=30,
) )
_run_systemctl(["reload-or-restart", get_service_name()], system=system, check=True, timeout=90) try:
print(f"{_service_scope_label(system).capitalize()} service restarted") _run_systemctl(["restart", get_service_name()], system=system, check=True, timeout=90)
except subprocess.CalledProcessError as exc:
if _systemd_error_indicates_start_limit(exc) or _systemd_service_is_start_limited(system=system):
_print_systemd_start_limit_wait(system=system)
return
raise
_wait_for_systemd_service_restart(system=system, previous_pid=pid)
@ -2255,6 +2362,10 @@ def systemd_status(deep: bool = False, system: bool = False, full: bool = False)
result_code = unit_props.get("Result", "") result_code = unit_props.get("Result", "")
if active_state == "activating" and sub_state == "auto-restart": if active_state == "activating" and sub_state == "auto-restart":
print(" ⏳ Restart pending: systemd is waiting to relaunch the gateway") print(" ⏳ Restart pending: systemd is waiting to relaunch the gateway")
elif _systemd_unit_is_start_limited(unit_props):
print(" ⏳ Restart pending: systemd is temporarily rate-limiting starts")
print(f" Run after the start-limit window expires: {'sudo ' if system else ''}hermes gateway restart{scope_flag}")
print(f" Or clear it manually: systemctl {'--user ' if not system else ''}reset-failed {get_service_name()}")
elif active_state == "failed" and exec_main_status == str(GATEWAY_SERVICE_RESTART_EXIT_CODE): elif active_state == "failed" and exec_main_status == str(GATEWAY_SERVICE_RESTART_EXIT_CODE):
print(" ⚠ Planned restart is stuck in systemd failed state (exit 75)") print(" ⚠ Planned restart is stuck in systemd failed state (exit 75)")
print(f" Run: systemctl {'--user ' if not system else ''}reset-failed {get_service_name()} && {'sudo ' if system else ''}hermes gateway start{scope_flag}") print(f" Run: systemctl {'--user ' if not system else ''}reset-failed {get_service_name()} && {'sudo ' if system else ''}hermes gateway start{scope_flag}")

View file

@ -1,4 +1,5 @@
import asyncio import asyncio
import json
import sys import sys
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
@ -70,6 +71,15 @@ import gateway.platforms.discord as discord_platform # noqa: E402
from gateway.platforms.discord import DiscordAdapter # noqa: E402 from gateway.platforms.discord import DiscordAdapter # noqa: E402
@pytest.fixture(autouse=True)
def _speed_up_command_sync_mutation_pacing(monkeypatch):
monkeypatch.setattr(
DiscordAdapter,
"_command_sync_mutation_interval_seconds",
lambda self: 0.0,
)
class FakeTree: class FakeTree:
def __init__(self): def __init__(self):
self.sync = AsyncMock(return_value=[]) self.sync = AsyncMock(return_value=[])
@ -536,6 +546,136 @@ async def test_post_connect_initialization_skips_sync_when_policy_off(monkeypatc
fake_tree.sync.assert_not_called() fake_tree.sync.assert_not_called()
@pytest.mark.asyncio
async def test_post_connect_initialization_skips_same_fingerprint_after_success(tmp_path, monkeypatch):
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="test-token"))
monkeypatch.setattr("hermes_constants.get_hermes_home", lambda: tmp_path)
class _DesiredCommand:
def to_dict(self, tree):
return {
"name": "status",
"description": "Show Hermes status",
"type": 1,
"options": [],
}
fake_tree = SimpleNamespace(
get_commands=lambda: [_DesiredCommand()],
fetch_commands=AsyncMock(return_value=[]),
)
fake_http = SimpleNamespace(
upsert_global_command=AsyncMock(),
edit_global_command=AsyncMock(),
delete_global_command=AsyncMock(),
)
adapter._client = SimpleNamespace(
tree=fake_tree,
http=fake_http,
application_id=999,
user=SimpleNamespace(id=999),
)
await adapter._run_post_connect_initialization()
await adapter._run_post_connect_initialization()
fake_tree.fetch_commands.assert_awaited_once()
fake_http.upsert_global_command.assert_awaited_once()
@pytest.mark.asyncio
async def test_post_connect_initialization_respects_discord_retry_after(tmp_path, monkeypatch):
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="test-token"))
monkeypatch.setattr("hermes_constants.get_hermes_home", lambda: tmp_path)
class _DesiredCommand:
def to_dict(self, tree):
return {
"name": "status",
"description": "Show Hermes status",
"type": 1,
"options": [],
}
adapter._client = SimpleNamespace(
tree=SimpleNamespace(get_commands=lambda: [_DesiredCommand()]),
application_id=999,
user=SimpleNamespace(id=999),
)
class _DiscordRateLimit(RuntimeError):
retry_after = 123.0
sync = AsyncMock(side_effect=_DiscordRateLimit("discord rate limited"))
monkeypatch.setattr(adapter, "_safe_sync_slash_commands", sync)
await adapter._run_post_connect_initialization()
await adapter._run_post_connect_initialization()
sync.assert_awaited_once()
state = json.loads((tmp_path / discord_platform._DISCORD_COMMAND_SYNC_STATE_FILE).read_text())
entry = state["999"]
assert entry["retry_after"] == 123.0
assert entry["retry_after_until"] > entry["last_attempt_at"]
@pytest.mark.asyncio
async def test_safe_sync_slash_commands_paces_mutation_writes(monkeypatch):
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="test-token"))
monkeypatch.setattr(
DiscordAdapter,
"_command_sync_mutation_interval_seconds",
lambda self: 1.25,
)
sleeps = []
async def fake_sleep(delay):
sleeps.append(delay)
monkeypatch.setattr(discord_platform.asyncio, "sleep", fake_sleep)
class _DesiredCommand:
def __init__(self, payload):
self._payload = payload
def to_dict(self, tree):
assert tree is not None
return dict(self._payload)
desired_one = {
"name": "status",
"description": "Show Hermes status",
"type": 1,
"options": [],
}
desired_two = {
"name": "debug",
"description": "Generate a debug report",
"type": 1,
"options": [],
}
fake_tree = SimpleNamespace(
get_commands=lambda: [_DesiredCommand(desired_one), _DesiredCommand(desired_two)],
fetch_commands=AsyncMock(return_value=[]),
)
fake_http = SimpleNamespace(
upsert_global_command=AsyncMock(),
edit_global_command=AsyncMock(),
delete_global_command=AsyncMock(),
)
adapter._client = SimpleNamespace(
tree=fake_tree,
http=fake_http,
application_id=999,
user=SimpleNamespace(id=999),
)
summary = await adapter._safe_sync_slash_commands()
assert summary["created"] == 2
assert fake_http.upsert_global_command.await_count == 2
assert sleeps == [1.25]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_safe_sync_reads_permission_attrs_from_existing_command(): async def test_safe_sync_reads_permission_attrs_from_existing_command():
"""Regression: AppCommand.to_dict() in discord.py does NOT include """Regression: AppCommand.to_dict() in discord.py does NOT include

View file

@ -2,6 +2,7 @@
import os import os
import pwd import pwd
import subprocess
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
@ -90,6 +91,13 @@ class TestSystemdServiceRefresh:
monkeypatch.setattr(gateway_cli, "generate_systemd_unit", lambda system=False, run_as_user=None: "new unit\n") monkeypatch.setattr(gateway_cli, "generate_systemd_unit", lambda system=False, run_as_user=None: "new unit\n")
calls = [] calls = []
monkeypatch.setattr("gateway.status.get_running_pid", lambda: None)
monkeypatch.setattr(gateway_cli, "_recover_pending_systemd_restart", lambda system=False, previous_pid=None: False)
monkeypatch.setattr(
gateway_cli,
"_wait_for_systemd_service_restart",
lambda system=False, previous_pid=None: calls.append(("wait", system, previous_pid)) or True,
)
def fake_run(cmd, check=True, **kwargs): def fake_run(cmd, check=True, **kwargs):
calls.append(cmd) calls.append(cmd)
@ -100,11 +108,12 @@ class TestSystemdServiceRefresh:
gateway_cli.systemd_restart() gateway_cli.systemd_restart()
assert unit_path.read_text(encoding="utf-8") == "new unit\n" assert unit_path.read_text(encoding="utf-8") == "new unit\n"
assert calls[:4] == [ assert calls[:5] == [
["systemctl", "--user", "daemon-reload"], ["systemctl", "--user", "daemon-reload"],
["systemctl", "--user", "show", gateway_cli.get_service_name(), "--no-pager", "--property", "ActiveState,SubState,Result,ExecMainStatus"], ["systemctl", "--user", "show", gateway_cli.get_service_name(), "--no-pager", "--property", "ActiveState,SubState,Result,ExecMainStatus,MainPID"],
["systemctl", "--user", "reset-failed", gateway_cli.get_service_name()], ["systemctl", "--user", "reset-failed", gateway_cli.get_service_name()],
["systemctl", "--user", "reload-or-restart", gateway_cli.get_service_name()], ["systemctl", "--user", "restart", gateway_cli.get_service_name()],
("wait", False, None),
] ]
def test_systemd_stop_marks_running_gateway_as_planned_stop(self, monkeypatch): def test_systemd_stop_marks_running_gateway_as_planned_stop(self, monkeypatch):
@ -611,62 +620,141 @@ class TestGatewayServiceDetection:
assert gateway_cli._is_service_running() is False assert gateway_cli._is_service_running() is False
class TestGatewaySystemServiceRouting: class TestGatewaySystemServiceRouting:
def test_systemd_restart_self_requests_graceful_restart_and_waits(self, monkeypatch, capsys): def test_systemd_restart_gracefully_restarts_running_service_and_waits(self, monkeypatch, capsys):
calls = [] calls = []
monkeypatch.setattr(gateway_cli, "_select_systemd_scope", lambda system=False: False) monkeypatch.setattr(gateway_cli, "_select_systemd_scope", lambda system=False: False)
monkeypatch.setattr(gateway_cli, "_require_service_installed", lambda action, system=False: None) monkeypatch.setattr(gateway_cli, "_require_service_installed", lambda action, system=False: None)
monkeypatch.setattr(gateway_cli, "refresh_systemd_unit_if_needed", lambda system=False: calls.append(("refresh", system))) monkeypatch.setattr(gateway_cli, "refresh_systemd_unit_if_needed", lambda system=False: calls.append(("refresh", system)))
monkeypatch.setattr(gateway_cli, "_get_restart_drain_timeout", lambda: 12.0)
monkeypatch.setattr( monkeypatch.setattr(
"gateway.status.get_running_pid", "gateway.status.get_running_pid",
lambda: 654, lambda: 654,
) )
monkeypatch.setattr( monkeypatch.setattr(
gateway_cli, gateway_cli,
"_request_gateway_self_restart", "_graceful_restart_via_sigusr1",
lambda pid: calls.append(("self", pid)) or True, lambda pid, timeout: calls.append(("graceful", pid, timeout)) or True,
) )
# Simulate: old process dies immediately, new process becomes active # Simulate systemctl reset-failed/restart followed by an active unit.
kill_call_count = [0] # A plain start does not break systemd's auto-restart timer once the
def fake_kill(pid, sig): # old gateway has exited with the planned restart code.
kill_call_count[0] += 1
if kill_call_count[0] >= 2: # first call checks, second = dead
raise ProcessLookupError()
monkeypatch.setattr(os, "kill", fake_kill)
# Simulate systemctl reset-failed/start followed by an active unit
new_pid = [None]
def fake_subprocess_run(cmd, **kwargs): def fake_subprocess_run(cmd, **kwargs):
if "reset-failed" in cmd: if "reset-failed" in cmd:
calls.append(("reset-failed", cmd)) calls.append(("reset-failed", cmd))
return SimpleNamespace(stdout="", returncode=0) return SimpleNamespace(stdout="", returncode=0)
if "start" in cmd: if "restart" in cmd:
calls.append(("start", cmd)) calls.append(("restart", cmd))
return SimpleNamespace(stdout="", returncode=0) return SimpleNamespace(stdout="", returncode=0)
if "show" in cmd:
new_pid[0] = 999
return SimpleNamespace(
stdout="ActiveState=active\nSubState=running\nResult=success\nExecMainStatus=0\n",
returncode=0,
)
raise AssertionError(f"Unexpected systemctl call: {cmd}") raise AssertionError(f"Unexpected systemctl call: {cmd}")
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_subprocess_run) monkeypatch.setattr(gateway_cli.subprocess, "run", fake_subprocess_run)
# get_running_pid returns new PID after restart monkeypatch.setattr(
pid_calls = [0] gateway_cli,
def fake_get_pid(): "_wait_for_systemd_service_restart",
pid_calls[0] += 1 lambda system=False, previous_pid=None: calls.append(("wait", system, previous_pid)) or True,
return 999 if pid_calls[0] > 1 else 654 )
monkeypatch.setattr("gateway.status.get_running_pid", fake_get_pid)
gateway_cli.systemd_restart() gateway_cli.systemd_restart()
assert ("self", 654) in calls assert ("graceful", 654, 17.0) in calls
assert any(call[0] == "reset-failed" for call in calls) assert any(call[0] == "reset-failed" for call in calls)
assert any(call[0] == "start" for call in calls) assert any(call[0] == "restart" for call in calls)
assert ("wait", False, 654) in calls
out = capsys.readouterr().out.lower() out = capsys.readouterr().out.lower()
assert "restarted" in out assert "restarting gracefully" in out
def test_systemd_restart_uses_systemd_main_pid_when_pid_file_is_missing(self, monkeypatch, capsys):
calls = []
monkeypatch.setattr(gateway_cli, "_select_systemd_scope", lambda system=False: False)
monkeypatch.setattr(gateway_cli, "_require_service_installed", lambda action, system=False: None)
monkeypatch.setattr(gateway_cli, "refresh_systemd_unit_if_needed", lambda system=False: None)
monkeypatch.setattr(gateway_cli, "_get_restart_drain_timeout", lambda: 10.0)
monkeypatch.setattr("gateway.status.get_running_pid", lambda: None)
monkeypatch.setattr(
gateway_cli,
"_read_systemd_unit_properties",
lambda system=False: {
"ActiveState": "active",
"SubState": "running",
"Result": "success",
"ExecMainStatus": "0",
"MainPID": "777",
},
)
monkeypatch.setattr(
gateway_cli,
"_graceful_restart_via_sigusr1",
lambda pid, timeout: calls.append(("graceful", pid, timeout)) or True,
)
monkeypatch.setattr(gateway_cli, "_run_systemctl", lambda args, **kwargs: calls.append(args) or SimpleNamespace(stdout="", returncode=0))
monkeypatch.setattr(
gateway_cli,
"_wait_for_systemd_service_restart",
lambda system=False, previous_pid=None: calls.append(("wait", system, previous_pid)) or True,
)
gateway_cli.systemd_restart()
assert ("graceful", 777, 15.0) in calls
assert ("wait", False, 777) in calls
assert "restarting gracefully (pid 777)" in capsys.readouterr().out.lower()
def test_wait_for_systemd_restart_waits_for_runtime_running(self, monkeypatch, capsys):
monkeypatch.setattr(
gateway_cli,
"_read_systemd_unit_properties",
lambda system=False: {
"ActiveState": "active",
"SubState": "running",
"Result": "success",
"ExecMainStatus": "0",
"MainPID": "999",
},
)
monkeypatch.setattr("gateway.status.get_running_pid", lambda: None)
monkeypatch.setattr(
gateway_cli,
"_gateway_runtime_status_for_pid",
lambda pid: {"pid": pid, "gateway_state": "running"},
)
assert gateway_cli._wait_for_systemd_service_restart(previous_pid=777, timeout=0.1) is True
assert "restarted (pid 999)" in capsys.readouterr().out.lower()
def test_systemd_restart_reports_start_limit_hit(self, monkeypatch, capsys):
calls = []
monkeypatch.setattr(gateway_cli, "_select_systemd_scope", lambda system=False: False)
monkeypatch.setattr(gateway_cli, "_require_service_installed", lambda action, system=False: None)
monkeypatch.setattr(gateway_cli, "refresh_systemd_unit_if_needed", lambda system=False: None)
monkeypatch.setattr("gateway.status.get_running_pid", lambda: None)
monkeypatch.setattr(gateway_cli, "_recover_pending_systemd_restart", lambda system=False, previous_pid=None: False)
def fake_run_systemctl(args, **kwargs):
calls.append(args)
if args[0] == "show":
return SimpleNamespace(stdout="ActiveState=inactive\nSubState=dead\nResult=success\nExecMainStatus=0\nMainPID=0\n", stderr="", returncode=0)
if args[0] == "reset-failed":
return SimpleNamespace(stdout="", stderr="", returncode=0)
if args[0] == "restart":
raise subprocess.CalledProcessError(
1,
["systemctl", "--user", *args],
stderr="Job failed. See result 'start-limit-hit'.",
)
raise AssertionError(f"Unexpected args: {args}")
monkeypatch.setattr(gateway_cli, "_run_systemctl", fake_run_systemctl)
gateway_cli.systemd_restart()
assert ["restart", gateway_cli.get_service_name()] in calls
out = capsys.readouterr().out.lower()
assert "rate-limited by systemd" in out
assert "reset-failed" in out
def test_systemd_restart_recovers_failed_planned_restart(self, monkeypatch, capsys): def test_systemd_restart_recovers_failed_planned_restart(self, monkeypatch, capsys):
monkeypatch.setattr(gateway_cli, "_select_systemd_scope", lambda system=False: False) monkeypatch.setattr(gateway_cli, "_select_systemd_scope", lambda system=False: False)
@ -711,6 +799,11 @@ class TestGatewaySystemServiceRouting:
"gateway.status.get_running_pid", "gateway.status.get_running_pid",
lambda: 999 if started["value"] else None, lambda: 999 if started["value"] else None,
) )
monkeypatch.setattr(
gateway_cli,
"_gateway_runtime_status_for_pid",
lambda pid: {"pid": pid, "gateway_state": "running"},
)
gateway_cli.systemd_restart() gateway_cli.systemd_restart()