diff --git a/tests/tools/test_notify_on_complete.py b/tests/tools/test_notify_on_complete.py index 888721906..8cf17bfbf 100644 --- a/tests/tools/test_notify_on_complete.py +++ b/tests/tools/test_notify_on_complete.py @@ -197,6 +197,26 @@ class TestCheckpointNotify: s = registry.get("proc_live") assert s.notify_on_complete is True + def test_recover_requeues_notify_watchers(self, registry, tmp_path): + checkpoint = tmp_path / "procs.json" + checkpoint.write_text(json.dumps([{ + "session_id": "proc_live", + "command": "sleep 999", + "pid": os.getpid(), + "task_id": "t1", + "session_key": "sk1", + "watcher_platform": "telegram", + "watcher_chat_id": "123", + "watcher_thread_id": "42", + "watcher_interval": 5, + "notify_on_complete": True, + }])) + with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint): + recovered = registry.recover_from_checkpoint() + assert recovered == 1 + assert len(registry.pending_watchers) == 1 + assert registry.pending_watchers[0]["notify_on_complete"] is True + def test_recover_defaults_false(self, registry, tmp_path): """Old checkpoint entries without the field default to False.""" checkpoint = tmp_path / "procs.json" diff --git a/tests/tools/test_process_registry.py b/tests/tools/test_process_registry.py index e6cfa40e7..44e3a1bd3 100644 --- a/tests/tools/test_process_registry.py +++ b/tests/tools/test_process_registry.py @@ -2,6 +2,9 @@ import json import os +import signal +import subprocess +import sys import time import pytest from pathlib import Path @@ -45,6 +48,23 @@ def _make_session( return s +def _spawn_python_sleep(seconds: float) -> subprocess.Popen: + """Spawn a portable short-lived Python sleep process.""" + return subprocess.Popen( + [sys.executable, "-c", f"import time; time.sleep({seconds})"], + ) + + +def _wait_until(predicate, timeout: float = 5.0, interval: float = 0.05) -> bool: + """Poll a predicate until it returns truthy or the timeout elapses.""" + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if predicate(): + return True + time.sleep(interval) + return False + + # ========================================================================= # Get / Poll # ========================================================================= @@ -349,6 +369,88 @@ class TestCheckpoint: assert recovered == 1 assert len(registry.pending_watchers) == 0 + def test_recovery_keeps_live_checkpoint_entries(self, registry, tmp_path): + checkpoint = tmp_path / "procs.json" + checkpoint.write_text(json.dumps([{ + "session_id": "proc_live", + "command": "sleep 999", + "pid": os.getpid(), + "task_id": "t1", + "session_key": "sk1", + }])) + + with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint): + recovered = registry.recover_from_checkpoint() + assert recovered == 1 + assert registry.get("proc_live") is not None + + data = json.loads(checkpoint.read_text()) + assert len(data) == 1 + assert data[0]["session_id"] == "proc_live" + assert data[0]["pid"] == os.getpid() + assert data != [] + + def test_recovery_skips_explicit_sandbox_backed_entries(self, registry, tmp_path): + checkpoint = tmp_path / "procs.json" + original = [{ + "session_id": "proc_remote", + "command": "sleep 999", + "pid": os.getpid(), + "task_id": "t1", + "pid_scope": "sandbox", + }] + checkpoint.write_text(json.dumps(original)) + + with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint): + recovered = registry.recover_from_checkpoint() + assert recovered == 0 + assert registry.get("proc_remote") is None + + data = json.loads(checkpoint.read_text()) + assert data == [] + + def test_detached_recovered_process_eventually_exits(self, registry, tmp_path): + proc = _spawn_python_sleep(0.4) + checkpoint = tmp_path / "procs.json" + checkpoint.write_text(json.dumps([{ + "session_id": "proc_live", + "command": "python -c 'import time; time.sleep(0.4)'", + "pid": proc.pid, + "task_id": "t1", + "session_key": "sk1", + }])) + + try: + with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint): + recovered = registry.recover_from_checkpoint() + assert recovered == 1 + + session = registry.get("proc_live") + assert session is not None + assert session.detached is True + + proc.wait(timeout=5) + + assert _wait_until( + lambda: registry.get("proc_live") is not None + and registry.get("proc_live").exited, + timeout=5, + ) + + poll_result = registry.poll("proc_live") + assert poll_result["status"] == "exited" + + wait_result = registry.wait("proc_live", timeout=1) + assert wait_result["status"] == "exited" + finally: + if proc.poll() is None: + proc.terminate() + try: + proc.wait(timeout=5) + except Exception: + proc.kill() + proc.wait(timeout=5) + # ========================================================================= # Kill process @@ -365,6 +467,27 @@ class TestKillProcess: result = registry.kill_process(s.id) assert result["status"] == "already_exited" + def test_kill_detached_session_uses_host_pid(self, registry): + s = _make_session(sid="proc_detached", command="sleep 999") + s.pid = 424242 + s.detached = True + registry._running[s.id] = s + + calls = [] + + def fake_kill(pid, sig): + calls.append((pid, sig)) + + try: + with patch("tools.process_registry.os.kill", side_effect=fake_kill): + result = registry.kill_process(s.id) + + assert result["status"] == "killed" + assert (424242, 0) in calls + assert (424242, signal.SIGTERM) in calls + finally: + registry._running.pop(s.id, None) + # ========================================================================= # Tool handler diff --git a/tools/process_registry.py b/tools/process_registry.py index 948f073ab..b935f49c3 100644 --- a/tools/process_registry.py +++ b/tools/process_registry.py @@ -76,6 +76,7 @@ class ProcessSession: output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS) max_output_chars: int = MAX_OUTPUT_CHARS detached: bool = False # True if recovered from crash (no pipe) + pid_scope: str = "host" # "host" for local/PTY PIDs, "sandbox" for env-local PIDs # Watcher/notification metadata (persisted for crash recovery) watcher_platform: str = "" watcher_chat_id: str = "" @@ -127,6 +128,48 @@ class ProcessRegistry: lines.pop(0) return "\n".join(lines) + @staticmethod + def _is_host_pid_alive(pid: Optional[int]) -> bool: + """Best-effort liveness check for host-visible PIDs.""" + if not pid: + return False + try: + os.kill(pid, 0) + return True + except (ProcessLookupError, PermissionError): + return False + + def _refresh_detached_session(self, session: Optional[ProcessSession]) -> Optional[ProcessSession]: + """Update recovered host-PID sessions when the underlying process has exited.""" + if session is None or session.exited or not session.detached or session.pid_scope != "host": + return session + + if self._is_host_pid_alive(session.pid): + return session + + with session._lock: + if session.exited: + return session + session.exited = True + # Recovered sessions no longer have a waitable handle, so the real + # exit code is unavailable once the original process object is gone. + session.exit_code = None + + self._move_to_finished(session) + return session + + @staticmethod + def _terminate_host_pid(pid: int) -> None: + """Terminate a host-visible PID without requiring the original process handle.""" + if _IS_WINDOWS: + os.kill(pid, signal.SIGTERM) + return + + try: + os.killpg(os.getpgid(pid), signal.SIGTERM) + except (OSError, ProcessLookupError, PermissionError): + os.kill(pid, signal.SIGTERM) + # ----- Spawn ----- def spawn_local( @@ -269,6 +312,7 @@ class ProcessRegistry: cwd=cwd, started_at=time.time(), env_ref=env, + pid_scope="sandbox", ) # Run the command in the sandbox with output capture @@ -439,7 +483,8 @@ class ProcessRegistry: def get(self, session_id: str) -> Optional[ProcessSession]: """Get a session by ID (running or finished).""" with self._lock: - return self._running.get(session_id) or self._finished.get(session_id) + session = self._running.get(session_id) or self._finished.get(session_id) + return self._refresh_detached_session(session) def poll(self, session_id: str) -> dict: """Check status and get new output for a background process.""" @@ -531,6 +576,7 @@ class ProcessRegistry: deadline = time.monotonic() + effective_timeout while time.monotonic() < deadline: + session = self._refresh_detached_session(session) if session.exited: result = { "status": "exited", @@ -596,6 +642,25 @@ class ProcessRegistry: elif session.env_ref and session.pid: # Non-local -- kill inside sandbox session.env_ref.execute(f"kill {session.pid} 2>/dev/null", timeout=5) + elif session.detached and session.pid_scope == "host" and session.pid: + if not self._is_host_pid_alive(session.pid): + with session._lock: + session.exited = True + session.exit_code = None + self._move_to_finished(session) + return { + "status": "already_exited", + "exit_code": session.exit_code, + } + self._terminate_host_pid(session.pid) + else: + return { + "status": "error", + "error": ( + "Recovered process cannot be killed after restart because " + "its original runtime handle is no longer available" + ), + } session.exited = True session.exit_code = -15 # SIGTERM self._move_to_finished(session) @@ -640,6 +705,8 @@ class ProcessRegistry: with self._lock: all_sessions = list(self._running.values()) + list(self._finished.values()) + all_sessions = [self._refresh_detached_session(s) for s in all_sessions] + if task_id: all_sessions = [s for s in all_sessions if s.task_id == task_id] @@ -666,6 +733,12 @@ class ProcessRegistry: def has_active_processes(self, task_id: str) -> bool: """Check if there are active (running) processes for a task_id.""" + with self._lock: + sessions = list(self._running.values()) + + for session in sessions: + self._refresh_detached_session(session) + with self._lock: return any( s.task_id == task_id and not s.exited @@ -674,6 +747,12 @@ class ProcessRegistry: def has_active_for_session(self, session_key: str) -> bool: """Check if there are active processes for a gateway session key.""" + with self._lock: + sessions = list(self._running.values()) + + for session in sessions: + self._refresh_detached_session(session) + with self._lock: return any( s.session_key == session_key and not s.exited @@ -727,6 +806,7 @@ class ProcessRegistry: "session_id": s.id, "command": s.command, "pid": s.pid, + "pid_scope": s.pid_scope, "cwd": s.cwd, "started_at": s.started_at, "task_id": s.task_id, @@ -764,13 +844,21 @@ class ProcessRegistry: if not pid: continue + pid_scope = entry.get("pid_scope", "host") + if pid_scope != "host": + # Sandbox-backed processes keep only in-sandbox PIDs in the + # checkpoint, which are not meaningful to the restarted host + # process once the original environment handle is gone. + logger.info( + "Skipping recovery for non-host process: %s (pid=%s, scope=%s)", + entry.get("command", "unknown")[:60], + pid, + pid_scope, + ) + continue + # Check if PID is still alive - alive = False - try: - os.kill(pid, 0) - alive = True - except (ProcessLookupError, PermissionError): - pass + alive = self._is_host_pid_alive(pid) if alive: session = ProcessSession( @@ -779,6 +867,7 @@ class ProcessRegistry: task_id=entry.get("task_id", ""), session_key=entry.get("session_key", ""), pid=pid, + pid_scope=pid_scope, cwd=entry.get("cwd"), started_at=entry.get("started_at", time.time()), detached=True, # Can't read output, but can report status + kill @@ -802,14 +891,10 @@ class ProcessRegistry: "platform": session.watcher_platform, "chat_id": session.watcher_chat_id, "thread_id": session.watcher_thread_id, + "notify_on_complete": session.notify_on_complete, }) - # Clear the checkpoint (will be rewritten as processes finish) - try: - from utils import atomic_json_write - atomic_json_write(CHECKPOINT_PATH, []) - except Exception as e: - logger.debug("Could not clear checkpoint file: %s", e, exc_info=True) + self._write_checkpoint() return recovered