diff --git a/tools/environments/local.py b/tools/environments/local.py index 1641d12a4d..5859d6e409 100644 --- a/tools/environments/local.py +++ b/tools/environments/local.py @@ -247,10 +247,10 @@ class LocalEnvironment(PersistentShellMixin, BaseEnvironment): def _read_temp_files(self, *paths: str) -> list[str]: results = [] for path in paths: - try: + if os.path.exists(path): with open(path) as f: results.append(f.read()) - except OSError: + else: results.append("") return results @@ -262,15 +262,13 @@ class LocalEnvironment(PersistentShellMixin, BaseEnvironment): ["pkill", "-P", str(self._shell_pid)], capture_output=True, timeout=5, ) - except (subprocess.TimeoutExpired, OSError, FileNotFoundError): + except (subprocess.TimeoutExpired, FileNotFoundError): pass def _cleanup_temp_files(self): for f in glob.glob(f"{self._temp_prefix}-*"): - try: + if os.path.exists(f): os.remove(f) - except OSError: - pass def _execute_oneshot(self, command: str, cwd: str = "", *, timeout: int | None = None, @@ -286,106 +284,87 @@ class LocalEnvironment(PersistentShellMixin, BaseEnvironment): else: effective_stdin = stdin_data - try: - user_shell = _find_bash() - fenced_cmd = ( - f"printf '{_OUTPUT_FENCE}';" - f" {exec_command};" - f" __hermes_rc=$?;" - f" printf '{_OUTPUT_FENCE}';" - f" exit $__hermes_rc" - ) - run_env = _make_run_env(self.env) + user_shell = _find_bash() + fenced_cmd = ( + f"printf '{_OUTPUT_FENCE}';" + f" {exec_command};" + f" __hermes_rc=$?;" + f" printf '{_OUTPUT_FENCE}';" + f" exit $__hermes_rc" + ) + run_env = _make_run_env(self.env) - proc = subprocess.Popen( - [user_shell, "-lic", fenced_cmd], - text=True, - cwd=work_dir, - env=run_env, - encoding="utf-8", - errors="replace", - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL, - preexec_fn=None if _IS_WINDOWS else os.setsid, - ) - - if effective_stdin is not None: - def _write_stdin(): - try: - proc.stdin.write(effective_stdin) - proc.stdin.close() - except (BrokenPipeError, OSError): - pass - threading.Thread(target=_write_stdin, daemon=True).start() - - _output_chunks: list[str] = [] - - def _drain_stdout(): - try: - for line in proc.stdout: - _output_chunks.append(line) - except ValueError: - pass - finally: - try: - proc.stdout.close() - except Exception: - pass - - reader = threading.Thread(target=_drain_stdout, daemon=True) - reader.start() - deadline = time.monotonic() + effective_timeout - - while proc.poll() is None: - if is_interrupted(): - try: - if _IS_WINDOWS: - proc.terminate() - else: - pgid = os.getpgid(proc.pid) - os.killpg(pgid, signal.SIGTERM) - try: - proc.wait(timeout=1.0) - except subprocess.TimeoutExpired: - os.killpg(pgid, signal.SIGKILL) - except (ProcessLookupError, PermissionError): - proc.kill() - reader.join(timeout=2) - return { - "output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]", - "returncode": 130, - } - if time.monotonic() > deadline: - try: - if _IS_WINDOWS: - proc.terminate() - else: - os.killpg(os.getpgid(proc.pid), signal.SIGTERM) - except (ProcessLookupError, PermissionError): - proc.kill() - reader.join(timeout=2) - return self._timeout_result(effective_timeout) - time.sleep(0.2) - - reader.join(timeout=5) - output = _extract_fenced_output("".join(_output_chunks)) - return {"output": output, "returncode": proc.returncode} - - except Exception as e: - return {"output": f"Execution error: {str(e)}", "returncode": 1} - - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - if self.persistent: - return self._execute_persistent( - command, cwd, timeout=timeout, stdin_data=stdin_data, - ) - return self._execute_oneshot( - command, cwd, timeout=timeout, stdin_data=stdin_data, + proc = subprocess.Popen( + [user_shell, "-lic", fenced_cmd], + text=True, + cwd=work_dir, + env=run_env, + encoding="utf-8", + errors="replace", + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL, + preexec_fn=None if _IS_WINDOWS else os.setsid, ) - def cleanup(self): - if self.persistent: - self._cleanup_persistent_shell() + if effective_stdin is not None: + def _write_stdin(): + try: + proc.stdin.write(effective_stdin) + proc.stdin.close() + except (BrokenPipeError, OSError): + pass + threading.Thread(target=_write_stdin, daemon=True).start() + + _output_chunks: list[str] = [] + + def _drain_stdout(): + try: + for line in proc.stdout: + _output_chunks.append(line) + except ValueError: + pass + finally: + try: + proc.stdout.close() + except Exception: + pass + + reader = threading.Thread(target=_drain_stdout, daemon=True) + reader.start() + deadline = time.monotonic() + effective_timeout + + while proc.poll() is None: + if is_interrupted(): + try: + if _IS_WINDOWS: + proc.terminate() + else: + pgid = os.getpgid(proc.pid) + os.killpg(pgid, signal.SIGTERM) + try: + proc.wait(timeout=1.0) + except subprocess.TimeoutExpired: + os.killpg(pgid, signal.SIGKILL) + except (ProcessLookupError, PermissionError): + proc.kill() + reader.join(timeout=2) + return { + "output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]", + "returncode": 130, + } + if time.monotonic() > deadline: + try: + if _IS_WINDOWS: + proc.terminate() + else: + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) + except (ProcessLookupError, PermissionError): + proc.kill() + reader.join(timeout=2) + return self._timeout_result(effective_timeout) + time.sleep(0.2) + + reader.join(timeout=5) + output = _extract_fenced_output("".join(_output_chunks)) + return {"output": output, "returncode": proc.returncode} diff --git a/tools/environments/persistent_shell.py b/tools/environments/persistent_shell.py index dd560a93b0..df1a78ef91 100644 --- a/tools/environments/persistent_shell.py +++ b/tools/environments/persistent_shell.py @@ -17,9 +17,11 @@ class PersistentShellMixin: """Mixin that adds persistent shell capability to any BaseEnvironment. Subclasses must implement ``_spawn_shell_process()``, ``_read_temp_files()``, - ``_kill_shell_children()``, and ``_execute_oneshot()`` (stdin fallback). + ``_kill_shell_children()``, ``_execute_oneshot()``, and ``_cleanup_temp_files()``. """ + persistent: bool + @abstractmethod def _spawn_shell_process(self) -> subprocess.Popen: ... @@ -43,15 +45,16 @@ class PersistentShellMixin: def _temp_prefix(self) -> str: return f"/tmp/hermes-persistent-{self._session_id}" + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + def _init_persistent_shell(self): self._shell_lock = threading.Lock() - self._session_id: str = "" self._shell_proc: subprocess.Popen | None = None self._shell_alive: bool = False self._shell_pid: int | None = None - self._start_persistent_shell() - def _start_persistent_shell(self): self._session_id = uuid.uuid4().hex[:12] p = self._temp_prefix self._pshell_stdout = f"{p}-stdout" @@ -98,6 +101,52 @@ class PersistentShellMixin: if reported_cwd: self.cwd = reported_cwd + def _cleanup_persistent_shell(self): + if self._shell_proc is None: + return + + if self._session_id: + self._cleanup_temp_files() + + try: + self._shell_proc.stdin.close() + except Exception: + pass + try: + self._shell_proc.terminate() + self._shell_proc.wait(timeout=3) + except subprocess.TimeoutExpired: + self._shell_proc.kill() + + self._shell_alive = False + self._shell_proc = None + + if hasattr(self, "_drain_thread") and self._drain_thread.is_alive(): + self._drain_thread.join(timeout=1.0) + + # ------------------------------------------------------------------ + # execute() / cleanup() — shared dispatcher, subclasses inherit + # ------------------------------------------------------------------ + + def execute(self, command: str, cwd: str = "", *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: + if self.persistent: + return self._execute_persistent( + command, cwd, timeout=timeout, stdin_data=stdin_data, + ) + return self._execute_oneshot( + command, cwd, timeout=timeout, stdin_data=stdin_data, + ) + + def cleanup(self): + if self.persistent: + self._cleanup_persistent_shell() + + # ------------------------------------------------------------------ + # Shell I/O + # ------------------------------------------------------------------ + def _drain_shell_output(self): try: for _ in self._shell_proc.stdout: @@ -130,12 +179,16 @@ class PersistentShellMixin: exit_code = 1 return output, exit_code, cwd.strip() + # ------------------------------------------------------------------ + # Execution + # ------------------------------------------------------------------ + def _execute_persistent(self, command: str, cwd: str, *, timeout: int | None = None, stdin_data: str | None = None) -> dict: if not self._shell_alive: logger.info("Persistent shell died, restarting...") - self._start_persistent_shell() + self._init_persistent_shell() exec_command, sudo_stdin = self._prepare_command(command) effective_timeout = timeout or self.timeout @@ -216,27 +269,3 @@ class PersistentShellMixin: if stderr.strip(): parts.append(stderr.rstrip("\n")) return "\n".join(parts) - - def _cleanup_persistent_shell(self): - if self._shell_proc is None: - return - - if self._session_id: - self._cleanup_temp_files() - - try: - self._shell_proc.stdin.close() - except Exception: - pass - try: - self._shell_proc.terminate() - self._shell_proc.wait(timeout=3) - except subprocess.TimeoutExpired: - self._shell_proc.kill() - - self._shell_alive = False - self._shell_proc = None - - if hasattr(self, "_drain_thread") and self._drain_thread.is_alive(): - self._drain_thread.join(timeout=1.0) - diff --git a/tools/environments/ssh.py b/tools/environments/ssh.py index 1bcc41ee76..c48b385093 100644 --- a/tools/environments/ssh.py +++ b/tools/environments/ssh.py @@ -130,11 +130,11 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): pass def _cleanup_temp_files(self): + cmd = self._build_ssh_command() + cmd.append(f"rm -f {self._temp_prefix}-*") try: - cmd = self._build_ssh_command() - cmd.append(f"rm -f {self._temp_prefix}-*") subprocess.run(cmd, capture_output=True, timeout=5) - except (OSError, subprocess.SubprocessError): + except (subprocess.TimeoutExpired, OSError): pass def _execute_oneshot(self, command: str, cwd: str = "", *, @@ -155,74 +155,58 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): cmd = self._build_ssh_command() cmd.append(wrapped) - try: - kwargs = self._build_run_kwargs(timeout, effective_stdin) - kwargs.pop("timeout", None) - _output_chunks = [] - proc = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL, - text=True, - ) - - if effective_stdin: - try: - proc.stdin.write(effective_stdin) - proc.stdin.close() - except Exception: - pass - - def _drain(): - try: - for line in proc.stdout: - _output_chunks.append(line) - except Exception: - pass - - reader = threading.Thread(target=_drain, daemon=True) - reader.start() - deadline = time.monotonic() + effective_timeout - - while proc.poll() is None: - if is_interrupted(): - proc.terminate() - try: - proc.wait(timeout=1) - except subprocess.TimeoutExpired: - proc.kill() - reader.join(timeout=2) - return { - "output": "".join(_output_chunks) + "\n[Command interrupted]", - "returncode": 130, - } - if time.monotonic() > deadline: - proc.kill() - reader.join(timeout=2) - return self._timeout_result(effective_timeout) - time.sleep(0.2) - - reader.join(timeout=5) - return {"output": "".join(_output_chunks), "returncode": proc.returncode} - - except Exception as e: - return {"output": f"SSH execution error: {str(e)}", "returncode": 1} - - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - if self.persistent: - return self._execute_persistent( - command, cwd, timeout=timeout, stdin_data=stdin_data - ) - return self._execute_oneshot( - command, cwd, timeout=timeout, stdin_data=stdin_data + kwargs = self._build_run_kwargs(timeout, effective_stdin) + kwargs.pop("timeout", None) + _output_chunks = [] + proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL, + text=True, ) + if effective_stdin: + try: + proc.stdin.write(effective_stdin) + proc.stdin.close() + except (BrokenPipeError, OSError): + pass + + def _drain(): + try: + for line in proc.stdout: + _output_chunks.append(line) + except Exception: + pass + + reader = threading.Thread(target=_drain, daemon=True) + reader.start() + deadline = time.monotonic() + effective_timeout + + while proc.poll() is None: + if is_interrupted(): + proc.terminate() + try: + proc.wait(timeout=1) + except subprocess.TimeoutExpired: + proc.kill() + reader.join(timeout=2) + return { + "output": "".join(_output_chunks) + "\n[Command interrupted]", + "returncode": 130, + } + if time.monotonic() > deadline: + proc.kill() + reader.join(timeout=2) + return self._timeout_result(effective_timeout) + time.sleep(0.2) + + reader.join(timeout=5) + return {"output": "".join(_output_chunks), "returncode": proc.returncode} + def cleanup(self): - if self.persistent: - self._cleanup_persistent_shell() + super().cleanup() if self.control_socket.exists(): try: cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",