fix: hermes update kills freshly-restarted gateway service

After restarting a service-managed gateway (systemd/launchd), the
stale-process sweep calls find_gateway_pids() which returns ALL gateway
PIDs via ps aux — including the one just spawned by the service manager.
The sweep kills it, leaving the user with a stopped gateway and a
confusing 'Restart manually' message.

Fix: add _get_service_pids() to query systemd MainPID and launchd PID
for active gateway services, then exclude those PIDs from the sweep.
Also add exclude_pids parameter to find_gateway_pids() and
kill_gateway_processes() so callers can skip known service-managed PIDs.

Adds 9 targeted tests covering:
- _get_service_pids() for systemd, launchd, empty, and zero-PID cases
- find_gateway_pids() exclude_pids filtering
- cmd_update integration: service PID not killed after restart
- cmd_update integration: manual PID killed while service PID preserved
This commit is contained in:
kshitijk4poor 2026-04-06 09:52:22 +05:30 committed by Teknium
parent 9c96f669a1
commit a2a9ad7431
3 changed files with 371 additions and 9 deletions

View file

@ -28,9 +28,101 @@ from hermes_cli.colors import Colors, color
# Process Management (for manual gateway runs)
# =============================================================================
def find_gateway_pids() -> list:
"""Find PIDs of running gateway processes."""
def _get_service_pids() -> set:
"""Return PIDs currently managed by systemd or launchd gateway services.
Used to avoid killing freshly-restarted service processes when sweeping
for stale manual gateway processes after a service restart.
"""
pids: set = set()
# --- systemd (Linux) ---
if is_linux():
try:
result = subprocess.run(
["systemctl", "--user", "list-units", "hermes-gateway*",
"--plain", "--no-legend", "--no-pager"],
capture_output=True, text=True, timeout=5,
)
for line in result.stdout.strip().splitlines():
parts = line.split()
if not parts or not parts[0].endswith(".service"):
continue
svc = parts[0]
try:
show = subprocess.run(
["systemctl", "--user", "show", svc,
"--property=MainPID", "--value"],
capture_output=True, text=True, timeout=5,
)
pid = int(show.stdout.strip())
if pid > 0:
pids.add(pid)
except (ValueError, subprocess.TimeoutExpired):
pass
except (FileNotFoundError, subprocess.TimeoutExpired):
pass
# Also check system scope
try:
result = subprocess.run(
["systemctl", "list-units", "hermes-gateway*",
"--plain", "--no-legend", "--no-pager"],
capture_output=True, text=True, timeout=5,
)
for line in result.stdout.strip().splitlines():
parts = line.split()
if not parts or not parts[0].endswith(".service"):
continue
svc = parts[0]
try:
show = subprocess.run(
["systemctl", "show", svc,
"--property=MainPID", "--value"],
capture_output=True, text=True, timeout=5,
)
pid = int(show.stdout.strip())
if pid > 0:
pids.add(pid)
except (ValueError, subprocess.TimeoutExpired):
pass
except (FileNotFoundError, subprocess.TimeoutExpired):
pass
# --- launchd (macOS) ---
if is_macos():
try:
from hermes_cli.gateway import get_launchd_label
result = subprocess.run(
["launchctl", "list", get_launchd_label()],
capture_output=True, text=True, timeout=5,
)
if result.returncode == 0:
# Output format: "PID\tStatus\tLabel" header then data line
for line in result.stdout.strip().splitlines():
parts = line.split()
if parts:
try:
pid = int(parts[0])
if pid > 0:
pids.add(pid)
except ValueError:
pass
except (FileNotFoundError, subprocess.TimeoutExpired):
pass
return pids
def find_gateway_pids(exclude_pids: set | None = None) -> list:
"""Find PIDs of running gateway processes.
Args:
exclude_pids: PIDs to exclude from the result (e.g. service-managed
PIDs that should not be killed during a stale-process sweep).
"""
pids = []
_exclude = exclude_pids or set()
patterns = [
"hermes_cli.main gateway",
"hermes_cli/main.py gateway",
@ -56,7 +148,7 @@ def find_gateway_pids() -> list:
if any(p in current_cmd for p in patterns):
try:
pid = int(pid_str)
if pid != os.getpid() and pid not in pids:
if pid != os.getpid() and pid not in pids and pid not in _exclude:
pids.append(pid)
except ValueError:
pass
@ -78,7 +170,7 @@ def find_gateway_pids() -> list:
if len(parts) > 1:
try:
pid = int(parts[1])
if pid not in pids:
if pid not in pids and pid not in _exclude:
pids.append(pid)
except ValueError:
continue
@ -89,9 +181,15 @@ def find_gateway_pids() -> list:
return pids
def kill_gateway_processes(force: bool = False) -> int:
"""Kill ALL running gateway processes (across all profiles). Returns count killed."""
pids = find_gateway_pids()
def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None) -> int:
"""Kill any running gateway processes. Returns count killed.
Args:
force: Use SIGKILL instead of SIGTERM.
exclude_pids: PIDs to skip (e.g. service-managed PIDs that were just
restarted and should not be killed).
"""
pids = find_gateway_pids(exclude_pids=exclude_pids)
killed = 0
for pid in pids: