From d684d7ee7e07c7690e299c9863e659147bd9d17f Mon Sep 17 00:00:00 2001 From: alt-glitch Date: Wed, 8 Apr 2026 13:38:04 -0700 Subject: [PATCH 01/49] feat(environments): unified spawn-per-call execution layer Replace dual execution model (PersistentShellMixin + per-backend oneshot) with spawn-per-call + session snapshot for all backends except ManagedModal. Core changes: - Every command spawns a fresh bash process; session snapshot (env vars, functions, aliases) captured at init and re-sourced before each command - CWD persists via file-based read (local) or in-band stdout markers (remote) - ProcessHandle protocol + _ThreadedProcessHandle adapter for SDK backends - cancel_fn wired for Modal (sandbox.terminate) and Daytona (sandbox.stop) - Shared utilities extracted: _pipe_stdin, _popen_bash, _load_json_store, _save_json_store, _file_mtime_key, _SYNC_INTERVAL_SECONDS - Rate-limited file sync unified in base _before_execute() with _sync_files() hook - execute_oneshot() removed; all 11 call sites in code_execution_tool.py migrated to execute() - Daytona timeout wrapper replaced with SDK-native timeout parameter - persistent_shell.py deleted (291 lines) Backend-specific: - Local: process-group kill via os.killpg, file-based CWD read - Docker: -e env flags only on init_session, not per-command - SSH: shlex.quote transport, ControlMaster connection reuse - Singularity: apptainer exec with instance://, no forced --pwd - Modal: _AsyncWorker + _ThreadedProcessHandle, cancel_fn -> sandbox.terminate - Daytona: SDK-level timeout (not shell wrapper), cancel_fn -> sandbox.stop - ManagedModal: unchanged (gateway owns execution); docstring added explaining why --- tests/tools/test_base_environment.py | 174 ++++++ tests/tools/test_file_tools_live.py | 146 +---- tests/tools/test_local_persistent.py | 164 ------ tests/tools/test_managed_modal_environment.py | 8 +- tests/tools/test_threaded_process_handle.py | 144 +++++ tools/code_execution_tool.py | 26 +- tools/environments/base.py | 556 ++++++++++++++++-- tools/environments/daytona.py | 181 ++---- tools/environments/docker.py | 175 ++---- tools/environments/local.py | 352 +++-------- tools/environments/managed_modal.py | 2 +- tools/environments/modal.py | 233 +++----- .../{modal_common.py => modal_utils.py} | 12 +- tools/environments/persistent_shell.py | 290 --------- tools/environments/singularity.py | 200 ++----- tools/environments/ssh.py | 188 +----- tools/terminal_tool.py | 5 +- 17 files changed, 1170 insertions(+), 1686 deletions(-) create mode 100644 tests/tools/test_base_environment.py delete mode 100644 tests/tools/test_local_persistent.py create mode 100644 tests/tools/test_threaded_process_handle.py rename tools/environments/{modal_common.py => modal_utils.py} (91%) delete mode 100644 tools/environments/persistent_shell.py diff --git a/tests/tools/test_base_environment.py b/tests/tools/test_base_environment.py new file mode 100644 index 00000000000..913ad0387c5 --- /dev/null +++ b/tests/tools/test_base_environment.py @@ -0,0 +1,174 @@ +"""Tests for BaseEnvironment unified execution model. + +Tests _wrap_command(), _extract_cwd_from_output(), _embed_stdin_heredoc(), +init_session() failure handling, and the CWD marker contract. +""" + +import uuid +from unittest.mock import MagicMock + +from tools.environments.base import BaseEnvironment, _cwd_marker + + +class _TestableEnv(BaseEnvironment): + """Concrete subclass for testing base class methods.""" + + def __init__(self, cwd="/tmp", timeout=10): + super().__init__(cwd=cwd, timeout=timeout) + + def _run_bash(self, cmd_string, *, login=False, timeout=120, stdin_data=None): + raise NotImplementedError("Use mock") + + def cleanup(self): + pass + + +class TestWrapCommand: + def test_basic_shape(self): + env = _TestableEnv() + env._snapshot_ready = True + wrapped = env._wrap_command("echo hello", "/tmp") + + assert "source" in wrapped + assert "cd /tmp" in wrapped or "cd '/tmp'" in wrapped + assert "eval 'echo hello'" in wrapped + assert "__hermes_ec=$?" in wrapped + assert "export -p >" in wrapped + assert "pwd -P >" in wrapped + assert env._cwd_marker in wrapped + assert "exit $__hermes_ec" in wrapped + + def test_no_snapshot_skips_source(self): + env = _TestableEnv() + env._snapshot_ready = False + wrapped = env._wrap_command("echo hello", "/tmp") + + assert "source" not in wrapped + + def test_single_quote_escaping(self): + env = _TestableEnv() + env._snapshot_ready = True + wrapped = env._wrap_command("echo 'hello world'", "/tmp") + + assert "eval 'echo '\\''hello world'\\'''" in wrapped + + def test_tilde_not_quoted(self): + env = _TestableEnv() + env._snapshot_ready = True + wrapped = env._wrap_command("ls", "~") + + assert "cd ~" in wrapped + assert "cd '~'" not in wrapped + + def test_cd_failure_exit_126(self): + env = _TestableEnv() + env._snapshot_ready = True + wrapped = env._wrap_command("ls", "/nonexistent") + + assert "exit 126" in wrapped + + +class TestExtractCwdFromOutput: + def test_happy_path(self): + env = _TestableEnv() + marker = env._cwd_marker + result = { + "output": f"hello\n{marker}/home/user{marker}\n", + } + env._extract_cwd_from_output(result) + + assert env.cwd == "/home/user" + assert marker not in result["output"] + + def test_missing_marker(self): + env = _TestableEnv() + result = {"output": "hello world\n"} + env._extract_cwd_from_output(result) + + assert env.cwd == "/tmp" # unchanged + + def test_marker_in_command_output(self): + """If the marker appears in command output AND as the real marker, + rfind grabs the last (real) one.""" + env = _TestableEnv() + marker = env._cwd_marker + result = { + "output": f"user typed {marker} in their output\nreal output\n{marker}/correct/path{marker}\n", + } + env._extract_cwd_from_output(result) + + assert env.cwd == "/correct/path" + + def test_output_cleaned(self): + env = _TestableEnv() + marker = env._cwd_marker + result = { + "output": f"hello\n{marker}/tmp{marker}\n", + } + env._extract_cwd_from_output(result) + + assert "hello" in result["output"] + assert marker not in result["output"] + + +class TestEmbedStdinHeredoc: + def test_heredoc_format(self): + result = BaseEnvironment._embed_stdin_heredoc("cat", "hello world") + + assert result.startswith("cat << '") + assert "hello world" in result + assert "HERMES_STDIN_" in result + + def test_unique_delimiter_each_call(self): + r1 = BaseEnvironment._embed_stdin_heredoc("cat", "data") + r2 = BaseEnvironment._embed_stdin_heredoc("cat", "data") + + # Extract delimiters + d1 = r1.split("'")[1] + d2 = r2.split("'")[1] + assert d1 != d2 # UUID-based, should be unique + + +class TestInitSessionFailure: + def test_snapshot_ready_false_on_failure(self): + env = _TestableEnv() + + def failing_run_bash(*args, **kwargs): + raise RuntimeError("bash not found") + + env._run_bash = failing_run_bash + env.init_session() + + assert env._snapshot_ready is False + + def test_login_flag_when_snapshot_not_ready(self): + """When _snapshot_ready=False, execute() should pass login=True to _run_bash.""" + env = _TestableEnv() + env._snapshot_ready = False + + calls = [] + def mock_run_bash(cmd, *, login=False, timeout=120, stdin_data=None): + calls.append({"login": login}) + # Return a mock process handle + mock = MagicMock() + mock.poll.return_value = 0 + mock.returncode = 0 + mock.stdout = iter([]) + return mock + + env._run_bash = mock_run_bash + env.execute("echo test") + + assert len(calls) == 1 + assert calls[0]["login"] is True + + +class TestCwdMarker: + def test_marker_contains_session_id(self): + env = _TestableEnv() + assert env._session_id in env._cwd_marker + + def test_unique_per_instance(self): + env1 = _TestableEnv() + env2 = _TestableEnv() + assert env1._cwd_marker != env2._cwd_marker diff --git a/tests/tools/test_file_tools_live.py b/tests/tools/test_file_tools_live.py index 4daf19a0305..6c3500eb88a 100644 --- a/tests/tools/test_file_tools_live.py +++ b/tests/tools/test_file_tools_live.py @@ -22,21 +22,19 @@ import pytest sys.path.insert(0, str(Path(__file__).resolve().parents[2])) -from tools.environments.local import ( - LocalEnvironment, - _clean_shell_noise, - _extract_fenced_output, - _OUTPUT_FENCE, - _SHELL_NOISE_SUBSTRINGS, -) +from tools.environments.local import LocalEnvironment from tools.file_operations import ShellFileOperations # ── Shared noise detection ─────────────────────────────────────────────── -# Every known shell noise pattern. If ANY of these appear in output that -# isn't explicitly expected, the test fails with a clear message. +# Known shell noise patterns that should never appear in command output. -_ALL_NOISE_PATTERNS = list(_SHELL_NOISE_SUBSTRINGS) + [ +_ALL_NOISE_PATTERNS = [ + "bash: cannot set terminal process group", + "bash: no job control in this shell", + "no job control in this shell", + "cannot set terminal process group", + "tcsetattr: Inappropriate ioctl for device", "bash: ", "Inappropriate ioctl", "Auto-suggestions:", @@ -88,134 +86,6 @@ def populated_dir(tmp_path): return tmp_path -# ── _clean_shell_noise unit tests ──────────────────────────────────────── - -class TestCleanShellNoise: - def test_single_noise_line(self): - output = "bash: no job control in this shell\nhello world\n" - result = _clean_shell_noise(output) - assert result == "hello world\n" - - def test_double_noise_lines(self): - output = ( - "bash: cannot set terminal process group (-1): Inappropriate ioctl for device\n" - "bash: no job control in this shell\n" - "actual output here\n" - ) - result = _clean_shell_noise(output) - assert result == "actual output here\n" - _assert_clean(result) - - def test_tcsetattr_noise(self): - output = ( - "bash: [12345: 2 (255)] tcsetattr: Inappropriate ioctl for device\n" - "real content\n" - ) - result = _clean_shell_noise(output) - assert result == "real content\n" - _assert_clean(result) - - def test_triple_noise_lines(self): - output = ( - "bash: cannot set terminal process group (-1): Inappropriate ioctl for device\n" - "bash: no job control in this shell\n" - "bash: [999: 2 (255)] tcsetattr: Inappropriate ioctl for device\n" - "clean\n" - ) - result = _clean_shell_noise(output) - assert result == "clean\n" - - def test_no_noise_untouched(self): - assert _clean_shell_noise("hello\nworld\n") == "hello\nworld\n" - - def test_empty_string(self): - assert _clean_shell_noise("") == "" - - def test_only_noise_produces_empty(self): - output = "bash: no job control in this shell\n" - result = _clean_shell_noise(output) - _assert_clean(result) - - def test_noise_in_middle_not_stripped(self): - """Noise in the middle is real output and should be preserved.""" - output = "real\nbash: no job control in this shell\nmore real\n" - result = _clean_shell_noise(output) - assert result == output - - def test_zsh_restored_session(self): - output = "Restored session: Mon Mar 2 22:16:54 +03 2026\nhello\n" - result = _clean_shell_noise(output) - assert result == "hello\n" - - def test_zsh_saving_session_trailing(self): - output = "hello\nSaving session...completed.\n" - result = _clean_shell_noise(output) - assert result == "hello\n" - - def test_zsh_oh_my_zsh_banner(self): - output = "Oh My Zsh on! | Auto-suggestions: press right\nhello\n" - result = _clean_shell_noise(output) - assert result == "hello\n" - - def test_zsh_full_noise_sandwich(self): - """Both leading and trailing zsh noise stripped.""" - output = ( - "Restored session: Mon Mar 2\n" - "command not found: docker\n" - "Oh My Zsh on!\n" - "actual output\n" - "Saving session...completed.\n" - ) - result = _clean_shell_noise(output) - assert result == "actual output\n" - - def test_last_login_stripped(self): - output = "Last login: Mon Mar 2 22:00:00 on ttys001\nhello\n" - result = _clean_shell_noise(output) - assert result == "hello\n" - - -# ── _extract_fenced_output unit tests ──────────────────────────────────── - -class TestExtractFencedOutput: - def test_normal_fenced_output(self): - raw = f"noise\n{_OUTPUT_FENCE}hello world\n{_OUTPUT_FENCE}more noise\n" - assert _extract_fenced_output(raw) == "hello world\n" - - def test_no_trailing_newline(self): - """printf output with no trailing newline is preserved.""" - raw = f"noise{_OUTPUT_FENCE}exact{_OUTPUT_FENCE}noise" - assert _extract_fenced_output(raw) == "exact" - - def test_no_fences_falls_back(self): - """Without fences, falls back to pattern-based cleaning.""" - raw = "bash: no job control in this shell\nhello\n" - result = _extract_fenced_output(raw) - assert result == "hello\n" - - def test_only_start_fence(self): - """Only start fence (e.g. user command called exit).""" - raw = f"noise{_OUTPUT_FENCE}hello\nSaving session...\n" - result = _extract_fenced_output(raw) - assert result == "hello\n" - - def test_user_outputs_fence_string(self): - """If user command outputs the fence marker, it is preserved.""" - raw = f"noise{_OUTPUT_FENCE}{_OUTPUT_FENCE}real\n{_OUTPUT_FENCE}noise" - result = _extract_fenced_output(raw) - # first fence -> last fence captures the middle including user's fence - assert _OUTPUT_FENCE in result - assert "real\n" in result - - def test_empty_command_output(self): - raw = f"noise{_OUTPUT_FENCE}{_OUTPUT_FENCE}noise" - assert _extract_fenced_output(raw) == "" - - def test_multiline_output(self): - raw = f"noise\n{_OUTPUT_FENCE}line1\nline2\nline3\n{_OUTPUT_FENCE}noise\n" - assert _extract_fenced_output(raw) == "line1\nline2\nline3\n" - - # ── LocalEnvironment.execute() ─────────────────────────────────────────── class TestLocalEnvironmentExecute: diff --git a/tests/tools/test_local_persistent.py b/tests/tools/test_local_persistent.py deleted file mode 100644 index 5b9ce2e2380..00000000000 --- a/tests/tools/test_local_persistent.py +++ /dev/null @@ -1,164 +0,0 @@ -"""Tests for the local persistent shell backend.""" - -import glob as glob_mod - -import pytest - -from tools.environments.local import LocalEnvironment -from tools.environments.persistent_shell import PersistentShellMixin - - -class TestLocalConfig: - def test_local_persistent_default_false(self, monkeypatch): - monkeypatch.delenv("TERMINAL_LOCAL_PERSISTENT", raising=False) - from tools.terminal_tool import _get_env_config - assert _get_env_config()["local_persistent"] is False - - def test_local_persistent_true(self, monkeypatch): - monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "true") - from tools.terminal_tool import _get_env_config - assert _get_env_config()["local_persistent"] is True - - def test_local_persistent_yes(self, monkeypatch): - monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "yes") - from tools.terminal_tool import _get_env_config - assert _get_env_config()["local_persistent"] is True - - -class TestMergeOutput: - def test_stdout_only(self): - assert PersistentShellMixin._merge_output("out", "") == "out" - - def test_stderr_only(self): - assert PersistentShellMixin._merge_output("", "err") == "err" - - def test_both(self): - assert PersistentShellMixin._merge_output("out", "err") == "out\nerr" - - def test_empty(self): - assert PersistentShellMixin._merge_output("", "") == "" - - def test_strips_trailing_newlines(self): - assert PersistentShellMixin._merge_output("out\n\n", "err\n") == "out\nerr" - - -class TestLocalOneShotRegression: - def test_echo(self): - env = LocalEnvironment(persistent=False) - r = env.execute("echo hello") - assert r["returncode"] == 0 - assert "hello" in r["output"] - env.cleanup() - - def test_exit_code(self): - env = LocalEnvironment(persistent=False) - r = env.execute("exit 42") - assert r["returncode"] == 42 - env.cleanup() - - def test_state_does_not_persist(self): - env = LocalEnvironment(persistent=False) - env.execute("export HERMES_ONESHOT_LOCAL=yes") - r = env.execute("echo $HERMES_ONESHOT_LOCAL") - assert r["output"].strip() == "" - env.cleanup() - - def test_oneshot_heredoc_does_not_leak_fence_wrapper(self): - """Heredoc closing line must not be merged with the fence wrapper tail.""" - env = LocalEnvironment(persistent=False) - cmd = "cat <<'H_EOF'\nheredoc body line\nH_EOF" - r = env.execute(cmd) - env.cleanup() - assert r["returncode"] == 0 - assert "heredoc body line" in r["output"] - assert "__hermes_rc" not in r["output"] - assert "printf '" not in r["output"] - assert "exit $" not in r["output"] - - -class TestLocalPersistent: - @pytest.fixture - def env(self): - e = LocalEnvironment(persistent=True) - yield e - e.cleanup() - - def test_echo(self, env): - r = env.execute("echo hello-persistent") - assert r["returncode"] == 0 - assert "hello-persistent" in r["output"] - - def test_env_var_persists(self, env): - env.execute("export HERMES_LOCAL_PERSIST_TEST=works") - r = env.execute("echo $HERMES_LOCAL_PERSIST_TEST") - assert r["output"].strip() == "works" - - def test_cwd_persists(self, env): - env.execute("cd /tmp") - r = env.execute("pwd") - assert r["output"].strip() == "/tmp" - - def test_exit_code(self, env): - r = env.execute("(exit 42)") - assert r["returncode"] == 42 - - def test_stderr(self, env): - r = env.execute("echo oops >&2") - assert r["returncode"] == 0 - assert "oops" in r["output"] - - def test_multiline_output(self, env): - r = env.execute("echo a; echo b; echo c") - lines = r["output"].strip().splitlines() - assert lines == ["a", "b", "c"] - - def test_timeout_then_recovery(self, env): - r = env.execute("sleep 999", timeout=2) - assert r["returncode"] in (124, 130) - r = env.execute("echo alive") - assert r["returncode"] == 0 - assert "alive" in r["output"] - - def test_large_output(self, env): - r = env.execute("seq 1 1000") - assert r["returncode"] == 0 - lines = r["output"].strip().splitlines() - assert len(lines) == 1000 - assert lines[0] == "1" - assert lines[-1] == "1000" - - def test_shell_variable_persists(self, env): - env.execute("MY_LOCAL_VAR=hello123") - r = env.execute("echo $MY_LOCAL_VAR") - assert r["output"].strip() == "hello123" - - def test_cleanup_removes_temp_files(self, env): - env.execute("echo warmup") - prefix = env._temp_prefix - assert len(glob_mod.glob(f"{prefix}-*")) > 0 - env.cleanup() - remaining = glob_mod.glob(f"{prefix}-*") - assert remaining == [] - - def test_state_does_not_leak_between_instances(self): - env1 = LocalEnvironment(persistent=True) - env2 = LocalEnvironment(persistent=True) - try: - env1.execute("export LEAK_TEST=from_env1") - r = env2.execute("echo $LEAK_TEST") - assert r["output"].strip() == "" - finally: - env1.cleanup() - env2.cleanup() - - def test_special_characters_in_command(self, env): - r = env.execute("echo 'hello world'") - assert r["output"].strip() == "hello world" - - def test_pipe_command(self, env): - r = env.execute("echo hello | tr 'h' 'H'") - assert r["output"].strip() == "Hello" - - def test_multiple_commands_semicolon(self, env): - r = env.execute("X=42; echo $X") - assert r["output"].strip() == "42" diff --git a/tests/tools/test_managed_modal_environment.py b/tests/tools/test_managed_modal_environment.py index ded9cd3d4ba..1d7241e0b73 100644 --- a/tests/tools/test_managed_modal_environment.py +++ b/tests/tools/test_managed_modal_environment.py @@ -110,7 +110,7 @@ class _FakeResponse: def test_managed_modal_execute_polls_until_completed(monkeypatch): _install_fake_tools_package() managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py") - modal_common = sys.modules["tools.environments.modal_common"] + modal_common = sys.modules["tools.environments.modal_utils"] calls = [] poll_count = {"value": 0} @@ -173,7 +173,7 @@ def test_managed_modal_create_sends_a_stable_idempotency_key(monkeypatch): def test_managed_modal_execute_cancels_on_interrupt(monkeypatch): interrupt_event = _install_fake_tools_package() managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py") - modal_common = sys.modules["tools.environments.modal_common"] + modal_common = sys.modules["tools.environments.modal_utils"] calls = [] @@ -215,7 +215,7 @@ def test_managed_modal_execute_cancels_on_interrupt(monkeypatch): def test_managed_modal_execute_returns_descriptive_error_on_missing_exec(monkeypatch): _install_fake_tools_package() managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py") - modal_common = sys.modules["tools.environments.modal_common"] + modal_common = sys.modules["tools.environments.modal_utils"] def fake_request(method, url, headers=None, json=None, timeout=None): if method == "POST" and url.endswith("/v1/sandboxes"): @@ -293,7 +293,7 @@ def test_managed_modal_rejects_host_credential_passthrough(): def test_managed_modal_execute_times_out_and_cancels(monkeypatch): _install_fake_tools_package() managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py") - modal_common = sys.modules["tools.environments.modal_common"] + modal_common = sys.modules["tools.environments.modal_utils"] calls = [] monotonic_values = iter([0.0, 12.5]) diff --git a/tests/tools/test_threaded_process_handle.py b/tests/tools/test_threaded_process_handle.py new file mode 100644 index 00000000000..4e6fbdb0d61 --- /dev/null +++ b/tests/tools/test_threaded_process_handle.py @@ -0,0 +1,144 @@ +"""Tests for _ThreadedProcessHandle — the adapter for SDK backends.""" + +import threading +import time + +from tools.environments.base import _ThreadedProcessHandle + + +class TestBasicExecution: + def test_successful_execution(self): + def exec_fn(): + return ("hello world", 0) + + handle = _ThreadedProcessHandle(exec_fn) + handle.wait(timeout=5) + + assert handle.returncode == 0 + output = handle.stdout.read() + assert "hello world" in output + + def test_nonzero_exit_code(self): + def exec_fn(): + return ("error occurred", 42) + + handle = _ThreadedProcessHandle(exec_fn) + handle.wait(timeout=5) + + assert handle.returncode == 42 + output = handle.stdout.read() + assert "error occurred" in output + + def test_exception_in_exec_fn(self): + def exec_fn(): + raise RuntimeError("boom") + + handle = _ThreadedProcessHandle(exec_fn) + handle.wait(timeout=5) + + assert handle.returncode == 1 + + def test_empty_output(self): + def exec_fn(): + return ("", 0) + + handle = _ThreadedProcessHandle(exec_fn) + handle.wait(timeout=5) + + assert handle.returncode == 0 + output = handle.stdout.read() + assert output == "" + + +class TestPolling: + def test_poll_returns_none_while_running(self): + event = threading.Event() + + def exec_fn(): + event.wait(timeout=5) + return ("done", 0) + + handle = _ThreadedProcessHandle(exec_fn) + assert handle.poll() is None + + event.set() + handle.wait(timeout=5) + assert handle.poll() == 0 + + def test_poll_returns_returncode_when_done(self): + def exec_fn(): + return ("ok", 0) + + handle = _ThreadedProcessHandle(exec_fn) + handle.wait(timeout=5) + assert handle.poll() == 0 + + +class TestCancelFn: + def test_cancel_fn_called_on_kill(self): + called = threading.Event() + + def cancel(): + called.set() + + def exec_fn(): + time.sleep(10) + return ("", 0) + + handle = _ThreadedProcessHandle(exec_fn, cancel_fn=cancel) + handle.kill() + assert called.is_set() + + def test_cancel_fn_none_is_safe(self): + def exec_fn(): + return ("ok", 0) + + handle = _ThreadedProcessHandle(exec_fn, cancel_fn=None) + handle.kill() # should not raise + handle.wait(timeout=5) + assert handle.returncode == 0 + + def test_cancel_fn_exception_swallowed(self): + def cancel(): + raise RuntimeError("cancel failed") + + def exec_fn(): + return ("ok", 0) + + handle = _ThreadedProcessHandle(exec_fn, cancel_fn=cancel) + handle.kill() # should not raise despite cancel raising + handle.wait(timeout=5) + + +class TestStdoutPipe: + def test_stdout_is_readable(self): + def exec_fn(): + return ("line1\nline2\nline3\n", 0) + + handle = _ThreadedProcessHandle(exec_fn) + handle.wait(timeout=5) + + lines = handle.stdout.readlines() + assert len(lines) == 3 + assert lines[0] == "line1\n" + + def test_stdout_iterable(self): + def exec_fn(): + return ("a\nb\nc\n", 0) + + handle = _ThreadedProcessHandle(exec_fn) + handle.wait(timeout=5) + + collected = list(handle.stdout) + assert len(collected) == 3 + + def test_unicode_output(self): + def exec_fn(): + return ("hello 世界 🌍\n", 0) + + handle = _ThreadedProcessHandle(exec_fn) + handle.wait(timeout=5) + + output = handle.stdout.read() + assert "世界" in output + assert "🌍" in output diff --git a/tools/code_execution_tool.py b/tools/code_execution_tool.py index aa4cd0863f2..f0d61210ffa 100644 --- a/tools/code_execution_tool.py +++ b/tools/code_execution_tool.py @@ -18,7 +18,7 @@ Architecture (two transports): 2. Parent ships both files to the remote environment 3. Script runs inside the terminal backend (Docker/SSH/Modal/Daytona/etc.) 4. Tool calls are written as request files; a polling thread on the parent - reads them via execute_oneshot(), dispatches, and writes response files + reads them via env.execute(), dispatches, and writes response files 5. The script polls for response files and continues In both cases, only the script's stdout is returned to the LLM; intermediate @@ -536,7 +536,7 @@ def _ship_file_to_remote(env, remote_path: str, content: str) -> None: quotes are fine. """ encoded = base64.b64encode(content.encode("utf-8")).decode("ascii") - env.execute_oneshot( + env.execute( f"echo '{encoded}' | base64 -d > {remote_path}", cwd="/", timeout=30, @@ -555,9 +555,9 @@ def _rpc_poll_loop( ): """Poll the remote filesystem for tool call requests and dispatch them. - Runs in a background thread. Uses ``env.execute_oneshot()`` so it can - operate concurrently with the script-execution thread that holds - ``env.execute()`` (important for persistent-shell backends like SSH). + Runs in a background thread. Each ``env.execute()`` spawns an + independent process, so these calls run safely concurrent with the + script-execution thread. """ from model_tools import handle_function_call @@ -566,7 +566,7 @@ def _rpc_poll_loop( while not stop_event.is_set(): try: # List pending request files (skip .tmp partials) - ls_result = env.execute_oneshot( + ls_result = env.execute( f"ls -1 {rpc_dir}/req_* 2>/dev/null || true", cwd="/", timeout=10, @@ -590,7 +590,7 @@ def _rpc_poll_loop( call_start = time.monotonic() # Read request - read_result = env.execute_oneshot( + read_result = env.execute( f"cat {req_file}", cwd="/", timeout=10, @@ -600,7 +600,7 @@ def _rpc_poll_loop( except (json.JSONDecodeError, ValueError): logger.debug("Malformed RPC request in %s", req_file) # Remove bad request to avoid infinite retry - env.execute_oneshot(f"rm -f {req_file}", cwd="/", timeout=5) + env.execute(f"rm -f {req_file}", cwd="/", timeout=5) continue tool_name = request.get("tool", "") @@ -664,7 +664,7 @@ def _rpc_poll_loop( encoded_result = base64.b64encode( tool_result.encode("utf-8") ).decode("ascii") - env.execute_oneshot( + env.execute( f"echo '{encoded_result}' | base64 -d > {res_file}.tmp" f" && mv {res_file}.tmp {res_file}", cwd="/", @@ -672,7 +672,7 @@ def _rpc_poll_loop( ) # Remove the request file - env.execute_oneshot(f"rm -f {req_file}", cwd="/", timeout=5) + env.execute(f"rm -f {req_file}", cwd="/", timeout=5) except Exception as e: if not stop_event.is_set(): @@ -717,7 +717,7 @@ def _execute_remote( try: # Verify Python is available on the remote - py_check = env.execute_oneshot( + py_check = env.execute( "command -v python3 >/dev/null 2>&1 && echo OK", cwd="/", timeout=15, ) @@ -734,7 +734,7 @@ def _execute_remote( }) # Create sandbox directory on remote - env.execute_oneshot( + env.execute( f"mkdir -p {sandbox_dir}/rpc", cwd="/", timeout=10, ) @@ -806,7 +806,7 @@ def _execute_remote( # Clean up remote sandbox dir try: - env.execute_oneshot( + env.execute( f"rm -rf {sandbox_dir}", cwd="/", timeout=15, ) except Exception: diff --git a/tools/environments/base.py b/tools/environments/base.py index 21b698ec0c4..31ce0e17de8 100644 --- a/tools/environments/base.py +++ b/tools/environments/base.py @@ -1,11 +1,27 @@ -"""Base class for all Hermes execution environment backends.""" +"""Base class for all Hermes execution environment backends. -from abc import ABC, abstractmethod +Unified spawn-per-call model: every command spawns a fresh ``bash -c`` process. +A session snapshot (env vars, functions, aliases) is captured once at init and +re-sourced before each command. CWD persists via in-band stdout markers (remote) +or a temp file (local). +""" + +import json +import logging import os +import shlex import subprocess +import threading +import time +import uuid +from abc import ABC, abstractmethod from pathlib import Path +from typing import IO, Callable, Protocol from hermes_constants import get_hermes_home +from tools.interrupt import is_interrupted + +logger = logging.getLogger(__name__) def get_sandbox_dir() -> Path: @@ -23,30 +39,501 @@ def get_sandbox_dir() -> Path: return p -class BaseEnvironment(ABC): - """Common interface for all Hermes execution backends. +# --------------------------------------------------------------------------- +# Shared constants and utilities +# --------------------------------------------------------------------------- - Subclasses implement execute() and cleanup(). Shared helpers eliminate - duplicated subprocess boilerplate across backends. +_SYNC_INTERVAL_SECONDS = 5.0 + + +def _pipe_stdin(proc: subprocess.Popen, data: str) -> None: + """Write *data* to proc.stdin on a daemon thread to avoid pipe-buffer deadlocks.""" + + def _write(): + try: + proc.stdin.write(data) + proc.stdin.close() + except (BrokenPipeError, OSError): + pass + + threading.Thread(target=_write, daemon=True).start() + + +def _popen_bash( + cmd: list[str], stdin_data: str | None = None, **kwargs +) -> subprocess.Popen: + """Spawn a subprocess with standard stdout/stderr/stdin setup. + + If *stdin_data* is provided, writes it asynchronously via :func:`_pipe_stdin`. + Backends with special Popen needs (e.g. local's ``preexec_fn``) can bypass + this and call :func:`_pipe_stdin` directly. """ + proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + stdin=subprocess.PIPE if stdin_data is not None else subprocess.DEVNULL, + text=True, + **kwargs, + ) + if stdin_data is not None: + _pipe_stdin(proc, stdin_data) + return proc + + +def _load_json_store(path: Path) -> dict: + """Load a JSON file as a dict, returning ``{}`` on any error.""" + if path.exists(): + try: + return json.loads(path.read_text()) + except Exception: + pass + return {} + + +def _save_json_store(path: Path, data: dict) -> None: + """Write *data* as pretty-printed JSON to *path*.""" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(data, indent=2)) + + +def _file_mtime_key(host_path: str) -> tuple[float, int] | None: + """Return ``(mtime, size)`` for cache comparison, or ``None`` if unreadable.""" + try: + st = Path(host_path).stat() + return (st.st_mtime, st.st_size) + except OSError: + return None + + +# --------------------------------------------------------------------------- +# ProcessHandle protocol +# --------------------------------------------------------------------------- + + +class ProcessHandle(Protocol): + """Duck type that every backend's _run_bash() must return. + + subprocess.Popen satisfies this natively. SDK backends (Modal, Daytona) + return _ThreadedProcessHandle which adapts their blocking calls. + """ + + def poll(self) -> int | None: ... + def kill(self) -> None: ... + def wait(self, timeout: float | None = None) -> int: ... + + @property + def stdout(self) -> IO[str] | None: ... + + @property + def returncode(self) -> int | None: ... + + +class _ThreadedProcessHandle: + """Adapter for SDK backends (Modal, Daytona) that have no real subprocess. + + Wraps a blocking ``exec_fn() -> (output_str, exit_code)`` in a background + thread and exposes a ProcessHandle-compatible interface. An optional + ``cancel_fn`` is invoked on ``kill()`` for backend-specific cancellation + (e.g. Modal sandbox.terminate, Daytona sandbox.stop). + """ + + def __init__( + self, + exec_fn: Callable[[], tuple[str, int]], + cancel_fn: Callable[[], None] | None = None, + ): + self._cancel_fn = cancel_fn + self._done = threading.Event() + self._returncode: int | None = None + self._error: Exception | None = None + + # Pipe for stdout — drain thread in _wait_for_process reads the read end. + read_fd, write_fd = os.pipe() + self._stdout = os.fdopen(read_fd, "r", encoding="utf-8", errors="replace") + self._write_fd = write_fd + + def _worker(): + try: + output, exit_code = exec_fn() + self._returncode = exit_code + # Write output into the pipe so drain thread picks it up. + try: + os.write(self._write_fd, output.encode("utf-8", errors="replace")) + except OSError: + pass + except Exception as exc: + self._error = exc + self._returncode = 1 + finally: + try: + os.close(self._write_fd) + except OSError: + pass + self._done.set() + + t = threading.Thread(target=_worker, daemon=True) + t.start() + + @property + def stdout(self): + return self._stdout + + @property + def returncode(self) -> int | None: + return self._returncode + + def poll(self) -> int | None: + return self._returncode if self._done.is_set() else None + + def kill(self): + if self._cancel_fn: + try: + self._cancel_fn() + except Exception: + pass + + def wait(self, timeout: float | None = None) -> int: + self._done.wait(timeout=timeout) + return self._returncode + + +# --------------------------------------------------------------------------- +# CWD marker for remote backends +# --------------------------------------------------------------------------- + + +def _cwd_marker(session_id: str) -> str: + return f"__HERMES_CWD_{session_id}__" + + +# --------------------------------------------------------------------------- +# BaseEnvironment +# --------------------------------------------------------------------------- + + +class BaseEnvironment(ABC): + """Common interface and unified execution flow for all Hermes backends. + + Subclasses implement ``_run_bash()`` and ``cleanup()``. The base class + provides ``execute()`` with session snapshot sourcing, CWD tracking, + interrupt handling, and timeout enforcement. + """ + + # Subclasses that embed stdin as a heredoc (Modal, Daytona) set this. + _stdin_mode: str = "pipe" # "pipe" or "heredoc" + + # Snapshot creation timeout (override for slow cold-starts). + _snapshot_timeout: int = 30 def __init__(self, cwd: str, timeout: int, env: dict = None): self.cwd = cwd self.timeout = timeout self.env = env or {} - @abstractmethod - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - """Execute a command, return {"output": str, "returncode": int}.""" - ... + self._session_id = uuid.uuid4().hex[:12] + self._snapshot_path = f"/tmp/hermes-snap-{self._session_id}.sh" + self._cwd_file = f"/tmp/hermes-cwd-{self._session_id}.txt" + self._cwd_marker = _cwd_marker(self._session_id) + self._snapshot_ready = False + self._last_sync_time: float | None = ( + None # set to 0 by backends that need file sync + ) + + # ------------------------------------------------------------------ + # Abstract methods + # ------------------------------------------------------------------ + + def _run_bash( + self, + cmd_string: str, + *, + login: bool = False, + timeout: int = 120, + stdin_data: str | None = None, + ) -> ProcessHandle: + """Spawn a bash process to run *cmd_string*. + + Returns a ProcessHandle (subprocess.Popen or _ThreadedProcessHandle). + Must be overridden by every backend. + """ + raise NotImplementedError(f"{type(self).__name__} must implement _run_bash()") @abstractmethod def cleanup(self): """Release backend resources (container, instance, connection).""" ... + # ------------------------------------------------------------------ + # Session snapshot (init_session) + # ------------------------------------------------------------------ + + def init_session(self): + """Capture login shell environment into a snapshot file. + + Called once after backend construction. On success, sets + ``_snapshot_ready = True`` so subsequent commands source the snapshot + instead of running with ``bash -l``. + """ + # Full capture: env vars, functions (filtered), aliases, shell options. + bootstrap = ( + f"export -p > {self._snapshot_path}\n" + f"declare -f | grep -vE '^_[^_]' >> {self._snapshot_path}\n" + f"alias -p >> {self._snapshot_path}\n" + f"echo 'shopt -s expand_aliases' >> {self._snapshot_path}\n" + f"echo 'set +e' >> {self._snapshot_path}\n" + f"echo 'set +u' >> {self._snapshot_path}\n" + f"pwd -P > {self._cwd_file} 2>/dev/null || true\n" + f"printf '\\n{self._cwd_marker}%s{self._cwd_marker}\\n' \"$(pwd -P)\"\n" + ) + try: + proc = self._run_bash(bootstrap, login=True, timeout=self._snapshot_timeout) + result = self._wait_for_process(proc, timeout=self._snapshot_timeout) + self._snapshot_ready = True + self._update_cwd(result) + logger.info( + "Session snapshot created (session=%s, cwd=%s)", + self._session_id, + self.cwd, + ) + except Exception as exc: + logger.warning( + "init_session failed (session=%s): %s — " + "falling back to bash -l per command", + self._session_id, + exc, + ) + self._snapshot_ready = False + + # ------------------------------------------------------------------ + # Command wrapping + # ------------------------------------------------------------------ + + def _wrap_command(self, command: str, cwd: str) -> str: + """Build the full bash script that sources snapshot, cd's, runs command, + re-dumps env vars, and emits CWD markers.""" + escaped = command.replace("'", "'\\''") + + parts = [] + + # Source snapshot (env vars from previous commands) + if self._snapshot_ready: + parts.append(f"source {self._snapshot_path} 2>/dev/null || true") + + # cd to working directory — let bash expand ~ natively + quoted_cwd = ( + shlex.quote(cwd) if cwd != "~" and not cwd.startswith("~/") else cwd + ) + parts.append(f"cd {quoted_cwd} || exit 126") + + # Run the actual command + parts.append(f"eval '{escaped}'") + parts.append("__hermes_ec=$?") + + # Re-dump env vars to snapshot (last-writer-wins for concurrent calls) + if self._snapshot_ready: + parts.append(f"export -p > {self._snapshot_path} 2>/dev/null || true") + + # Write CWD to file (local reads this) and stdout marker (remote parses this) + parts.append(f"pwd -P > {self._cwd_file} 2>/dev/null || true") + # Use a distinct line for the marker. The leading \n ensures + # the marker starts on its own line even if the command doesn't + # end with a newline (e.g. printf 'exact'). We'll strip this + # injected newline in _extract_cwd_from_output. + parts.append( + f"printf '\\n{self._cwd_marker}%s{self._cwd_marker}\\n' \"$(pwd -P)\"" + ) + parts.append("exit $__hermes_ec") + + return "\n".join(parts) + + # ------------------------------------------------------------------ + # Stdin heredoc embedding (for SDK backends) + # ------------------------------------------------------------------ + + @staticmethod + def _embed_stdin_heredoc(command: str, stdin_data: str) -> str: + """Append stdin_data as a shell heredoc to the command string.""" + delimiter = f"HERMES_STDIN_{uuid.uuid4().hex[:12]}" + return f"{command} << '{delimiter}'\n{stdin_data}\n{delimiter}" + + # ------------------------------------------------------------------ + # Process lifecycle + # ------------------------------------------------------------------ + + def _wait_for_process(self, proc: ProcessHandle, timeout: int = 120) -> dict: + """Poll-based wait with interrupt checking and stdout draining. + + Shared across all backends — not overridden. + """ + output_chunks: list[str] = [] + + def _drain(): + try: + for line in proc.stdout: + output_chunks.append(line) + except UnicodeDecodeError: + output_chunks.clear() + output_chunks.append( + "[binary output detected — raw bytes not displayable]" + ) + except (ValueError, OSError): + pass + + drain_thread = threading.Thread(target=_drain, daemon=True) + drain_thread.start() + deadline = time.monotonic() + timeout + + while proc.poll() is None: + if is_interrupted(): + self._kill_process(proc) + drain_thread.join(timeout=2) + return { + "output": "".join(output_chunks) + "\n[Command interrupted]", + "returncode": 130, + } + if time.monotonic() > deadline: + self._kill_process(proc) + drain_thread.join(timeout=2) + partial = "".join(output_chunks) + timeout_msg = f"\n[Command timed out after {timeout}s]" + return { + "output": partial + timeout_msg + if partial + else timeout_msg.lstrip(), + "returncode": 124, + } + time.sleep(0.2) + + drain_thread.join(timeout=5) + + try: + proc.stdout.close() + except Exception: + pass + + return {"output": "".join(output_chunks), "returncode": proc.returncode} + + def _kill_process(self, proc: ProcessHandle): + """Terminate a process. Subclasses may override for process-group kill.""" + try: + proc.kill() + except (ProcessLookupError, PermissionError, OSError): + pass + + # ------------------------------------------------------------------ + # CWD extraction + # ------------------------------------------------------------------ + + def _update_cwd(self, result: dict): + """Extract CWD from command output. Override for local file-based read.""" + self._extract_cwd_from_output(result) + + def _extract_cwd_from_output(self, result: dict): + """Parse the __HERMES_CWD_{session}__ marker from stdout output. + + Updates self.cwd and strips the marker from result["output"]. + Used by remote backends (Docker, SSH, Modal, Daytona, Singularity). + """ + output = result.get("output", "") + marker = self._cwd_marker + last = output.rfind(marker) + if last == -1: + return + + # Find the opening marker before this closing one + search_start = max(0, last - 4096) # CWD path won't be >4KB + first = output.rfind(marker, search_start, last) + if first == -1 or first == last: + return + + cwd_path = output[first + len(marker) : last].strip() + if cwd_path: + self.cwd = cwd_path + + # Strip the marker line AND the \n we injected before it. + # The wrapper emits: printf '\n__MARKER__%s__MARKER__\n' + # So the output looks like: \n__MARKER__path__MARKER__\n + # We want to remove everything from the injected \n onwards. + line_start = output.rfind("\n", 0, first) + if line_start == -1: + line_start = first + line_end = output.find("\n", last + len(marker)) + line_end = line_end + 1 if line_end != -1 else len(output) + + result["output"] = output[:line_start] + output[line_end:] + + # ------------------------------------------------------------------ + # Hooks + # ------------------------------------------------------------------ + + def _before_execute(self): + """Rate-limited file sync before each command. + + Backends that need pre-command sync set ``self._last_sync_time = 0`` + in ``__init__`` and override :meth:`_sync_files`. Backends needing + extra pre-exec logic (e.g. Daytona sandbox restart check) override + this method and call ``super()._before_execute()``. + """ + if self._last_sync_time is not None: + now = time.monotonic() + if now - self._last_sync_time >= _SYNC_INTERVAL_SECONDS: + self._sync_files() + self._last_sync_time = now + + def _sync_files(self): + """Push files to remote environment. Called rate-limited by _before_execute.""" + pass + + # ------------------------------------------------------------------ + # Unified execute() + # ------------------------------------------------------------------ + + def execute( + self, + command: str, + cwd: str = "", + *, + timeout: int | None = None, + stdin_data: str | None = None, + ) -> dict: + """Execute a command, return {"output": str, "returncode": int}.""" + self._before_execute() + + exec_command, sudo_stdin = self._prepare_command(command) + effective_timeout = timeout or self.timeout + effective_cwd = cwd or self.cwd + + # Merge sudo stdin with caller stdin + if sudo_stdin is not None and stdin_data is not None: + effective_stdin = sudo_stdin + stdin_data + elif sudo_stdin is not None: + effective_stdin = sudo_stdin + else: + effective_stdin = stdin_data + + # Embed stdin as heredoc for backends that need it + if effective_stdin and self._stdin_mode == "heredoc": + exec_command = self._embed_stdin_heredoc(exec_command, effective_stdin) + effective_stdin = None + + wrapped = self._wrap_command(exec_command, effective_cwd) + + # Use login shell if snapshot failed (so user's profile still loads) + login = not self._snapshot_ready + + proc = self._run_bash( + wrapped, login=login, timeout=effective_timeout, stdin_data=effective_stdin + ) + result = self._wait_for_process(proc, timeout=effective_timeout) + self._update_cwd(result) + + return result + + # ------------------------------------------------------------------ + # Shared helpers + # ------------------------------------------------------------------ + def stop(self): """Alias for cleanup (compat with older callers).""" self.cleanup() @@ -57,53 +544,12 @@ class BaseEnvironment(ABC): except Exception: pass - # ------------------------------------------------------------------ - # Shared helpers (eliminate duplication across backends) - # ------------------------------------------------------------------ - def _prepare_command(self, command: str) -> tuple[str, str | None]: - """Transform sudo commands if SUDO_PASSWORD is available. - - Returns: - (transformed_command, sudo_stdin) — see _transform_sudo_command - for the full contract. Callers that drive a subprocess directly - should prepend sudo_stdin (when not None) to any stdin_data they - pass to Popen. Callers that embed stdin via heredoc (modal, - daytona) handle sudo_stdin in their own execute() method. - """ + """Transform sudo commands if SUDO_PASSWORD is available.""" from tools.terminal_tool import _transform_sudo_command + return _transform_sudo_command(command) - def _build_run_kwargs(self, timeout: int | None, - stdin_data: str | None = None) -> dict: - """Build common subprocess.run kwargs for non-interactive execution.""" - kw = { - "text": True, - "timeout": timeout or self.timeout, - "encoding": "utf-8", - "errors": "replace", - "stdout": subprocess.PIPE, - "stderr": subprocess.STDOUT, - } - if stdin_data is not None: - kw["input"] = stdin_data - else: - kw["stdin"] = subprocess.DEVNULL - return kw - - def execute_oneshot(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - """Execute a command bypassing any persistent shell. - - Safe for concurrent use alongside a long-running execute() call. - Backends that maintain a persistent shell (SSH, Local) override this - to route through their oneshot path, avoiding the shell lock. - Non-persistent backends delegate to execute(). - """ - return self.execute(command, cwd=cwd, timeout=timeout, - stdin_data=stdin_data) - def _timeout_result(self, timeout: int | None) -> dict: """Standard return dict when a command times out.""" return { diff --git a/tools/environments/daytona.py b/tools/environments/daytona.py index e52459d8b5a..60958fd353e 100644 --- a/tools/environments/daytona.py +++ b/tools/environments/daytona.py @@ -6,17 +6,18 @@ and resumed on next creation, preserving the filesystem across sessions. """ import logging -import time import math import shlex import threading -import uuid import warnings from pathlib import Path from typing import Dict, Optional -from tools.environments.base import BaseEnvironment -from tools.interrupt import is_interrupted +from tools.environments.base import ( + BaseEnvironment, + _ThreadedProcessHandle, + _file_mtime_key, +) logger = logging.getLogger(__name__) @@ -24,22 +25,25 @@ logger = logging.getLogger(__name__) class DaytonaEnvironment(BaseEnvironment): """Daytona cloud sandbox execution backend. - Uses stopped/started sandbox lifecycle for filesystem persistence - instead of snapshots, making it faster and stateless on the host. + Spawn-per-call via _ThreadedProcessHandle wrapping blocking SDK calls. + cancel_fn wired to sandbox.stop() for interrupt support. + Shell timeout wrapper preserved (SDK timeout unreliable). """ + _stdin_mode = "heredoc" + def __init__( self, image: str, cwd: str = "/home/daytona", timeout: int = 60, cpu: int = 1, - memory: int = 5120, # MB (hermes convention) - disk: int = 10240, # MB (Daytona platform max is 10GB) + memory: int = 5120, + disk: int = 10240, persistent_filesystem: bool = True, task_id: str = "default", ): - self._requested_cwd = cwd + requested_cwd = cwd super().__init__(cwd=cwd, timeout=timeout) from daytona import ( @@ -53,16 +57,18 @@ class DaytonaEnvironment(BaseEnvironment): self._persistent = persistent_filesystem self._task_id = task_id self._SandboxState = SandboxState + self._DaytonaError = DaytonaError self._daytona = Daytona() self._sandbox = None self._lock = threading.Lock() + self._last_sync_time: float = 0 memory_gib = max(1, math.ceil(memory / 1024)) disk_gib = max(1, math.ceil(disk / 1024)) if disk_gib > 10: warnings.warn( f"Daytona: requested disk ({disk_gib}GB) exceeds platform limit (10GB). " - f"Capping to 10GB. Set container_disk: 10240 in config to silence this.", + f"Capping to 10GB.", stacklevel=2, ) disk_gib = 10 @@ -71,9 +77,7 @@ class DaytonaEnvironment(BaseEnvironment): labels = {"hermes_task_id": task_id} sandbox_name = f"hermes-{task_id}" - # Try to resume an existing sandbox for this task if self._persistent: - # 1. Try name-based lookup (new path) try: self._sandbox = self._daytona.get(sandbox_name) self._sandbox.start() @@ -86,7 +90,6 @@ class DaytonaEnvironment(BaseEnvironment): task_id, e) self._sandbox = None - # 2. Legacy fallback: find sandbox created before the naming migration if self._sandbox is None: try: page = self._daytona.list(labels=labels, page=1, limit=1) @@ -100,7 +103,6 @@ class DaytonaEnvironment(BaseEnvironment): task_id, e) self._sandbox = None - # Create a fresh sandbox if we don't have one if self._sandbox is None: self._sandbox = self._daytona.create( CreateSandboxFromImageParams( @@ -114,32 +116,25 @@ class DaytonaEnvironment(BaseEnvironment): logger.info("Daytona: created sandbox %s for task %s", self._sandbox.id, task_id) - # Detect remote home dir first so mounts go to the right place. + # Detect remote home dir self._remote_home = "/root" try: home = self._sandbox.process.exec("echo $HOME").result.strip() if home: self._remote_home = home - if self._requested_cwd in ("~", "/home/daytona"): + if requested_cwd in ("~", "/home/daytona"): self.cwd = home except Exception: pass logger.info("Daytona: resolved home to %s, cwd to %s", self._remote_home, self.cwd) - # Track synced files to avoid redundant uploads. - # Key: remote_path, Value: (mtime, size) self._synced_files: Dict[str, tuple] = {} - - # Upload credential files and skills directory into the sandbox. - self._sync_skills_and_credentials() + self._sync_files() + self.init_session() def _upload_if_changed(self, host_path: str, remote_path: str) -> bool: - """Upload a file if its mtime/size changed since last sync.""" - hp = Path(host_path) - try: - stat = hp.stat() - file_key = (stat.st_mtime, stat.st_size) - except OSError: + file_key = _file_mtime_key(host_path) + if file_key is None: return False if self._synced_files.get(remote_path) == file_key: return False @@ -153,20 +148,15 @@ class DaytonaEnvironment(BaseEnvironment): logger.debug("Daytona: upload failed %s: %s", host_path, e) return False - def _sync_skills_and_credentials(self) -> None: - """Upload changed credential files and skill files into the sandbox.""" + def _sync_files(self) -> None: container_base = f"{self._remote_home}/.hermes" try: from tools.credential_files import get_credential_file_mounts, iter_skills_files - for mount_entry in get_credential_file_mounts(): remote_path = mount_entry["container_path"].replace("/root/.hermes", container_base, 1) - if self._upload_if_changed(mount_entry["host_path"], remote_path): - logger.debug("Daytona: synced credential %s", remote_path) - + self._upload_if_changed(mount_entry["host_path"], remote_path) for entry in iter_skills_files(container_base=container_base): - if self._upload_if_changed(entry["host_path"], entry["container_path"]): - logger.debug("Daytona: synced skill %s", entry["container_path"]) + self._upload_if_changed(entry["host_path"], entry["container_path"]) except Exception as e: logger.debug("Daytona: could not sync skills/credentials: %s", e) @@ -177,111 +167,36 @@ class DaytonaEnvironment(BaseEnvironment): self._sandbox.start() logger.info("Daytona: restarted sandbox %s", self._sandbox.id) - def _exec_in_thread(self, exec_command: str, cwd: Optional[str], timeout: int) -> dict: - """Run exec in a background thread with interrupt polling. - - The Daytona SDK's exec(timeout=...) parameter is unreliable (the - server-side timeout is not enforced and the SDK has no client-side - fallback), so we wrap the command with the shell ``timeout`` utility - which reliably kills the process and returns exit code 124. - """ - # Wrap with shell `timeout` to enforce the deadline reliably. - # Add a small buffer so the shell timeout fires before any SDK-level - # timeout would, giving us a clean exit code 124. - timed_command = f"timeout {timeout} sh -c {shlex.quote(exec_command)}" - - result_holder: dict = {"value": None, "error": None} - - def _run(): - try: - response = self._sandbox.process.exec( - timed_command, cwd=cwd, - ) - result_holder["value"] = { - "output": response.result or "", - "returncode": response.exit_code, - } - except Exception as e: - result_holder["error"] = e - - t = threading.Thread(target=_run, daemon=True) - t.start() - # Wait for timeout + generous buffer for network/SDK overhead - deadline = time.monotonic() + timeout + 10 - while t.is_alive(): - t.join(timeout=0.2) - if is_interrupted(): - with self._lock: - try: - self._sandbox.stop() - except Exception: - pass - return { - "output": "[Command interrupted - Daytona sandbox stopped]", - "returncode": 130, - } - if time.monotonic() > deadline: - # Shell timeout didn't fire and SDK is hung — force stop - with self._lock: - try: - self._sandbox.stop() - except Exception: - pass - return self._timeout_result(timeout) - - if result_holder["error"]: - return {"error": result_holder["error"]} - return result_holder["value"] - - def execute(self, command: str, cwd: str = "", *, - timeout: Optional[int] = None, - stdin_data: Optional[str] = None) -> dict: + def _before_execute(self): + """Ensure sandbox is ready, then rate-limited file sync via base class.""" with self._lock: self._ensure_sandbox_ready() - # Incremental sync before each command so mid-session credential - # refreshes and skill updates are picked up. - self._sync_skills_and_credentials() + super()._before_execute() - if stdin_data is not None: - marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}" - while marker in stdin_data: - marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}" - command = f"{command} << '{marker}'\n{stdin_data}\n{marker}" + def _run_bash(self, cmd_string: str, *, login: bool = False, + timeout: int = 120, + stdin_data: str | None = None): + """Return a _ThreadedProcessHandle wrapping a blocking Daytona SDK call.""" + sandbox = self._sandbox + lock = self._lock - exec_command, sudo_stdin = self._prepare_command(command) + def cancel(): + with lock: + try: + sandbox.stop() + except Exception: + pass - # Daytona sandboxes execute commands via the Daytona SDK and cannot - # pipe subprocess stdin directly the way a local Popen can. When a - # sudo password is present, use a shell-level pipe from printf so that - # the password feeds sudo -S without appearing as an echo argument - # embedded in the shell string. The password is still visible in the - # remote sandbox's command line, but it is not exposed on the user's - # local machine — which is the primary threat being mitigated. - if sudo_stdin is not None: - import shlex - exec_command = ( - f"printf '%s\\n' {shlex.quote(sudo_stdin.rstrip())} | {exec_command}" - ) - effective_cwd = cwd or self.cwd or None - effective_timeout = timeout or self.timeout + if login: + shell_cmd = f"bash -l -c {shlex.quote(cmd_string)}" + else: + shell_cmd = f"bash -c {shlex.quote(cmd_string)}" - result = self._exec_in_thread(exec_command, effective_cwd, effective_timeout) + def exec_fn() -> tuple[str, int]: + response = sandbox.process.exec(shell_cmd, timeout=timeout) + return (response.result or "", response.exit_code) - if "error" in result: - from daytona import DaytonaError - err = result["error"] - if isinstance(err, DaytonaError): - with self._lock: - try: - self._ensure_sandbox_ready() - except Exception: - return {"output": f"Daytona execution error: {err}", "returncode": 1} - result = self._exec_in_thread(exec_command, effective_cwd, effective_timeout) - if "error" not in result: - return result - return {"output": f"Daytona execution error: {err}", "returncode": 1} - - return result + return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel) def cleanup(self): with self._lock: diff --git a/tools/environments/docker.py b/tools/environments/docker.py index b97040d4e0b..59a23779612 100644 --- a/tools/environments/docker.py +++ b/tools/environments/docker.py @@ -8,18 +8,14 @@ persistence via bind mounts. import logging import os import re -import shlex import shutil import subprocess import sys -import threading -import time import uuid from typing import Optional -from tools.environments.base import BaseEnvironment +from tools.environments.base import BaseEnvironment, _popen_bash from tools.environments.local import _HERMES_PROVIDER_ENV_BLOCKLIST -from tools.interrupt import is_interrupted logger = logging.getLogger(__name__) @@ -431,6 +427,69 @@ class DockerEnvironment(BaseEnvironment): self._container_id = result.stdout.strip() logger.info(f"Started container {container_name} ({self._container_id[:12]})") + # Build the init-time env forwarding args (used only by init_session + # to inject host env vars into the snapshot; subsequent commands get + # them from the snapshot file). + self._init_env_args = self._build_init_env_args() + + # Initialize session snapshot inside the container + self.init_session() + + def _build_init_env_args(self) -> list[str]: + """Build -e KEY=VALUE args for injecting host env vars into init_session. + + These are used once during init_session() so that export -p captures + them into the snapshot. Subsequent execute() calls don't need -e flags. + """ + exec_env: dict[str, str] = dict(self._env) + + explicit_forward_keys = set(self._forward_env) + passthrough_keys: set[str] = set() + try: + from tools.env_passthrough import get_all_passthrough + passthrough_keys = set(get_all_passthrough()) + except Exception: + pass + # Explicit docker_forward_env entries are an intentional opt-in and must + # win over the generic Hermes secret blocklist. Only implicit passthrough + # keys are filtered. + forward_keys = explicit_forward_keys | (passthrough_keys - _HERMES_PROVIDER_ENV_BLOCKLIST) + hermes_env = _load_hermes_env_vars() if forward_keys else {} + for key in sorted(forward_keys): + value = os.getenv(key) + if value is None: + value = hermes_env.get(key) + if value is not None: + exec_env[key] = value + + args = [] + for key in sorted(exec_env): + args.extend(["-e", f"{key}={exec_env[key]}"]) + return args + + def _run_bash(self, cmd_string: str, *, login: bool = False, + timeout: int = 120, + stdin_data: str | None = None) -> subprocess.Popen: + """Spawn a bash process inside the Docker container.""" + assert self._container_id, "Container not started" + cmd = [self._docker_exe, "exec"] + if stdin_data is not None: + cmd.append("-i") + + # Only inject -e env args during init_session (login=True). + # Subsequent commands get env vars from the snapshot. + if login: + cmd.extend(self._init_env_args) + + cmd.extend([self._container_id]) + + if login: + cmd.extend(["bash", "-l", "-c", cmd_string]) + else: + cmd.extend(["bash", "-c", cmd_string]) + + return _popen_bash(cmd, stdin_data) + @staticmethod def _storage_opt_supported() -> bool: """Check if Docker's storage driver supports --storage-opt size=. @@ -471,112 +530,6 @@ class DockerEnvironment(BaseEnvironment): logger.debug("Docker --storage-opt support: %s", _storage_opt_ok) return _storage_opt_ok - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - exec_command, sudo_stdin = self._prepare_command(command) - work_dir = cwd or self.cwd - effective_timeout = timeout or self.timeout - - # Merge sudo password (if any) with caller-supplied stdin_data. - if sudo_stdin is not None and stdin_data is not None: - effective_stdin = sudo_stdin + stdin_data - elif sudo_stdin is not None: - effective_stdin = sudo_stdin - else: - effective_stdin = stdin_data - - # docker exec -w doesn't expand ~, so prepend a cd into the command. - # Keep ~ unquoted (for shell expansion) and quote only the subpath. - if work_dir == "~": - exec_command = f"cd ~ && {exec_command}" - work_dir = "/" - elif work_dir.startswith("~/"): - exec_command = f"cd ~/{shlex.quote(work_dir[2:])} && {exec_command}" - work_dir = "/" - - assert self._container_id, "Container not started" - cmd = [self._docker_exe, "exec"] - if effective_stdin is not None: - cmd.append("-i") - cmd.extend(["-w", work_dir]) - # Build the per-exec environment: start with explicit docker_env values - # (static config), then overlay docker_forward_env / skill env_passthrough - # (dynamic from host process). Forward values take precedence. - exec_env: dict[str, str] = dict(self._env) - - explicit_forward_keys = set(self._forward_env) - passthrough_keys: set[str] = set() - try: - from tools.env_passthrough import get_all_passthrough - passthrough_keys = set(get_all_passthrough()) - except Exception: - pass - # Explicit docker_forward_env entries are an intentional opt-in and must - # win over the generic Hermes secret blocklist. Only implicit passthrough - # keys are filtered. - forward_keys = explicit_forward_keys | (passthrough_keys - _HERMES_PROVIDER_ENV_BLOCKLIST) - hermes_env = _load_hermes_env_vars() if forward_keys else {} - for key in sorted(forward_keys): - value = os.getenv(key) - if value is None: - value = hermes_env.get(key) - if value is not None: - exec_env[key] = value - - for key in sorted(exec_env): - cmd.extend(["-e", f"{key}={exec_env[key]}"]) - cmd.extend([self._container_id, "bash", "-lc", exec_command]) - - try: - _output_chunks = [] - proc = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL, - text=True, - ) - if effective_stdin: - try: - proc.stdin.write(effective_stdin) - proc.stdin.close() - except Exception: - pass - - def _drain(): - try: - for line in proc.stdout: - _output_chunks.append(line) - except Exception: - pass - - reader = threading.Thread(target=_drain, daemon=True) - reader.start() - deadline = time.monotonic() + effective_timeout - - while proc.poll() is None: - if is_interrupted(): - proc.terminate() - try: - proc.wait(timeout=1) - except subprocess.TimeoutExpired: - proc.kill() - reader.join(timeout=2) - return { - "output": "".join(_output_chunks) + "\n[Command interrupted]", - "returncode": 130, - } - if time.monotonic() > deadline: - proc.kill() - reader.join(timeout=2) - return self._timeout_result(effective_timeout) - time.sleep(0.2) - - reader.join(timeout=5) - return {"output": "".join(_output_chunks), "returncode": proc.returncode} - except Exception as e: - return {"output": f"Docker execution error: {e}", "returncode": 1} - def cleanup(self): """Stop and remove the container. Bind-mount dirs persist if persistent=True.""" if self._container_id: diff --git a/tools/environments/local.py b/tools/environments/local.py index 27282b6ef67..d3bb3448291 100644 --- a/tools/environments/local.py +++ b/tools/environments/local.py @@ -1,42 +1,22 @@ -"""Local execution environment with interrupt support and non-blocking I/O.""" +"""Local execution environment — spawn-per-call with session snapshot.""" -import glob import os import platform import shutil import signal import subprocess -import threading -import time + +from tools.environments.base import BaseEnvironment, _pipe_stdin _IS_WINDOWS = platform.system() == "Windows" -from tools.environments.base import BaseEnvironment -from tools.environments.persistent_shell import PersistentShellMixin -from tools.interrupt import is_interrupted - -# Unique marker to isolate real command output from shell init/exit noise. -# printf (no trailing newline) keeps the boundaries clean for splitting. -_OUTPUT_FENCE = "__HERMES_FENCE_a9f7b3__" # Hermes-internal env vars that should NOT leak into terminal subprocesses. -# These are loaded from ~/.hermes/.env for Hermes' own LLM/provider calls -# but can break external CLIs (e.g. codex) that also honor them. -# See: https://github.com/NousResearch/hermes-agent/issues/1002 -# -# Built dynamically from the provider registry so new providers are -# automatically covered without manual blocklist maintenance. _HERMES_PROVIDER_ENV_FORCE_PREFIX = "_HERMES_FORCE_" def _build_provider_env_blocklist() -> frozenset: - """Derive the blocklist from provider, tool, and gateway config. - - Automatically picks up api_key_env_vars and base_url_env_var from - every registered provider, plus tool/messaging env vars from the - optional config registry, so new Hermes-managed secrets are blocked - in subprocesses without having to maintain multiple static lists. - """ + """Derive the blocklist from provider, tool, and gateway config.""" blocked: set[str] = set() try: @@ -59,33 +39,30 @@ def _build_provider_env_blocklist() -> frozenset: except ImportError: pass - # Vars not covered above but still Hermes-internal / conflict-prone. blocked.update({ "OPENAI_BASE_URL", "OPENAI_API_KEY", - "OPENAI_API_BASE", # legacy alias + "OPENAI_API_BASE", "OPENAI_ORG_ID", "OPENAI_ORGANIZATION", "OPENROUTER_API_KEY", "ANTHROPIC_BASE_URL", - "ANTHROPIC_TOKEN", # OAuth token (not in registry as env var) + "ANTHROPIC_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN", "LLM_MODEL", - # Expanded isolation for other major providers (Issue #1002) - "GOOGLE_API_KEY", # Gemini / Google AI Studio - "DEEPSEEK_API_KEY", # DeepSeek - "MISTRAL_API_KEY", # Mistral AI - "GROQ_API_KEY", # Groq - "TOGETHER_API_KEY", # Together AI - "PERPLEXITY_API_KEY", # Perplexity - "COHERE_API_KEY", # Cohere - "FIREWORKS_API_KEY", # Fireworks AI - "XAI_API_KEY", # xAI (Grok) - "HELICONE_API_KEY", # LLM Observability proxy + "GOOGLE_API_KEY", + "DEEPSEEK_API_KEY", + "MISTRAL_API_KEY", + "GROQ_API_KEY", + "TOGETHER_API_KEY", + "PERPLEXITY_API_KEY", + "COHERE_API_KEY", + "FIREWORKS_API_KEY", + "XAI_API_KEY", + "HELICONE_API_KEY", "PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "FIRECRAWL_API_URL", - # Gateway/runtime config not represented in OPTIONAL_ENV_VARS. "TELEGRAM_HOME_CHANNEL", "TELEGRAM_HOME_CHANNEL_NAME", "DISCORD_HOME_CHANNEL", @@ -115,12 +92,10 @@ def _build_provider_env_blocklist() -> frozenset: "EMAIL_HOME_ADDRESS", "EMAIL_HOME_ADDRESS_NAME", "GATEWAY_ALLOWED_USERS", - # Skills Hub / GitHub app auth paths and aliases. "GH_TOKEN", "GITHUB_APP_ID", "GITHUB_APP_PRIVATE_KEY_PATH", "GITHUB_APP_INSTALLATION_ID", - # Remote sandbox backend credentials. "MODAL_TOKEN_ID", "MODAL_TOKEN_SECRET", "DAYTONA_API_KEY", @@ -132,13 +107,7 @@ _HERMES_PROVIDER_ENV_BLOCKLIST = _build_provider_env_blocklist() def _sanitize_subprocess_env(base_env: dict | None, extra_env: dict | None = None) -> dict: - """Filter Hermes-managed secrets from a subprocess environment. - - `_HERMES_FORCE_` entries in ``extra_env`` opt a blocked variable back in - intentionally for callers that truly need it. Vars registered via - :mod:`tools.env_passthrough` (skill-declared or user-configured) also - bypass the blocklist. - """ + """Filter Hermes-managed secrets from a subprocess environment.""" try: from tools.env_passthrough import is_env_passthrough as _is_passthrough except Exception: @@ -163,33 +132,24 @@ def _sanitize_subprocess_env(base_env: dict | None, extra_env: dict | None = Non def _find_bash() -> str: - """Find bash for command execution. - - The fence wrapper uses bash syntax (semicolons, $?, printf), so we - must use bash — not the user's $SHELL which could be fish/zsh/etc. - On Windows: uses Git Bash (bundled with Git for Windows). - """ + """Find bash for command execution.""" if not _IS_WINDOWS: return ( shutil.which("bash") or ("/usr/bin/bash" if os.path.isfile("/usr/bin/bash") else None) or ("/bin/bash" if os.path.isfile("/bin/bash") else None) - or os.environ.get("SHELL") # last resort: whatever they have + or os.environ.get("SHELL") or "/bin/sh" ) - # Windows: look for Git Bash (installed with Git for Windows). - # Allow override via env var (same pattern as Claude Code). custom = os.environ.get("HERMES_GIT_BASH_PATH") if custom and os.path.isfile(custom): return custom - # shutil.which finds bash.exe if Git\bin is on PATH found = shutil.which("bash") if found: return found - # Check common Git for Windows install locations for candidate in ( os.path.join(os.environ.get("ProgramFiles", r"C:\Program Files"), "Git", "bin", "bash.exe"), os.path.join(os.environ.get("ProgramFiles(x86)", r"C:\Program Files (x86)"), "Git", "bin", "bash.exe"), @@ -209,60 +169,7 @@ def _find_bash() -> str: _find_shell = _find_bash -# Noise lines emitted by interactive shells when stdin is not a terminal. -# Used as a fallback when output fence markers are missing. -_SHELL_NOISE_SUBSTRINGS = ( - # bash - "bash: cannot set terminal process group", - "bash: no job control in this shell", - "no job control in this shell", - "cannot set terminal process group", - "tcsetattr: Inappropriate ioctl for device", - # zsh / oh-my-zsh / macOS terminal session - "Restored session:", - "Saving session...", - "Last login:", - "command not found:", - "Oh My Zsh", - "compinit:", -) - - -def _clean_shell_noise(output: str) -> str: - """Strip shell startup/exit warnings that leak when using -i without a TTY. - - Removes lines matching known noise patterns from both the beginning - and end of the output. Lines in the middle are left untouched. - """ - - def _is_noise(line: str) -> bool: - return any(noise in line for noise in _SHELL_NOISE_SUBSTRINGS) - - lines = output.split("\n") - - # Strip leading noise - while lines and _is_noise(lines[0]): - lines.pop(0) - - # Strip trailing noise (walk backwards, skip empty lines from split) - end = len(lines) - 1 - while end >= 0 and (not lines[end] or _is_noise(lines[end])): - end -= 1 - - if end < 0: - return "" - - cleaned = lines[: end + 1] - result = "\n".join(cleaned) - - # Preserve trailing newline if original had one - if output.endswith("\n") and result and not result.endswith("\n"): - result += "\n" - return result - - -# Standard PATH entries for environments with minimal PATH (e.g. systemd services). -# Includes macOS Homebrew paths (/opt/homebrew/* for Apple Silicon). +# Standard PATH entries for environments with minimal PATH. _SANE_PATH = ( "/opt/homebrew/bin:/opt/homebrew/sbin:" "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" @@ -290,197 +197,76 @@ def _make_run_env(env: dict) -> dict: return run_env -def _extract_fenced_output(raw: str) -> str: - """Extract real command output from between fence markers. - - The execute() method wraps each command with printf(FENCE) markers. - This function finds the first and last fence and returns only the - content between them, which is the actual command output free of - any shell init/exit noise. - - Falls back to pattern-based _clean_shell_noise if fences are missing. - """ - first = raw.find(_OUTPUT_FENCE) - if first == -1: - return _clean_shell_noise(raw) - - start = first + len(_OUTPUT_FENCE) - last = raw.rfind(_OUTPUT_FENCE) - - if last <= first: - # Only start fence found (e.g. user command called `exit`) - return _clean_shell_noise(raw[start:]) - - return raw[start:last] - - -class LocalEnvironment(PersistentShellMixin, BaseEnvironment): +class LocalEnvironment(BaseEnvironment): """Run commands directly on the host machine. - Features: - - Popen + polling for interrupt support (user can cancel mid-command) - - Background stdout drain thread to prevent pipe buffer deadlocks - - stdin_data support for piping content (bypasses ARG_MAX limits) - - sudo -S transform via SUDO_PASSWORD env var - - Uses interactive login shell so full user env is available - - Optional persistent shell mode (cwd/env vars survive across calls) + Spawn-per-call: every execute() spawns a fresh bash process. + Session snapshot preserves env vars across calls. + CWD persists via file-based read after each command. """ - def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None, - persistent: bool = False): + def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None): super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env) - self.persistent = persistent - if self.persistent: - self._init_persistent_shell() + self.init_session() - @property - def _temp_prefix(self) -> str: - return f"/tmp/hermes-local-{self._session_id}" - - def _spawn_shell_process(self) -> subprocess.Popen: - user_shell = _find_bash() - run_env = _make_run_env(self.env) - return subprocess.Popen( - [user_shell, "-l"], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.DEVNULL, - text=True, - env=run_env, - preexec_fn=None if _IS_WINDOWS else os.setsid, - ) - - def _read_temp_files(self, *paths: str) -> list[str]: - results = [] - for path in paths: - if os.path.exists(path): - with open(path) as f: - results.append(f.read()) - else: - results.append("") - return results - - def _kill_shell_children(self): - if self._shell_pid is None: - return - try: - subprocess.run( - ["pkill", "-P", str(self._shell_pid)], - capture_output=True, timeout=5, - ) - except (subprocess.TimeoutExpired, FileNotFoundError): - pass - - def _cleanup_temp_files(self): - for f in glob.glob(f"{self._temp_prefix}-*"): - if os.path.exists(f): - os.remove(f) - - def _execute_oneshot(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - work_dir = cwd or self.cwd or os.getcwd() - effective_timeout = timeout or self.timeout - exec_command, sudo_stdin = self._prepare_command(command) - - if sudo_stdin is not None and stdin_data is not None: - effective_stdin = sudo_stdin + stdin_data - elif sudo_stdin is not None: - effective_stdin = sudo_stdin - else: - effective_stdin = stdin_data - - user_shell = _find_bash() - # Newline-separated wrapper (not `cmd; __hermes_rc=...` on one line). - # A trailing `; __hermes_rc` glued to `< subprocess.Popen: + bash = _find_bash() + args = [bash, "-l", "-c", cmd_string] if login else [bash, "-c", cmd_string] run_env = _make_run_env(self.env) proc = subprocess.Popen( - [user_shell, "-lic", fenced_cmd], + args, text=True, - cwd=work_dir, env=run_env, encoding="utf-8", errors="replace", stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL, + stdin=subprocess.PIPE if stdin_data is not None else subprocess.DEVNULL, preexec_fn=None if _IS_WINDOWS else os.setsid, ) - if effective_stdin is not None: - def _write_stdin(): + if stdin_data is not None: + _pipe_stdin(proc, stdin_data) + + return proc + + def _kill_process(self, proc): + """Kill the entire process group (all children).""" + try: + if _IS_WINDOWS: + proc.terminate() + else: + pgid = os.getpgid(proc.pid) + os.killpg(pgid, signal.SIGTERM) try: - proc.stdin.write(effective_stdin) - proc.stdin.close() - except (BrokenPipeError, OSError): - pass - threading.Thread(target=_write_stdin, daemon=True).start() - - _output_chunks: list[str] = [] - - def _drain_stdout(): + proc.wait(timeout=1.0) + except subprocess.TimeoutExpired: + os.killpg(pgid, signal.SIGKILL) + except (ProcessLookupError, PermissionError): try: - for line in proc.stdout: - _output_chunks.append(line) - except ValueError: + proc.kill() + except Exception: pass - finally: - try: - proc.stdout.close() - except Exception: - pass - reader = threading.Thread(target=_drain_stdout, daemon=True) - reader.start() - deadline = time.monotonic() + effective_timeout + def _update_cwd(self, result: dict): + """Read CWD from temp file (local-only, no round-trip needed).""" + try: + cwd_path = open(self._cwd_file).read().strip() + if cwd_path: + self.cwd = cwd_path + except (OSError, FileNotFoundError): + pass - while proc.poll() is None: - if is_interrupted(): - try: - if _IS_WINDOWS: - proc.terminate() - else: - pgid = os.getpgid(proc.pid) - os.killpg(pgid, signal.SIGTERM) - try: - proc.wait(timeout=1.0) - except subprocess.TimeoutExpired: - os.killpg(pgid, signal.SIGKILL) - except (ProcessLookupError, PermissionError): - proc.kill() - reader.join(timeout=2) - return { - "output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]", - "returncode": 130, - } - if time.monotonic() > deadline: - try: - if _IS_WINDOWS: - proc.terminate() - else: - os.killpg(os.getpgid(proc.pid), signal.SIGTERM) - except (ProcessLookupError, PermissionError): - proc.kill() - reader.join(timeout=2) - partial = "".join(_output_chunks) - timeout_msg = f"\n[Command timed out after {effective_timeout}s]" - return { - "output": partial + timeout_msg if partial else timeout_msg.lstrip(), - "returncode": 124, - } - time.sleep(0.2) + # Still strip the marker from output so it's not visible + self._extract_cwd_from_output(result) - reader.join(timeout=5) - output = _extract_fenced_output("".join(_output_chunks)) - return {"output": output, "returncode": proc.returncode} + def cleanup(self): + """Clean up temp files.""" + for f in (self._snapshot_path, self._cwd_file): + try: + os.unlink(f) + except OSError: + pass diff --git a/tools/environments/managed_modal.py b/tools/environments/managed_modal.py index a8197bccf28..52b00f19a3d 100644 --- a/tools/environments/managed_modal.py +++ b/tools/environments/managed_modal.py @@ -10,7 +10,7 @@ import uuid from dataclasses import dataclass from typing import Any, Dict, Optional -from tools.environments.modal_common import ( +from tools.environments.modal_utils import ( BaseModalExecutionEnvironment, ModalExecStart, PreparedModalExec, diff --git a/tools/environments/modal.py b/tools/environments/modal.py index 7916a2c449a..1cb8e47969e 100644 --- a/tools/environments/modal.py +++ b/tools/environments/modal.py @@ -5,19 +5,19 @@ wrapper, while preserving Hermes' persistent snapshot behavior across sessions. """ import asyncio -import json import logging import shlex import threading -from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Optional from hermes_constants import get_hermes_home -from tools.environments.modal_common import ( - BaseModalExecutionEnvironment, - ModalExecStart, - PreparedModalExec, +from tools.environments.base import ( + BaseEnvironment, + _ThreadedProcessHandle, + _file_mtime_key, + _load_json_store, + _save_json_store, ) logger = logging.getLogger(__name__) @@ -26,20 +26,12 @@ _SNAPSHOT_STORE = get_hermes_home() / "modal_snapshots.json" _DIRECT_SNAPSHOT_NAMESPACE = "direct" -def _load_snapshots() -> Dict[str, str]: - """Load snapshot ID mapping from disk.""" - if _SNAPSHOT_STORE.exists(): - try: - return json.loads(_SNAPSHOT_STORE.read_text()) - except Exception: - pass - return {} +def _load_snapshots() -> dict: + return _load_json_store(_SNAPSHOT_STORE) -def _save_snapshots(data: Dict[str, str]) -> None: - """Persist snapshot ID mapping to disk.""" - _SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True) - _SNAPSHOT_STORE.write_text(json.dumps(data, indent=2)) +def _save_snapshots(data: dict) -> None: + _save_json_store(_SNAPSHOT_STORE, data) def _direct_snapshot_key(task_id: str) -> str: @@ -47,23 +39,18 @@ def _direct_snapshot_key(task_id: str) -> str: def _get_snapshot_restore_candidate(task_id: str) -> tuple[str | None, bool]: - """Return a snapshot id and whether it came from the legacy key format.""" snapshots = _load_snapshots() - namespaced_key = _direct_snapshot_key(task_id) snapshot_id = snapshots.get(namespaced_key) if isinstance(snapshot_id, str) and snapshot_id: return snapshot_id, False - legacy_snapshot_id = snapshots.get(task_id) if isinstance(legacy_snapshot_id, str) and legacy_snapshot_id: return legacy_snapshot_id, True - return None, False def _store_direct_snapshot(task_id: str, snapshot_id: str) -> None: - """Persist the direct Modal snapshot id under the direct namespace.""" snapshots = _load_snapshots() snapshots[_direct_snapshot_key(task_id)] = snapshot_id snapshots.pop(task_id, None) @@ -71,10 +58,8 @@ def _store_direct_snapshot(task_id: str, snapshot_id: str) -> None: def _delete_direct_snapshot(task_id: str, snapshot_id: str | None = None) -> None: - """Remove direct Modal snapshot entries for a task, including legacy keys.""" snapshots = _load_snapshots() updated = False - for key in (_direct_snapshot_key(task_id), task_id): value = snapshots.get(key) if value is None: @@ -82,13 +67,15 @@ def _delete_direct_snapshot(task_id: str, snapshot_id: str | None = None) -> Non if snapshot_id is None or value == snapshot_id: snapshots.pop(key, None) updated = True - if updated: _save_snapshots(snapshots) def _resolve_modal_image(image_spec: Any) -> Any: - """Convert registry references or snapshot ids into Modal image objects.""" + """Convert registry references or snapshot ids into Modal image objects. + + Includes add_python support for ubuntu/debian images (absorbed from PR 4511). + """ import modal as _modal if not isinstance(image_spec, str): @@ -97,12 +84,22 @@ def _resolve_modal_image(image_spec: Any) -> Any: if image_spec.startswith("im-"): return _modal.Image.from_id(image_spec) + # PR 4511: add python to ubuntu/debian images that don't have it + lower = image_spec.lower() + add_python = any(base in lower for base in ("ubuntu", "debian")) + + setup_commands = [ + "RUN rm -rf /usr/local/lib/python*/site-packages/pip* 2>/dev/null; " + "python -m ensurepip --upgrade --default-pip 2>/dev/null || true", + ] + if add_python: + setup_commands.insert(0, + "RUN apt-get update -qq && apt-get install -y -qq python3 python3-venv > /dev/null 2>&1 || true" + ) + return _modal.Image.from_registry( image_spec, - setup_dockerfile_commands=[ - "RUN rm -rf /usr/local/lib/python*/site-packages/pip* 2>/dev/null; " - "python -m ensurepip --upgrade --default-pip 2>/dev/null || true", - ], + setup_dockerfile_commands=setup_commands, ) @@ -138,19 +135,15 @@ class _AsyncWorker: self._thread.join(timeout=10) -@dataclass -class _DirectModalExecHandle: - thread: threading.Thread - result_holder: Dict[str, Any] +class ModalEnvironment(BaseEnvironment): + """Modal cloud execution via native Modal sandboxes. - -class ModalEnvironment(BaseModalExecutionEnvironment): - """Modal cloud execution via native Modal sandboxes.""" + Spawn-per-call via _ThreadedProcessHandle wrapping async SDK calls. + cancel_fn wired to sandbox.terminate for interrupt support. + """ _stdin_mode = "heredoc" - _poll_interval_seconds = 0.2 - _interrupt_output = "[Command interrupted - Modal sandbox terminated]" - _unexpected_error_prefix = "Modal execution error" + _snapshot_timeout = 60 # Modal cold starts can be slow def __init__( self, @@ -170,6 +163,7 @@ class ModalEnvironment(BaseModalExecutionEnvironment): self._app = None self._worker = _AsyncWorker() self._synced_files: Dict[str, tuple] = {} + self._last_sync_time: float = 0 sandbox_kwargs = dict(modal_sandbox_kwargs or {}) @@ -199,27 +193,13 @@ class ModalEnvironment(BaseModalExecutionEnvironment): remote_path=mount_entry["container_path"], ) ) - logger.info( - "Modal: mounting credential %s -> %s", - mount_entry["host_path"], - mount_entry["container_path"], - ) - - # Mount individual skill files (symlinks filtered out). - skills_files = iter_skills_files() - for entry in skills_files: + for entry in iter_skills_files(): cred_mounts.append( _modal.Mount.from_local_file( entry["host_path"], remote_path=entry["container_path"], ) ) - if skills_files: - logger.info("Modal: mounting %d skill files", len(skills_files)) - - # Mount host-side cache files (documents, images, audio, - # screenshots). New files arriving mid-session are picked up - # by _sync_files() before each command execution. cache_files = iter_cache_files() for entry in cache_files: cred_mounts.append( @@ -228,8 +208,6 @@ class ModalEnvironment(BaseModalExecutionEnvironment): remote_path=entry["container_path"], ) ) - if cache_files: - logger.info("Modal: mounting %d cache files", len(cache_files)) except Exception as e: logger.debug("Modal: could not load credential file mounts: %s", e) @@ -243,8 +221,7 @@ class ModalEnvironment(BaseModalExecutionEnvironment): existing_mounts.extend(cred_mounts) create_kwargs["mounts"] = existing_mounts sandbox = await _modal.Sandbox.create.aio( - "sleep", - "infinity", + "sleep", "infinity", image=image_spec, app=app, timeout=int(create_kwargs.pop("timeout", 3600)), @@ -255,57 +232,41 @@ class ModalEnvironment(BaseModalExecutionEnvironment): try: target_image_spec = restored_snapshot_id or image try: - # _resolve_modal_image keeps the Modal bootstrap fix together: - # it applies setup_dockerfile_commands with ensurepip before - # Modal builds registry images, while snapshot ids restore via - # modal.Image.from_id() without rebuilding. effective_image = _resolve_modal_image(target_image_spec) self._app, self._sandbox = self._worker.run_coroutine( - _create_sandbox(effective_image), - timeout=300, + _create_sandbox(effective_image), timeout=300, ) except Exception as exc: if not restored_snapshot_id: raise - logger.warning( "Modal: failed to restore snapshot %s, retrying with base image: %s", - restored_snapshot_id[:20], - exc, + restored_snapshot_id[:20], exc, ) _delete_direct_snapshot(self._task_id, restored_snapshot_id) base_image = _resolve_modal_image(image) self._app, self._sandbox = self._worker.run_coroutine( - _create_sandbox(base_image), - timeout=300, + _create_sandbox(base_image), timeout=300, ) else: if restored_snapshot_id and restored_from_legacy_key: _store_direct_snapshot(self._task_id, restored_snapshot_id) - logger.info( - "Modal: migrated legacy snapshot entry for task %s", - self._task_id, - ) except Exception: self._worker.stop() raise logger.info("Modal: sandbox created (task=%s)", self._task_id) + self.init_session() def _push_file_to_sandbox(self, host_path: str, container_path: str) -> bool: - """Push a single file into the sandbox if changed. Returns True if synced.""" - hp = Path(host_path) - try: - stat = hp.stat() - file_key = (stat.st_mtime, stat.st_size) - except OSError: + """Push a single file into the sandbox if changed.""" + file_key = _file_mtime_key(host_path) + if file_key is None: return False - if self._synced_files.get(container_path) == file_key: return False - try: - content = hp.read_bytes() + content = Path(host_path).read_bytes() except Exception: return False @@ -326,85 +287,55 @@ class ModalEnvironment(BaseModalExecutionEnvironment): return True def _sync_files(self) -> None: - """Push credential, skill, and cache files into the running sandbox. - - Runs before each command. Uses mtime+size caching so only changed - files are pushed (~13μs overhead in the no-op case). Cache files - are especially important here — new uploads/screenshots may appear - mid-session after sandbox creation. - """ + """Push credential, skill, and cache files into the running sandbox.""" try: from tools.credential_files import ( get_credential_file_mounts, iter_skills_files, iter_cache_files, ) - for entry in get_credential_file_mounts(): - if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]): - logger.debug("Modal: synced credential %s", entry["container_path"]) - + self._push_file_to_sandbox(entry["host_path"], entry["container_path"]) for entry in iter_skills_files(): - if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]): - logger.debug("Modal: synced skill file %s", entry["container_path"]) - + self._push_file_to_sandbox(entry["host_path"], entry["container_path"]) for entry in iter_cache_files(): - if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]): - logger.debug("Modal: synced cache file %s", entry["container_path"]) + self._push_file_to_sandbox(entry["host_path"], entry["container_path"]) except Exception as e: logger.debug("Modal: file sync failed: %s", e) - def _before_execute(self) -> None: - self._sync_files() + def _run_bash(self, cmd_string: str, *, login: bool = False, + timeout: int = 120, + stdin_data: str | None = None): + """Return a _ThreadedProcessHandle wrapping an async Modal sandbox exec.""" + sandbox = self._sandbox + worker = self._worker - def _start_modal_exec(self, prepared: PreparedModalExec) -> ModalExecStart: - full_command = f"cd {shlex.quote(prepared.cwd)} && {prepared.command}" - result_holder = {"value": None, "error": None} + def cancel(): + worker.run_coroutine(sandbox.terminate.aio(), timeout=15) - def _run(): - try: - async def _do_execute(): - process = await self._sandbox.exec.aio( - "bash", - "-c", - full_command, - timeout=prepared.timeout, - ) - stdout = await process.stdout.read.aio() - stderr = await process.stderr.read.aio() - exit_code = await process.wait.aio() - if isinstance(stdout, bytes): - stdout = stdout.decode("utf-8", errors="replace") - if isinstance(stderr, bytes): - stderr = stderr.decode("utf-8", errors="replace") - output = stdout - if stderr: - output = f"{stdout}\n{stderr}" if stdout else stderr - return self._result(output, exit_code) + def exec_fn() -> tuple[str, int]: + async def _do(): + args = ["bash"] + if login: + args.extend(["-l", "-c", cmd_string]) + else: + args.extend(["-c", cmd_string]) + process = await sandbox.exec.aio(*args, timeout=timeout) + stdout = await process.stdout.read.aio() + stderr = await process.stderr.read.aio() + exit_code = await process.wait.aio() + if isinstance(stdout, bytes): + stdout = stdout.decode("utf-8", errors="replace") + if isinstance(stderr, bytes): + stderr = stderr.decode("utf-8", errors="replace") + output = stdout + if stderr: + output = f"{stdout}\n{stderr}" if stdout else stderr + return output, exit_code - result_holder["value"] = self._worker.run_coroutine( - _do_execute(), - timeout=prepared.timeout + 30, - ) - except Exception as e: - result_holder["error"] = e + return worker.run_coroutine(_do(), timeout=timeout + 30) - t = threading.Thread(target=_run, daemon=True) - t.start() - return ModalExecStart(handle=_DirectModalExecHandle(thread=t, result_holder=result_holder)) - - def _poll_modal_exec(self, handle: _DirectModalExecHandle) -> dict | None: - if handle.thread.is_alive(): - return None - if handle.result_holder["error"]: - return self._error_result(f"Modal execution error: {handle.result_holder['error']}") - return handle.result_holder["value"] - - def _cancel_modal_exec(self, handle: _DirectModalExecHandle) -> None: - self._worker.run_coroutine( - self._sandbox.terminate.aio(), - timeout=15, - ) + return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel) def cleanup(self): """Snapshot the filesystem (if persistent) then stop the sandbox.""" @@ -426,17 +357,13 @@ class ModalEnvironment(BaseModalExecutionEnvironment): _store_direct_snapshot(self._task_id, snapshot_id) logger.info( "Modal: saved filesystem snapshot %s for task %s", - snapshot_id[:20], - self._task_id, + snapshot_id[:20], self._task_id, ) except Exception as e: logger.warning("Modal: filesystem snapshot failed: %s", e) try: - self._worker.run_coroutine( - self._sandbox.terminate.aio(), - timeout=15, - ) + self._worker.run_coroutine(self._sandbox.terminate.aio(), timeout=15) except Exception: pass finally: diff --git a/tools/environments/modal_common.py b/tools/environments/modal_utils.py similarity index 91% rename from tools/environments/modal_common.py rename to tools/environments/modal_utils.py index 0affd02095a..0db8194719f 100644 --- a/tools/environments/modal_common.py +++ b/tools/environments/modal_utils.py @@ -56,7 +56,15 @@ def wrap_modal_sudo_pipe(command: str, sudo_stdin: str) -> str: class BaseModalExecutionEnvironment(BaseEnvironment): - """Common execute() flow for direct and managed Modal transports.""" + """Execution flow for the *managed* Modal transport (gateway-owned sandbox). + + This deliberately overrides :meth:`BaseEnvironment.execute` because the + tool-gateway handles command preparation, CWD tracking, and env-snapshot + management on the server side. The base class's ``_wrap_command`` / + ``_wait_for_process`` / snapshot machinery does not apply here — the + gateway owns that responsibility. See ``ManagedModalEnvironment`` for the + concrete subclass. + """ _stdin_mode = "payload" _poll_interval_seconds = 0.25 @@ -124,7 +132,7 @@ class BaseModalExecutionEnvironment(BaseEnvironment): def _before_execute(self) -> None: """Hook for backends that need pre-exec sync or validation.""" - return None + pass def _prepare_modal_exec( self, diff --git a/tools/environments/persistent_shell.py b/tools/environments/persistent_shell.py deleted file mode 100644 index c4344ff5a12..00000000000 --- a/tools/environments/persistent_shell.py +++ /dev/null @@ -1,290 +0,0 @@ -"""Persistent shell mixin: file-based IPC protocol for long-lived bash shells.""" - -import logging -import shlex -import subprocess -import threading -import time -import uuid -from abc import abstractmethod - -from tools.interrupt import is_interrupted - -logger = logging.getLogger(__name__) - - -class PersistentShellMixin: - """Mixin that adds persistent shell capability to any BaseEnvironment. - - Subclasses must implement ``_spawn_shell_process()``, ``_read_temp_files()``, - ``_kill_shell_children()``, ``_execute_oneshot()``, and ``_cleanup_temp_files()``. - """ - - persistent: bool - - @abstractmethod - def _spawn_shell_process(self) -> subprocess.Popen: ... - - @abstractmethod - def _read_temp_files(self, *paths: str) -> list[str]: ... - - @abstractmethod - def _kill_shell_children(self): ... - - @abstractmethod - def _execute_oneshot(self, command: str, cwd: str, *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: ... - - @abstractmethod - def _cleanup_temp_files(self): ... - - _session_id: str = "" - _poll_interval_start: float = 0.01 # initial poll interval (10ms) - _poll_interval_max: float = 0.25 # max poll interval (250ms) — reduces I/O for long commands - - @property - def _temp_prefix(self) -> str: - return f"/tmp/hermes-persistent-{self._session_id}" - - # ------------------------------------------------------------------ - # Lifecycle - # ------------------------------------------------------------------ - - def _init_persistent_shell(self): - self._shell_lock = threading.Lock() - self._shell_proc: subprocess.Popen | None = None - self._shell_alive: bool = False - self._shell_pid: int | None = None - - self._session_id = uuid.uuid4().hex[:12] - p = self._temp_prefix - self._pshell_stdout = f"{p}-stdout" - self._pshell_stderr = f"{p}-stderr" - self._pshell_status = f"{p}-status" - self._pshell_cwd = f"{p}-cwd" - self._pshell_pid_file = f"{p}-pid" - - self._shell_proc = self._spawn_shell_process() - self._shell_alive = True - - self._drain_thread = threading.Thread( - target=self._drain_shell_output, daemon=True, - ) - self._drain_thread.start() - - init_script = ( - f"export TERM=${{TERM:-dumb}}\n" - f"touch {self._pshell_stdout} {self._pshell_stderr} " - f"{self._pshell_status} {self._pshell_cwd} {self._pshell_pid_file}\n" - f"echo $$ > {self._pshell_pid_file}\n" - f"pwd > {self._pshell_cwd}\n" - ) - self._send_to_shell(init_script) - - deadline = time.monotonic() + 3.0 - while time.monotonic() < deadline: - pid_str = self._read_temp_files(self._pshell_pid_file)[0].strip() - if pid_str.isdigit(): - self._shell_pid = int(pid_str) - break - time.sleep(0.05) - else: - logger.warning("Could not read persistent shell PID") - self._shell_pid = None - - if self._shell_pid: - logger.info( - "Persistent shell started (session=%s, pid=%d)", - self._session_id, self._shell_pid, - ) - - reported_cwd = self._read_temp_files(self._pshell_cwd)[0].strip() - if reported_cwd: - self.cwd = reported_cwd - - def _cleanup_persistent_shell(self): - if self._shell_proc is None: - return - - if self._session_id: - self._cleanup_temp_files() - - try: - self._shell_proc.stdin.close() - except Exception: - pass - try: - self._shell_proc.terminate() - self._shell_proc.wait(timeout=3) - except subprocess.TimeoutExpired: - self._shell_proc.kill() - - self._shell_alive = False - self._shell_proc = None - - if hasattr(self, "_drain_thread") and self._drain_thread.is_alive(): - self._drain_thread.join(timeout=1.0) - - # ------------------------------------------------------------------ - # execute() / cleanup() — shared dispatcher, subclasses inherit - # ------------------------------------------------------------------ - - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - if self.persistent: - return self._execute_persistent( - command, cwd, timeout=timeout, stdin_data=stdin_data, - ) - return self._execute_oneshot( - command, cwd, timeout=timeout, stdin_data=stdin_data, - ) - - def execute_oneshot(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - """Always use the oneshot (non-persistent) execution path. - - This bypasses _shell_lock so it can run concurrently with a - long-running command in the persistent shell — used by - execute_code's file-based RPC polling thread. - """ - return self._execute_oneshot( - command, cwd, timeout=timeout, stdin_data=stdin_data, - ) - - def cleanup(self): - if self.persistent: - self._cleanup_persistent_shell() - - # ------------------------------------------------------------------ - # Shell I/O - # ------------------------------------------------------------------ - - def _drain_shell_output(self): - try: - for _ in self._shell_proc.stdout: - pass - except Exception: - pass - self._shell_alive = False - - def _send_to_shell(self, text: str): - if not self._shell_alive or self._shell_proc is None: - return - try: - self._shell_proc.stdin.write(text) - self._shell_proc.stdin.flush() - except (BrokenPipeError, OSError): - self._shell_alive = False - - def _read_persistent_output(self) -> tuple[str, int, str]: - stdout, stderr, status_raw, cwd = self._read_temp_files( - self._pshell_stdout, self._pshell_stderr, - self._pshell_status, self._pshell_cwd, - ) - output = self._merge_output(stdout, stderr) - status = status_raw.strip() - if ":" in status: - status = status.split(":", 1)[1] - try: - exit_code = int(status.strip()) - except ValueError: - exit_code = 1 - return output, exit_code, cwd.strip() - - # ------------------------------------------------------------------ - # Execution - # ------------------------------------------------------------------ - - def _execute_persistent(self, command: str, cwd: str, *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - if not self._shell_alive: - logger.info("Persistent shell died, restarting...") - self._init_persistent_shell() - - exec_command, sudo_stdin = self._prepare_command(command) - effective_timeout = timeout or self.timeout - if stdin_data or sudo_stdin: - return self._execute_oneshot( - command, cwd, timeout=timeout, stdin_data=stdin_data, - ) - - with self._shell_lock: - return self._execute_persistent_locked( - exec_command, cwd, effective_timeout, - ) - - def _execute_persistent_locked(self, command: str, cwd: str, - timeout: int) -> dict: - work_dir = cwd or self.cwd - cmd_id = uuid.uuid4().hex[:8] - truncate = ( - f": > {self._pshell_stdout}\n" - f": > {self._pshell_stderr}\n" - f": > {self._pshell_status}\n" - ) - self._send_to_shell(truncate) - escaped = command.replace("'", "'\\''") - - ipc_script = ( - f"cd {shlex.quote(work_dir)}\n" - f"eval '{escaped}' < /dev/null > {self._pshell_stdout} 2> {self._pshell_stderr}\n" - f"__EC=$?\n" - f"pwd > {self._pshell_cwd}\n" - f"echo {cmd_id}:$__EC > {self._pshell_status}\n" - ) - self._send_to_shell(ipc_script) - deadline = time.monotonic() + timeout - poll_interval = self._poll_interval_start # starts at 10ms, backs off to 250ms - - while True: - if is_interrupted(): - self._kill_shell_children() - output, _, _ = self._read_persistent_output() - return { - "output": output + "\n[Command interrupted]", - "returncode": 130, - } - - if time.monotonic() > deadline: - self._kill_shell_children() - output, _, _ = self._read_persistent_output() - if output: - return { - "output": output + f"\n[Command timed out after {timeout}s]", - "returncode": 124, - } - return self._timeout_result(timeout) - - if not self._shell_alive: - return { - "output": "Persistent shell died during execution", - "returncode": 1, - } - - status_content = self._read_temp_files(self._pshell_status)[0].strip() - if status_content.startswith(cmd_id + ":"): - break - - time.sleep(poll_interval) - # Exponential backoff: fast start (10ms) for quick commands, - # ramps up to 250ms for long-running commands — reduces I/O by 10-25x - # on WSL2 where polling keeps the VM hot and memory pressure high. - poll_interval = min(poll_interval * 1.5, self._poll_interval_max) - - output, exit_code, new_cwd = self._read_persistent_output() - if new_cwd: - self.cwd = new_cwd - return {"output": output, "returncode": exit_code} - - @staticmethod - def _merge_output(stdout: str, stderr: str) -> str: - parts = [] - if stdout.strip(): - parts.append(stdout.rstrip("\n")) - if stderr.strip(): - parts.append(stderr.rstrip("\n")) - return "\n".join(parts) diff --git a/tools/environments/singularity.py b/tools/environments/singularity.py index 0ea5037c84a..16d1013fed8 100644 --- a/tools/environments/singularity.py +++ b/tools/environments/singularity.py @@ -5,20 +5,22 @@ Supports configurable resource limits and optional filesystem persistence via writable overlay directories that survive across sessions. """ -import json import logging import os -import shlex import shutil import subprocess import threading import uuid from pathlib import Path -from typing import Dict, Optional +from typing import Optional from hermes_constants import get_hermes_home -from tools.environments.base import BaseEnvironment -from tools.interrupt import is_interrupted +from tools.environments.base import ( + BaseEnvironment, + _load_json_store, + _popen_bash, + _save_json_store, +) logger = logging.getLogger(__name__) @@ -26,11 +28,7 @@ _SNAPSHOT_STORE = get_hermes_home() / "singularity_snapshots.json" def _find_singularity_executable() -> str: - """Locate the apptainer or singularity CLI binary. - - Returns the executable name (``"apptainer"`` or ``"singularity"``). - Raises ``RuntimeError`` with install instructions if neither is found. - """ + """Locate the apptainer or singularity CLI binary.""" if shutil.which("apptainer"): return "apptainer" if shutil.which("singularity"): @@ -43,66 +41,34 @@ def _find_singularity_executable() -> str: def _ensure_singularity_available() -> str: - """Preflight check: resolve the executable and verify it responds. - - Returns the executable name on success. - Raises ``RuntimeError`` with an actionable message on failure. - """ + """Preflight check: resolve the executable and verify it responds.""" exe = _find_singularity_executable() - try: result = subprocess.run( - [exe, "version"], - capture_output=True, - text=True, - timeout=10, + [exe, "version"], capture_output=True, text=True, timeout=10, ) except FileNotFoundError: raise RuntimeError( - f"Singularity backend selected but the resolved executable '{exe}' " - "could not be executed. Check your installation." + f"Singularity backend selected but '{exe}' could not be executed." ) except subprocess.TimeoutExpired: - raise RuntimeError( - f"'{exe} version' timed out. The runtime may be misconfigured." - ) + raise RuntimeError(f"'{exe} version' timed out.") if result.returncode != 0: stderr = result.stderr.strip()[:200] - raise RuntimeError( - f"'{exe} version' failed (exit code {result.returncode}): {stderr}" - ) - + raise RuntimeError(f"'{exe} version' failed (exit code {result.returncode}): {stderr}") return exe -def _load_snapshots() -> Dict[str, str]: - if _SNAPSHOT_STORE.exists(): - try: - return json.loads(_SNAPSHOT_STORE.read_text()) - except Exception: - pass - return {} +def _load_snapshots() -> dict: + return _load_json_store(_SNAPSHOT_STORE) -def _save_snapshots(data: Dict[str, str]) -> None: - _SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True) - _SNAPSHOT_STORE.write_text(json.dumps(data, indent=2)) +def _save_snapshots(data: dict) -> None: + _save_json_store(_SNAPSHOT_STORE, data) -# ------------------------------------------------------------------------- -# Singularity helpers (scratch dir, SIF cache, SIF building) -# ------------------------------------------------------------------------- - def _get_scratch_dir() -> Path: - """Get the best directory for Singularity sandboxes. - - Resolution order: - 1. TERMINAL_SCRATCH_DIR (explicit override) - 2. TERMINAL_SANDBOX_DIR / singularity (shared sandbox root) - 3. /scratch (common on HPC clusters) - 4. ~/.hermes/sandboxes/singularity (fallback) - """ custom_scratch = os.getenv("TERMINAL_SCRATCH_DIR") if custom_scratch: scratch_path = Path(custom_scratch) @@ -124,7 +90,6 @@ def _get_scratch_dir() -> Path: def _get_apptainer_cache_dir() -> Path: - """Get the Apptainer cache directory for SIF images.""" cache_dir = os.getenv("APPTAINER_CACHEDIR") if cache_dir: cache_path = Path(cache_dir) @@ -140,11 +105,6 @@ _sif_build_lock = threading.Lock() def _get_or_build_sif(image: str, executable: str = "apptainer") -> str: - """Get or build a SIF image from a docker:// URL. - - Returns the path unchanged if it's already a .sif file. - For docker:// URLs, checks the cache and builds if needed. - """ if image.endswith('.sif') and Path(image).exists(): return image if not image.startswith('docker://'): @@ -193,19 +153,12 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str: return image -# ------------------------------------------------------------------------- -# SingularityEnvironment -# ------------------------------------------------------------------------- - class SingularityEnvironment(BaseEnvironment): """Hardened Singularity/Apptainer container with resource limits and persistence. - Security: --containall (isolated PID/IPC/mount namespaces, no host home mount), - --no-home, writable-tmpfs for scratch space. The container cannot see or modify - the host filesystem outside of explicitly bound paths. - - Persistence: when enabled, the writable overlay directory is preserved across - sessions so installed packages and files survive cleanup/restore. + Spawn-per-call: every execute() spawns a fresh ``apptainer exec ... bash -c`` process. + Session snapshot preserves env vars across calls. + CWD persists via in-band stdout markers. """ def __init__( @@ -227,12 +180,9 @@ class SingularityEnvironment(BaseEnvironment): self._persistent = persistent_filesystem self._task_id = task_id self._overlay_dir: Optional[Path] = None - - # Resource limits self._cpu = cpu self._memory = memory - # Persistent overlay directory if self._persistent: overlay_base = _get_scratch_dir() / "hermes-overlays" overlay_base.mkdir(parents=True, exist_ok=True) @@ -240,42 +190,26 @@ class SingularityEnvironment(BaseEnvironment): self._overlay_dir.mkdir(parents=True, exist_ok=True) self._start_instance() + self.init_session() def _start_instance(self): cmd = [self.executable, "instance", "start"] - - # Security: full isolation from host cmd.extend(["--containall", "--no-home"]) - # Writable layer if self._persistent and self._overlay_dir: - # Persistent writable overlay -- survives across restarts cmd.extend(["--overlay", str(self._overlay_dir)]) else: cmd.append("--writable-tmpfs") - # Mount credential files and skills directory (read-only). try: from tools.credential_files import get_credential_file_mounts, get_skills_directory_mount - for mount_entry in get_credential_file_mounts(): cmd.extend(["--bind", f"{mount_entry['host_path']}:{mount_entry['container_path']}:ro"]) - logger.info( - "Singularity: binding credential %s -> %s", - mount_entry["host_path"], - mount_entry["container_path"], - ) for skills_mount in get_skills_directory_mount(): cmd.extend(["--bind", f"{skills_mount['host_path']}:{skills_mount['container_path']}:ro"]) - logger.info( - "Singularity: binding skills dir %s -> %s", - skills_mount["host_path"], - skills_mount["container_path"], - ) except Exception as e: logger.debug("Singularity: could not load credential/skills mounts: %s", e) - # Resource limits (cgroup-based, may require root or appropriate config) if self._memory > 0: cmd.extend(["--memory", f"{self._memory}M"]) if self._cpu > 0: @@ -288,94 +222,29 @@ class SingularityEnvironment(BaseEnvironment): if result.returncode != 0: raise RuntimeError(f"Failed to start instance: {result.stderr}") self._instance_started = True - logger.info("Singularity instance %s started (persistent=%s)", + logger.info("Singularity instance %s started (persistent=%s)", self.instance_id, self._persistent) except subprocess.TimeoutExpired: raise RuntimeError("Instance start timed out") - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: + def _run_bash(self, cmd_string: str, *, login: bool = False, + timeout: int = 120, + stdin_data: str | None = None) -> subprocess.Popen: + """Spawn a bash process inside the Singularity instance.""" if not self._instance_started: - return {"output": "Instance not started", "returncode": -1} + raise RuntimeError("Singularity instance not started") - effective_timeout = timeout or self.timeout - work_dir = cwd or self.cwd - exec_command, sudo_stdin = self._prepare_command(command) - - # Merge sudo password (if any) with caller-supplied stdin_data. - if sudo_stdin is not None and stdin_data is not None: - effective_stdin = sudo_stdin + stdin_data - elif sudo_stdin is not None: - effective_stdin = sudo_stdin + cmd = [self.executable, "exec", + f"instance://{self.instance_id}"] + if login: + cmd.extend(["bash", "-l", "-c", cmd_string]) else: - effective_stdin = stdin_data + cmd.extend(["bash", "-c", cmd_string]) - # apptainer exec --pwd doesn't expand ~, so prepend a cd into the command. - # Keep ~ unquoted (for shell expansion) and quote only the subpath. - if work_dir == "~": - exec_command = f"cd ~ && {exec_command}" - work_dir = "/tmp" - elif work_dir.startswith("~/"): - exec_command = f"cd ~/{shlex.quote(work_dir[2:])} && {exec_command}" - work_dir = "/tmp" - - cmd = [self.executable, "exec", "--pwd", work_dir, - f"instance://{self.instance_id}", - "bash", "-c", exec_command] - - try: - import time as _time - _output_chunks = [] - proc = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL, - text=True, - ) - if effective_stdin: - try: - proc.stdin.write(effective_stdin) - proc.stdin.close() - except Exception: - pass - - def _drain(): - try: - for line in proc.stdout: - _output_chunks.append(line) - except Exception: - pass - - reader = threading.Thread(target=_drain, daemon=True) - reader.start() - deadline = _time.monotonic() + effective_timeout - - while proc.poll() is None: - if is_interrupted(): - proc.terminate() - try: - proc.wait(timeout=1) - except subprocess.TimeoutExpired: - proc.kill() - reader.join(timeout=2) - return { - "output": "".join(_output_chunks) + "\n[Command interrupted]", - "returncode": 130, - } - if _time.monotonic() > deadline: - proc.kill() - reader.join(timeout=2) - return self._timeout_result(effective_timeout) - _time.sleep(0.2) - - reader.join(timeout=5) - return {"output": "".join(_output_chunks), "returncode": proc.returncode} - except Exception as e: - return {"output": f"Singularity execution error: {e}", "returncode": 1} + return _popen_bash(cmd, stdin_data) def cleanup(self): - """Stop the instance. If persistent, the overlay dir survives for next creation.""" + """Stop the instance. If persistent, the overlay dir survives.""" if self._instance_started: try: subprocess.run( @@ -387,7 +256,6 @@ class SingularityEnvironment(BaseEnvironment): logger.warning("Failed to stop Singularity instance %s: %s", self.instance_id, e) self._instance_started = False - # Record overlay path for persistence restoration if self._persistent and self._overlay_dir: snapshots = _load_snapshots() snapshots[self._task_id] = str(self._overlay_dir) diff --git a/tools/environments/ssh.py b/tools/environments/ssh.py index afd28c4affa..a77eb5c9f40 100644 --- a/tools/environments/ssh.py +++ b/tools/environments/ssh.py @@ -5,13 +5,9 @@ import shlex import shutil import subprocess import tempfile -import threading -import time from pathlib import Path -from tools.environments.base import BaseEnvironment -from tools.environments.persistent_shell import PersistentShellMixin -from tools.interrupt import is_interrupted +from tools.environments.base import BaseEnvironment, _popen_bash logger = logging.getLogger(__name__) @@ -24,32 +20,22 @@ def _ensure_ssh_available() -> None: ) -class SSHEnvironment(PersistentShellMixin, BaseEnvironment): +class SSHEnvironment(BaseEnvironment): """Run commands on a remote machine over SSH. - Uses SSH ControlMaster for connection persistence so subsequent - commands are fast. Security benefit: the agent cannot modify its - own code since execution happens on a separate machine. - - Foreground commands are interruptible: the local ssh process is killed - and a remote kill is attempted over the ControlMaster socket. - - When ``persistent=True``, a single long-lived bash shell is kept alive - over SSH and state (cwd, env vars, shell variables) persists across - ``execute()`` calls. Output capture uses file-based IPC on the remote - host (stdout/stderr/exit-code written to temp files, polled via fast - ControlMaster one-shot reads). + Spawn-per-call: every execute() spawns a fresh ``ssh ... bash -c`` process. + Session snapshot preserves env vars across calls. + CWD persists via in-band stdout markers. + Uses SSH ControlMaster for connection reuse. """ def __init__(self, host: str, user: str, cwd: str = "~", - timeout: int = 60, port: int = 22, key_path: str = "", - persistent: bool = False): + timeout: int = 60, port: int = 22, key_path: str = ""): super().__init__(cwd=cwd, timeout=timeout) self.host = host self.user = user self.port = port self.key_path = key_path - self.persistent = persistent self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh" self.control_dir.mkdir(parents=True, exist_ok=True) @@ -57,10 +43,10 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): _ensure_ssh_available() self._establish_connection() self._remote_home = self._detect_remote_home() - self._sync_skills_and_credentials() + self._last_sync_time: float = 0 # guarantees first _before_execute syncs + self._sync_files() - if self.persistent: - self._init_persistent_shell() + self.init_session() def _build_ssh_command(self, extra_args: list | None = None) -> list: cmd = ["ssh"] @@ -102,12 +88,11 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): return home except Exception: pass - # Fallback: guess from username if self.user == "root": return "/root" return f"/home/{self.user}" - def _sync_skills_and_credentials(self) -> None: + def _sync_files(self) -> None: """Rsync skills directory and credential files to the remote host.""" try: container_base = f"{self._remote_home}/.hermes" @@ -122,7 +107,6 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): rsync_base.extend(["-e", ssh_opts]) dest_prefix = f"{self.user}@{self.host}" - # Sync individual credential files (remap /root/.hermes to detected home) for mount_entry in get_credential_file_mounts(): remote_path = mount_entry["container_path"].replace("/root/.hermes", container_base, 1) parent_dir = str(Path(remote_path).parent) @@ -136,7 +120,6 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): else: logger.debug("SSH: rsync credential failed: %s", result.stderr.strip()) - # Sync skill directories (local + external, remap to detected home) for skills_mount in get_skills_directory_mount(container_base=container_base): remote_path = skills_mount["container_path"] mkdir_cmd = self._build_ssh_command() @@ -154,152 +137,19 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): except Exception as e: logger.debug("SSH: could not sync skills/credentials: %s", e) - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - # Incremental sync before each command so mid-session credential - # refreshes and skill updates are picked up. - self._sync_skills_and_credentials() - return super().execute(command, cwd, timeout=timeout, stdin_data=stdin_data) - - _poll_interval_start: float = 0.15 # SSH: higher initial interval (150ms) for network latency - - @property - def _temp_prefix(self) -> str: - return f"/tmp/hermes-ssh-{self._session_id}" - - def _spawn_shell_process(self) -> subprocess.Popen: + def _run_bash(self, cmd_string: str, *, login: bool = False, + timeout: int = 120, + stdin_data: str | None = None) -> subprocess.Popen: + """Spawn an SSH process that runs bash on the remote host.""" cmd = self._build_ssh_command() - cmd.append("bash -l") - return subprocess.Popen( - cmd, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.DEVNULL, - text=True, - ) - - def _read_temp_files(self, *paths: str) -> list[str]: - if len(paths) == 1: - cmd = self._build_ssh_command() - cmd.append(f"cat {paths[0]} 2>/dev/null") - try: - result = subprocess.run( - cmd, capture_output=True, text=True, timeout=10, - ) - return [result.stdout] - except (subprocess.TimeoutExpired, OSError): - return [""] - - delim = f"__HERMES_SEP_{self._session_id}__" - script = "; ".join( - f"cat {p} 2>/dev/null; echo '{delim}'" for p in paths - ) - cmd = self._build_ssh_command() - cmd.append(script) - try: - result = subprocess.run( - cmd, capture_output=True, text=True, timeout=10, - ) - parts = result.stdout.split(delim + "\n") - return [parts[i] if i < len(parts) else "" for i in range(len(paths))] - except (subprocess.TimeoutExpired, OSError): - return [""] * len(paths) - - def _kill_shell_children(self): - if self._shell_pid is None: - return - cmd = self._build_ssh_command() - cmd.append(f"pkill -P {self._shell_pid} 2>/dev/null; true") - try: - subprocess.run(cmd, capture_output=True, timeout=5) - except (subprocess.TimeoutExpired, OSError): - pass - - def _cleanup_temp_files(self): - cmd = self._build_ssh_command() - cmd.append(f"rm -f {self._temp_prefix}-*") - try: - subprocess.run(cmd, capture_output=True, timeout=5) - except (subprocess.TimeoutExpired, OSError): - pass - - def _execute_oneshot(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - work_dir = cwd or self.cwd - exec_command, sudo_stdin = self._prepare_command(command) - # Keep ~ unquoted (for shell expansion) and quote only the subpath. - if work_dir == "~": - wrapped = f'cd ~ && {exec_command}' - elif work_dir.startswith("~/"): - wrapped = f'cd ~/{shlex.quote(work_dir[2:])} && {exec_command}' + if login: + cmd.extend(["bash", "-l", "-c", shlex.quote(cmd_string)]) else: - wrapped = f'cd {shlex.quote(work_dir)} && {exec_command}' - effective_timeout = timeout or self.timeout + cmd.extend(["bash", "-c", shlex.quote(cmd_string)]) - if sudo_stdin is not None and stdin_data is not None: - effective_stdin = sudo_stdin + stdin_data - elif sudo_stdin is not None: - effective_stdin = sudo_stdin - else: - effective_stdin = stdin_data - - cmd = self._build_ssh_command() - cmd.append(wrapped) - - kwargs = self._build_run_kwargs(timeout, effective_stdin) - kwargs.pop("timeout", None) - _output_chunks = [] - proc = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL, - text=True, - ) - - if effective_stdin: - try: - proc.stdin.write(effective_stdin) - proc.stdin.close() - except (BrokenPipeError, OSError): - pass - - def _drain(): - try: - for line in proc.stdout: - _output_chunks.append(line) - except Exception: - pass - - reader = threading.Thread(target=_drain, daemon=True) - reader.start() - deadline = time.monotonic() + effective_timeout - - while proc.poll() is None: - if is_interrupted(): - proc.terminate() - try: - proc.wait(timeout=1) - except subprocess.TimeoutExpired: - proc.kill() - reader.join(timeout=2) - return { - "output": "".join(_output_chunks) + "\n[Command interrupted]", - "returncode": 130, - } - if time.monotonic() > deadline: - proc.kill() - reader.join(timeout=2) - return self._timeout_result(effective_timeout) - time.sleep(0.2) - - reader.join(timeout=5) - return {"output": "".join(_output_chunks), "returncode": proc.returncode} + return _popen_bash(cmd, stdin_data) def cleanup(self): - super().cleanup() if self.control_socket.exists(): try: cmd = ["ssh", "-o", f"ControlPath={self.control_socket}", diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index 6206c4aa693..243127a2958 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -611,9 +611,7 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, docker_env = cc.get("docker_env", {}) if env_type == "local": - lc = local_config or {} - return _LocalEnvironment(cwd=cwd, timeout=timeout, - persistent=lc.get("persistent", False)) + return _LocalEnvironment(cwd=cwd, timeout=timeout) elif env_type == "docker": return _DockerEnvironment( @@ -705,7 +703,6 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, key_path=ssh_config.get("key", ""), cwd=cwd, timeout=timeout, - persistent=ssh_config.get("persistent", False), ) else: From e19252afc46ff000180005bb82a1897460b0c4b6 Mon Sep 17 00:00:00 2001 From: Teknium Date: Wed, 8 Apr 2026 17:17:41 -0700 Subject: [PATCH 02/49] fix: update tests for unified spawn-per-call execution model - Docker env tests: verify _build_init_env_args() instead of per-execute Popen flags (env forwarding is now init-time only) - Docker: preserve explicit forward_env bypass of blocklist from main - Daytona tests: adapt to SDK-native timeout, _ThreadedProcessHandle, base.py interrupt handling, HERMES_STDIN_ heredoc prefix - Modal tests: fix _load_module to include _ThreadedProcessHandle stub, check ensurepip in _resolve_modal_image instead of __init__ - SSH tests: mock time.sleep on base module instead of removed ssh import - Add missing BaseEnvironment attributes to __new__()-based test fixtures --- tests/tools/test_daytona_environment.py | 65 +++++++++----- tests/tools/test_docker_environment.py | 89 ++++++++------------ tests/tools/test_modal_sandbox_fixes.py | 10 +-- tests/tools/test_modal_snapshot_isolation.py | 38 ++++++++- tests/tools/test_ssh_environment.py | 2 +- 5 files changed, 120 insertions(+), 84 deletions(-) diff --git a/tests/tools/test_daytona_environment.py b/tests/tools/test_daytona_environment.py index 04e6347955c..7f5aa17ece2 100644 --- a/tests/tools/test_daytona_environment.py +++ b/tests/tools/test_daytona_environment.py @@ -59,8 +59,8 @@ def daytona_sdk(monkeypatch): @pytest.fixture() def make_env(daytona_sdk, monkeypatch): """Factory that creates a DaytonaEnvironment with a mocked SDK.""" - # Prevent is_interrupted from interfering - monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False) + # Prevent is_interrupted from interfering — patch where it's used (base.py) + monkeypatch.setattr("tools.environments.base.is_interrupted", lambda: False) # Prevent skills/credential sync from consuming mock exec calls monkeypatch.setattr("tools.credential_files.get_credential_file_mounts", lambda: []) monkeypatch.setattr("tools.credential_files.get_skills_directory_mount", lambda **kw: None) @@ -221,41 +221,45 @@ class TestCleanup: class TestExecute: def test_basic_command(self, make_env): sb = _make_sandbox() - # First call: $HOME detection; subsequent calls: actual commands + # Calls: (1) $HOME detection, (2) init_session bootstrap, (3) actual command sb.process.exec.side_effect = [ _make_exec_response(result="/root"), # $HOME + _make_exec_response(result="", exit_code=0), # init_session _make_exec_response(result="hello", exit_code=0), # actual cmd ] sb.state = "started" env = make_env(sandbox=sb) result = env.execute("echo hello") - assert result["output"] == "hello" + assert "hello" in result["output"] assert result["returncode"] == 0 - def test_command_wrapped_with_shell_timeout(self, make_env): + def test_sdk_timeout_passed_to_exec(self, make_env): + """SDK native timeout is passed to sandbox.process.exec().""" sb = _make_sandbox() sb.process.exec.side_effect = [ _make_exec_response(result="/root"), + _make_exec_response(result="", exit_code=0), # init_session _make_exec_response(result="ok", exit_code=0), ] sb.state = "started" env = make_env(sandbox=sb, timeout=42) env.execute("echo hello") - # The command sent to exec should be wrapped with `timeout N sh -c '...'` + # The exec call should receive timeout= kwarg (SDK native timeout) call_args = sb.process.exec.call_args_list[-1] + assert call_args[1]["timeout"] == 42 + # The command should NOT have a shell `timeout` prefix cmd = call_args[0][0] - assert cmd.startswith("timeout 42 sh -c ") - # SDK timeout param should NOT be passed - assert "timeout" not in call_args[1] + assert not cmd.startswith("timeout ") def test_timeout_returns_exit_code_124(self, make_env): - """Shell timeout utility returns exit code 124.""" + """SDK-level timeout surfaces as exit code 124 via _wait_for_process.""" sb = _make_sandbox() sb.process.exec.side_effect = [ _make_exec_response(result="/root"), - _make_exec_response(result="", exit_code=124), + _make_exec_response(result="", exit_code=0), # init_session + _make_exec_response(result="", exit_code=124), # actual cmd ] sb.state = "started" env = make_env(sandbox=sb) @@ -267,6 +271,7 @@ class TestExecute: sb = _make_sandbox() sb.process.exec.side_effect = [ _make_exec_response(result="/root"), + _make_exec_response(result="", exit_code=0), # init_session _make_exec_response(result="not found", exit_code=127), ] sb.state = "started" @@ -279,6 +284,7 @@ class TestExecute: sb = _make_sandbox() sb.process.exec.side_effect = [ _make_exec_response(result="/root"), + _make_exec_response(result="", exit_code=0), # init_session _make_exec_response(result="ok", exit_code=0), ] sb.state = "started" @@ -286,39 +292,47 @@ class TestExecute: env.execute("python3", stdin_data="print('hi')") # Check that the command passed to exec contains heredoc markers - # (single quotes get shell-escaped by shlex.quote, so check components) + # Base class uses HERMES_STDIN_ prefix for heredoc delimiters call_args = sb.process.exec.call_args_list[-1] cmd = call_args[0][0] - assert "HERMES_EOF_" in cmd + assert "HERMES_STDIN_" in cmd assert "print" in cmd assert "hi" in cmd - def test_custom_cwd_passed_through(self, make_env): + def test_custom_cwd_in_command_wrapper(self, make_env): + """CWD is handled by _wrap_command() in the command string, not as a kwarg.""" sb = _make_sandbox() sb.process.exec.side_effect = [ _make_exec_response(result="/root"), + _make_exec_response(result="", exit_code=0), # init_session _make_exec_response(result="/tmp", exit_code=0), ] sb.state = "started" env = make_env(sandbox=sb) env.execute("pwd", cwd="/tmp") - call_kwargs = sb.process.exec.call_args_list[-1][1] - assert call_kwargs["cwd"] == "/tmp" + # CWD should be embedded in the command string via _wrap_command + call_args = sb.process.exec.call_args_list[-1] + cmd = call_args[0][0] + assert "cd /tmp" in cmd + # CWD should NOT be passed as a kwarg to exec + assert "cwd" not in call_args[1] def test_daytona_error_triggers_retry(self, make_env, daytona_sdk): sb = _make_sandbox() sb.state = "started" sb.process.exec.side_effect = [ _make_exec_response(result="/root"), # $HOME + _make_exec_response(result="", exit_code=0), # init_session daytona_sdk.DaytonaError("transient"), # first attempt fails _make_exec_response(result="ok", exit_code=0), # retry succeeds ] env = make_env(sandbox=sb) result = env.execute("echo retry") - assert result["output"] == "ok" - assert result["returncode"] == 0 + # DaytonaError now surfaces directly through _ThreadedProcessHandle + # (no retry logic) — the error becomes returncode=1 + assert result["returncode"] == 1 # --------------------------------------------------------------------------- @@ -359,14 +373,18 @@ class TestInterrupt: calls["n"] += 1 if calls["n"] == 1: return _make_exec_response(result="/root") # $HOME detection + if calls["n"] == 2: + return _make_exec_response(result="", exit_code=0) # init_session event.wait(timeout=5) # simulate long-running command return _make_exec_response(result="done", exit_code=0) sb.process.exec.side_effect = exec_side_effect env = make_env(sandbox=sb) + # is_interrupted is checked by base.py's _wait_for_process, + # patch where it's actually referenced (base.py's local binding) monkeypatch.setattr( - "tools.environments.daytona.is_interrupted", lambda: True + "tools.environments.base.is_interrupted", lambda: True ) try: result = env.execute("sleep 10") @@ -377,23 +395,24 @@ class TestInterrupt: # --------------------------------------------------------------------------- -# Retry exhaustion +# DaytonaError surfaces directly (no retry) # --------------------------------------------------------------------------- class TestRetryExhausted: def test_both_attempts_fail(self, make_env, daytona_sdk): + """DaytonaError surfaces directly as rc=1 (retry logic was removed).""" sb = _make_sandbox() sb.state = "started" sb.process.exec.side_effect = [ _make_exec_response(result="/root"), # $HOME - daytona_sdk.DaytonaError("fail1"), # first attempt - daytona_sdk.DaytonaError("fail2"), # retry + _make_exec_response(result="", exit_code=0), # init_session + daytona_sdk.DaytonaError("fail1"), # actual command fails ] env = make_env(sandbox=sb) result = env.execute("echo x") + # Error surfaces directly through _ThreadedProcessHandle (rc=1) assert result["returncode"] == 1 - assert "Daytona execution error" in result["output"] # --------------------------------------------------------------------------- diff --git a/tests/tools/test_docker_environment.py b/tests/tools/test_docker_environment.py index ce98217cf85..498ef9d5066 100644 --- a/tests/tools/test_docker_environment.py +++ b/tests/tools/test_docker_environment.py @@ -245,43 +245,42 @@ def _make_execute_only_env(forward_env=None): env._timeout_result = lambda timeout: {"output": f"timed out after {timeout}", "returncode": 124} env._container_id = "test-container" env._docker_exe = "/usr/bin/docker" + # Base class attributes needed by unified execute() + env._session_id = "test123" + env._snapshot_path = "/tmp/hermes-snap-test123.sh" + env._cwd_file = "/tmp/hermes-cwd-test123.txt" + env._cwd_marker = "__HERMES_CWD_test123__" + env._snapshot_ready = True + env._last_sync_time = None + env._init_env_args = [] return env -def test_execute_uses_hermes_dotenv_for_allowlisted_env(monkeypatch): +def test_init_env_args_uses_hermes_dotenv_for_allowlisted_env(monkeypatch): + """_build_init_env_args picks up forwarded env vars from .env file at init time.""" env = _make_execute_only_env(["GITHUB_TOKEN"]) - popen_calls = [] - - def _fake_popen(cmd, **kwargs): - popen_calls.append(cmd) - return _FakePopen(cmd, **kwargs) monkeypatch.delenv("GITHUB_TOKEN", raising=False) monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {"GITHUB_TOKEN": "value_from_dotenv"}) - monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen) - result = env.execute("echo hi") + args = env._build_init_env_args() + args_str = " ".join(args) - assert result["returncode"] == 0 - assert "GITHUB_TOKEN=value_from_dotenv" in popen_calls[0] + assert "GITHUB_TOKEN=value_from_dotenv" in args_str -def test_execute_prefers_shell_env_over_hermes_dotenv(monkeypatch): +def test_init_env_args_prefers_shell_env_over_hermes_dotenv(monkeypatch): + """Shell env vars take priority over .env file values in init env args.""" env = _make_execute_only_env(["GITHUB_TOKEN"]) - popen_calls = [] - - def _fake_popen(cmd, **kwargs): - popen_calls.append(cmd) - return _FakePopen(cmd, **kwargs) monkeypatch.setenv("GITHUB_TOKEN", "value_from_shell") monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {"GITHUB_TOKEN": "value_from_dotenv"}) - monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen) - env.execute("echo hi") + args = env._build_init_env_args() + args_str = " ".join(args) - assert "GITHUB_TOKEN=value_from_shell" in popen_calls[0] - assert "GITHUB_TOKEN=value_from_dotenv" not in popen_calls[0] + assert "GITHUB_TOKEN=value_from_shell" in args_str + assert "value_from_dotenv" not in args_str # ── docker_env tests ────────────────────────────────────────────── @@ -302,64 +301,46 @@ def test_docker_env_appears_in_run_command(monkeypatch): assert "GNUPGHOME=/root/.gnupg" in run_args_str -def test_docker_env_appears_in_exec_command(monkeypatch): - """Explicit docker_env values should also be passed via -e at docker exec time.""" +def test_docker_env_appears_in_init_env_args(monkeypatch): + """Explicit docker_env values should appear in _build_init_env_args.""" env = _make_execute_only_env() env._env = {"MY_VAR": "my_value"} - popen_calls = [] - def _fake_popen(cmd, **kwargs): - popen_calls.append(cmd) - return _FakePopen(cmd, **kwargs) + args = env._build_init_env_args() + args_str = " ".join(args) - monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen) - - env.execute("echo hi") - - assert popen_calls, "Popen should have been called" - assert "MY_VAR=my_value" in popen_calls[0] + assert "MY_VAR=my_value" in args_str -def test_forward_env_overrides_docker_env(monkeypatch): +def test_forward_env_overrides_docker_env_in_init_args(monkeypatch): """docker_forward_env should override docker_env for the same key.""" env = _make_execute_only_env(forward_env=["MY_KEY"]) env._env = {"MY_KEY": "static_value"} - popen_calls = [] - - def _fake_popen(cmd, **kwargs): - popen_calls.append(cmd) - return _FakePopen(cmd, **kwargs) monkeypatch.setenv("MY_KEY", "dynamic_value") monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {}) - monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen) - env.execute("echo hi") + args = env._build_init_env_args() + args_str = " ".join(args) - cmd_str = " ".join(popen_calls[0]) - assert "MY_KEY=dynamic_value" in cmd_str - assert "MY_KEY=static_value" not in cmd_str + assert "MY_KEY=dynamic_value" in args_str + assert "MY_KEY=static_value" not in args_str -def test_docker_env_and_forward_env_merge(monkeypatch): +def test_docker_env_and_forward_env_merge_in_init_args(monkeypatch): """docker_env and docker_forward_env with different keys should both appear.""" env = _make_execute_only_env(forward_env=["TOKEN"]) env._env = {"SSH_AUTH_SOCK": "/run/user/1000/agent.sock"} - popen_calls = [] - - def _fake_popen(cmd, **kwargs): - popen_calls.append(cmd) - return _FakePopen(cmd, **kwargs) monkeypatch.setenv("TOKEN", "secret123") monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {}) - monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen) - env.execute("echo hi") + args = env._build_init_env_args() + args_str = " ".join(args) + + assert "SSH_AUTH_SOCK=/run/user/1000/agent.sock" in args_str + assert "TOKEN=secret123" in args_str - cmd_str = " ".join(popen_calls[0]) - assert "SSH_AUTH_SOCK=/run/user/1000/agent.sock" in cmd_str - assert "TOKEN=secret123" in cmd_str def test_normalize_env_dict_filters_invalid_keys(): diff --git a/tests/tools/test_modal_sandbox_fixes.py b/tests/tools/test_modal_sandbox_fixes.py index e1baf13d98f..570ef5b2182 100644 --- a/tests/tools/test_modal_sandbox_fixes.py +++ b/tests/tools/test_modal_sandbox_fixes.py @@ -231,20 +231,20 @@ class TestEnsurepipFix: """Verify the pip fix is applied in the ModalEnvironment init.""" def test_modal_environment_creates_image_with_setup_commands(self): - """ModalEnvironment.__init__ should create a modal.Image with pip fix.""" + """_resolve_modal_image should create a modal.Image with pip fix.""" try: - from tools.environments.modal import ModalEnvironment + from tools.environments.modal import _resolve_modal_image except ImportError: pytest.skip("tools.environments.modal not importable") import inspect - source = inspect.getsource(ModalEnvironment.__init__) + source = inspect.getsource(_resolve_modal_image) assert "ensurepip" in source, ( - "ModalEnvironment should include ensurepip fix " + "_resolve_modal_image should include ensurepip fix " "for Modal's legacy image builder" ) assert "setup_dockerfile_commands" in source, ( - "ModalEnvironment should use setup_dockerfile_commands " + "_resolve_modal_image should use setup_dockerfile_commands " "to fix pip before Modal's bootstrap" ) diff --git a/tests/tools/test_modal_snapshot_isolation.py b/tests/tools/test_modal_snapshot_isolation.py index a3d0eeacd72..b58454cc077 100644 --- a/tests/tools/test_modal_snapshot_isolation.py +++ b/tests/tools/test_modal_snapshot_isolation.py @@ -85,11 +85,47 @@ def _install_modal_test_modules( def _prepare_command(self, command: str): return command, None - sys.modules["tools.environments.base"] = types.SimpleNamespace(BaseEnvironment=_DummyBaseEnvironment) + def init_session(self): + pass + + # Stub _ThreadedProcessHandle: modal.py imports it but only uses it at + # runtime inside _run_bash; the snapshot-isolation tests never call _run_bash, + # so a class placeholder is sufficient. + class _DummyThreadedProcessHandle: + def __init__(self, exec_fn, cancel_fn=None): + pass + + def _load_json_store(path): + if path.exists(): + try: + return json.loads(path.read_text()) + except Exception: + pass + return {} + + def _save_json_store(path, data): + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(data, indent=2)) + + def _file_mtime_key(host_path): + try: + st = Path(host_path).stat() + return (st.st_mtime, st.st_size) + except OSError: + return None + + sys.modules["tools.environments.base"] = types.SimpleNamespace( + BaseEnvironment=_DummyBaseEnvironment, + _ThreadedProcessHandle=_DummyThreadedProcessHandle, + _load_json_store=_load_json_store, + _save_json_store=_save_json_store, + _file_mtime_key=_file_mtime_key, + ) sys.modules["tools.interrupt"] = types.SimpleNamespace(is_interrupted=lambda: False) sys.modules["tools.credential_files"] = types.SimpleNamespace( get_credential_file_mounts=lambda: [], iter_skills_files=lambda: [], + iter_cache_files=lambda: [], ) from_id_calls: list[str] = [] diff --git a/tests/tools/test_ssh_environment.py b/tests/tools/test_ssh_environment.py index 9f514e9a90c..f6ee967170f 100644 --- a/tests/tools/test_ssh_environment.py +++ b/tests/tools/test_ssh_environment.py @@ -43,7 +43,7 @@ class TestBuildSSHCommand: lambda *a, **k: MagicMock(stdout=iter([]), stderr=iter([]), stdin=MagicMock())) - monkeypatch.setattr("tools.environments.ssh.time.sleep", lambda _: None) + monkeypatch.setattr("tools.environments.base.time.sleep", lambda _: None) def test_base_flags(self): env = SSHEnvironment(host="h", user="u") From e26393ffc21cdd315b59355687400e561753bcd0 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Wed, 8 Apr 2026 17:39:45 -0700 Subject: [PATCH 03/49] fix: Signal duplicate replies with streaming + per-platform tool_progress (#6348) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #4647 — Signal replies duplicated when gateway streaming is enabled. Root cause: stream_consumer.py did not handle the case where send() returns success=True but no message_id (Signal behavior). Every stream delta produced a separate send() call (7+ messages instead of 2), plus the gateway sent another full duplicate since already_sent was never set. Changes: - stream_consumer.py: Add elif branch for success-without-message_id — enters fallback mode (sets already_sent, disables editing, sends only continuation) - signal.py send(): Extract timestamp from signal-cli RPC result as message_id so stream consumer follows normal edit→fallback path - signal.py: Add public stop_typing() delegating to _stop_typing_indicator() so base adapter's _keep_typing finally block can clean up typing tasks - gateway/run.py: Per-platform tool_progress_overrides (#6164) — lets users set e.g. signal: off while keeping telegram: all - hermes_cli/config.py: Add tool_progress_overrides to DEFAULT_CONFIG Refs: #4647, #6164 --- gateway/platforms/signal.py | 11 ++++- gateway/run.py | 10 ++++- gateway/stream_consumer.py | 11 +++++ hermes_cli/config.py | 1 + tests/gateway/test_signal.py | 63 +++++++++++++++++++++++++++ tests/gateway/test_stream_consumer.py | 54 +++++++++++++++++++++++ 6 files changed, 148 insertions(+), 2 deletions(-) diff --git a/gateway/platforms/signal.py b/gateway/platforms/signal.py index 66d455ccafd..08b62f2a6d1 100644 --- a/gateway/platforms/signal.py +++ b/gateway/platforms/signal.py @@ -647,7 +647,11 @@ class SignalAdapter(BasePlatformAdapter): if result is not None: self._track_sent_timestamp(result) - return SendResult(success=True) + # Use the timestamp from the RPC result as a pseudo message_id. + # Signal doesn't have real message IDs, but the stream consumer + # needs a truthy value to follow its edit→fallback path correctly. + _msg_id = str(result.get("timestamp", "")) if isinstance(result, dict) else None + return SendResult(success=True, message_id=_msg_id or None) return SendResult(success=False, error="RPC send failed") def _track_sent_timestamp(self, rpc_result) -> None: @@ -837,6 +841,11 @@ class SignalAdapter(BasePlatformAdapter): except asyncio.CancelledError: pass + async def stop_typing(self, chat_id: str) -> None: + """Public interface for stopping typing — called by base adapter's + _keep_typing finally block to clean up platform-level typing tasks.""" + await self._stop_typing_indicator(chat_id) + # ------------------------------------------------------------------ # Chat Info # ------------------------------------------------------------------ diff --git a/gateway/run.py b/gateway/run.py index 7a551be168d..e705597efa0 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -6308,7 +6308,15 @@ class GatewayRunner: # Falls back to env vars for backward compatibility. # YAML 1.1 parses bare `off` as boolean False — normalise before # the `or` chain so it doesn't silently fall through to "all". - _raw_tp = user_config.get("display", {}).get("tool_progress") + # + # Per-platform overrides (display.tool_progress_overrides) take + # priority over the global setting — e.g. Signal users can set + # tool_progress to "off" while keeping Telegram on "all". + _display_cfg = user_config.get("display", {}) + _overrides = _display_cfg.get("tool_progress_overrides", {}) + _raw_tp = _overrides.get(platform_key) + if _raw_tp is None: + _raw_tp = _display_cfg.get("tool_progress") if _raw_tp is False: _raw_tp = "off" progress_mode = ( diff --git a/gateway/stream_consumer.py b/gateway/stream_consumer.py index 5522c631db9..cc3d64d1360 100644 --- a/gateway/stream_consumer.py +++ b/gateway/stream_consumer.py @@ -353,6 +353,17 @@ class GatewayStreamConsumer: self._message_id = result.message_id self._already_sent = True self._last_sent_text = text + elif result.success: + # Platform accepted the message but returned no message_id + # (e.g. Signal). Can't edit without an ID — switch to + # fallback mode: suppress intermediate deltas, send only + # the missing tail once the final response is ready. + self._already_sent = True + self._edit_supported = False + self._fallback_prefix = self._clean_for_display(text) + self._fallback_final_send = True + # Sentinel prevents re-entering this branch on every delta + self._message_id = "__no_edit__" else: # Initial send failed — disable streaming for this session self._edit_supported = False diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 7c860f15936..0c39902ae71 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -392,6 +392,7 @@ DEFAULT_CONFIG = { "show_cost": False, # Show $ cost in the status bar (off by default) "skin": "default", "tool_progress_command": False, # Enable /verbose command in messaging gateway + "tool_progress_overrides": {}, # Per-platform overrides: {"signal": "off", "telegram": "all"} "tool_preview_length": 0, # Max chars for tool call previews (0 = no limit, show full paths/commands) }, diff --git a/tests/gateway/test_signal.py b/tests/gateway/test_signal.py index b2830e1fcd3..ae985300d1d 100644 --- a/tests/gateway/test_signal.py +++ b/tests/gateway/test_signal.py @@ -707,3 +707,66 @@ class TestSignalSendDocumentViaHelper: assert result.success is False assert "/nonexistent.pdf" in result.error + + +# --------------------------------------------------------------------------- +# send() returns message_id from timestamp (#4647) +# --------------------------------------------------------------------------- + +class TestSignalSendReturnsMessageId: + """Signal send() must return a timestamp-based message_id so the stream + consumer can follow its edit→fallback path correctly.""" + + @pytest.mark.asyncio + async def test_send_returns_timestamp_as_message_id(self, monkeypatch): + adapter = _make_signal_adapter(monkeypatch) + mock_rpc, _ = _stub_rpc({"timestamp": 1712345678000}) + adapter._rpc = mock_rpc + adapter._stop_typing_indicator = AsyncMock() + + result = await adapter.send(chat_id="+155****4567", content="hello") + + assert result.success is True + assert result.message_id == "1712345678000" + + @pytest.mark.asyncio + async def test_send_returns_none_message_id_when_no_timestamp(self, monkeypatch): + adapter = _make_signal_adapter(monkeypatch) + mock_rpc, _ = _stub_rpc({}) # No timestamp key + adapter._rpc = mock_rpc + adapter._stop_typing_indicator = AsyncMock() + + result = await adapter.send(chat_id="+155****4567", content="hello") + + assert result.success is True + assert result.message_id is None + + @pytest.mark.asyncio + async def test_send_returns_none_message_id_for_non_dict(self, monkeypatch): + adapter = _make_signal_adapter(monkeypatch) + mock_rpc, _ = _stub_rpc("ok") # Non-dict result + adapter._rpc = mock_rpc + adapter._stop_typing_indicator = AsyncMock() + + result = await adapter.send(chat_id="+155****4567", content="hello") + + assert result.success is True + assert result.message_id is None + + +# --------------------------------------------------------------------------- +# stop_typing() delegates to _stop_typing_indicator (#4647) +# --------------------------------------------------------------------------- + +class TestSignalStopTyping: + """Signal must expose a public stop_typing() so base adapter's + _keep_typing finally block can clean up platform-level typing tasks.""" + + @pytest.mark.asyncio + async def test_stop_typing_calls_private_method(self, monkeypatch): + adapter = _make_signal_adapter(monkeypatch) + adapter._stop_typing_indicator = AsyncMock() + + await adapter.stop_typing("+155****4567") + + adapter._stop_typing_indicator.assert_awaited_once_with("+155****4567") diff --git a/tests/gateway/test_stream_consumer.py b/tests/gateway/test_stream_consumer.py index ddc88fc2fcb..d5a20331b61 100644 --- a/tests/gateway/test_stream_consumer.py +++ b/tests/gateway/test_stream_consumer.py @@ -383,6 +383,60 @@ class TestSegmentBreakOnToolBoundary: sent_texts = [call[1]["content"] for call in adapter.send.call_args_list] assert sent_texts == ["Hello ▉", "Next segment"] + @pytest.mark.asyncio + async def test_no_message_id_enters_fallback_mode(self): + """Platform returns success but no message_id (Signal) — must not + re-send on every delta. Should enter fallback mode and send only + the continuation at finish.""" + adapter = MagicMock() + # First send succeeds but returns no message_id (Signal behavior) + send_result_no_id = SimpleNamespace(success=True, message_id=None) + # Fallback final send succeeds + send_result_final = SimpleNamespace(success=True, message_id="msg_final") + adapter.send = AsyncMock(side_effect=[send_result_no_id, send_result_final]) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True)) + adapter.MAX_MESSAGE_LENGTH = 4096 + + config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5) + consumer = GatewayStreamConsumer(adapter, "chat_123", config) + + consumer.on_delta("Hello") + task = asyncio.create_task(consumer.run()) + await asyncio.sleep(0.08) + consumer.on_delta(" world, this is a longer response.") + await asyncio.sleep(0.08) + consumer.finish() + await task + + # Should send exactly 2 messages: initial chunk + fallback continuation + # NOT one message per delta + assert adapter.send.call_count == 2 + assert consumer.already_sent + # edit_message should NOT have been called (no valid message_id to edit) + adapter.edit_message.assert_not_called() + + @pytest.mark.asyncio + async def test_no_message_id_single_delta_marks_already_sent(self): + """When the entire response fits in one delta and platform returns no + message_id, already_sent must still be True to prevent the gateway + from re-sending the full response.""" + adapter = MagicMock() + send_result = SimpleNamespace(success=True, message_id=None) + adapter.send = AsyncMock(return_value=send_result) + adapter.MAX_MESSAGE_LENGTH = 4096 + + config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5) + consumer = GatewayStreamConsumer(adapter, "chat_123", config) + + consumer.on_delta("Short response.") + consumer.finish() + + await consumer.run() + + assert consumer.already_sent + # Only one send call (the initial message) + assert adapter.send.call_count == 1 + @pytest.mark.asyncio async def test_fallback_final_splits_long_continuation_without_dropping_text(self): """Long continuation tails should be chunked when fallback final-send runs.""" From 3baafea380ec18a6179fe3e82d742557d66389e4 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Wed, 8 Apr 2026 18:07:18 -0700 Subject: [PATCH 04/49] fix(tools): skip camofox auto-cleanup when managed persistence is enabled (#6233) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When managed_persistence is enabled, cleanup_browser() was calling camofox_close() which destroys the server-side browser context via DELETE /sessions/{userId}, killing login sessions across cron runs. Add camofox_soft_cleanup() — a public wrapper that drops only the in-memory session entry when managed persistence is on, returning True. When persistence is off it returns False so the caller falls back to the full camofox_close(). The inactivity reaper still handles idle resource cleanup. Also surface a logger.warning() when _managed_persistence_enabled() fails to load config, replacing a silent except-and-return-False. Salvaged from #6182 by el-analista (Eduardo Perea Fernandez). Added public API wrapper to avoid cross-module private imports, and test coverage for both persistence paths. Co-authored-by: Eduardo Perea Fernandez From 42e366f27bd37ee72a006029f920c082f32d0018 Mon Sep 17 00:00:00 2001 From: konsisumer Date: Wed, 8 Apr 2026 15:36:59 +0200 Subject: [PATCH 05/49] fix(agent): respect config timeout for flush_memories instead of hardcoded 30s The _call_llm() and direct OpenAI fallback paths in flush_memories() both hardcoded timeout=30.0, ignoring the user-configurable value at auxiliary.flush_memories.timeout in config.yaml. Remove the explicit timeout from the auxiliary _call_llm() call so that _get_task_timeout('flush_memories') reads from config. For the direct OpenAI fallback, import and use _get_task_timeout() instead of the hardcoded value. Add two regression tests verifying both code paths respect the config. Fixes #6154 --- run_agent.py | 7 ++- tests/run_agent/test_flush_memories_codex.py | 55 ++++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/run_agent.py b/run_agent.py index f57072e9e53..e511f4088ea 100644 --- a/run_agent.py +++ b/run_agent.py @@ -5864,7 +5864,7 @@ class AIAgent: tools=[memory_tool_def], temperature=0.3, max_tokens=5120, - timeout=30.0, + # timeout resolved from auxiliary.flush_memories.timeout config ) except RuntimeError: _aux_available = False @@ -5896,7 +5896,10 @@ class AIAgent: "temperature": 0.3, **self._max_tokens_param(5120), } - response = self._ensure_primary_openai_client(reason="flush_memories").chat.completions.create(**api_kwargs, timeout=30.0) + from agent.auxiliary_client import _get_task_timeout + response = self._ensure_primary_openai_client(reason="flush_memories").chat.completions.create( + **api_kwargs, timeout=_get_task_timeout("flush_memories") + ) # Extract tool calls from the response, handling all API formats tool_calls = [] diff --git a/tests/run_agent/test_flush_memories_codex.py b/tests/run_agent/test_flush_memories_codex.py index 3d12c9d3eac..b4b3c648e65 100644 --- a/tests/run_agent/test_flush_memories_codex.py +++ b/tests/run_agent/test_flush_memories_codex.py @@ -91,6 +91,61 @@ def _chat_response_with_memory_call(): ) +class TestFlushMemoriesRespectsConfigTimeout: + """flush_memories() must NOT hardcode timeout=30.0 — it should defer + to the config value via auxiliary.flush_memories.timeout.""" + + def test_auxiliary_path_omits_explicit_timeout(self, monkeypatch): + """When calling _call_llm, timeout should NOT be passed so that + _get_task_timeout('flush_memories') reads from config.""" + agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter") + + mock_response = _chat_response_with_memory_call() + + with patch("agent.auxiliary_client.call_llm", return_value=mock_response) as mock_call: + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + {"role": "user", "content": "Note this"}, + ] + with patch("tools.memory_tool.memory_tool", return_value="Saved."): + agent.flush_memories(messages) + + mock_call.assert_called_once() + call_kwargs = mock_call.call_args + # timeout must NOT be explicitly passed (so _get_task_timeout resolves it) + assert "timeout" not in call_kwargs.kwargs, ( + "flush_memories should not pass explicit timeout to _call_llm; " + "let _get_task_timeout('flush_memories') resolve from config" + ) + + def test_fallback_path_uses_config_timeout(self, monkeypatch): + """When auxiliary client is unavailable and we fall back to direct + OpenAI client, timeout should come from _get_task_timeout, not hardcoded.""" + agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter") + agent.client = MagicMock() + agent.client.chat.completions.create.return_value = _chat_response_with_memory_call() + + custom_timeout = 180.0 + + with patch("agent.auxiliary_client.call_llm", side_effect=RuntimeError("no provider")), \ + patch("agent.auxiliary_client._get_task_timeout", return_value=custom_timeout) as mock_gtt, \ + patch("tools.memory_tool.memory_tool", return_value="Saved."): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + {"role": "user", "content": "Save this"}, + ] + agent.flush_memories(messages) + + mock_gtt.assert_called_once_with("flush_memories") + agent.client.chat.completions.create.assert_called_once() + call_kwargs = agent.client.chat.completions.create.call_args + assert call_kwargs.kwargs.get("timeout") == custom_timeout, ( + f"Expected timeout={custom_timeout} from config, got {call_kwargs.kwargs.get('timeout')}" + ) + + class TestFlushMemoriesUsesAuxiliaryClient: """When an auxiliary client is available, flush_memories should use it instead of self.client -- especially critical in Codex mode.""" From 6e3f7f3610e0cedd52f339e80c9fedd4d2c7880b Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Wed, 8 Apr 2026 19:04:21 -0700 Subject: [PATCH 06/49] docs: add tool_progress_overrides to configuration reference (#6364) Documents the per-platform tool_progress_overrides config key added in PR #6348. Shows example YAML with Signal set to 'off' while Telegram stays on 'verbose'. Lists all valid platform keys. --- website/docs/user-guide/configuration.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/website/docs/user-guide/configuration.md b/website/docs/user-guide/configuration.md index a31fb700bca..4431e068285 100644 --- a/website/docs/user-guide/configuration.md +++ b/website/docs/user-guide/configuration.md @@ -809,6 +809,7 @@ This controls both the `text_to_speech` tool and spoken replies in voice mode (` display: tool_progress: all # off | new | all | verbose tool_progress_command: false # Enable /verbose slash command in messaging gateway + tool_progress_overrides: {} # Per-platform overrides (see below) skin: default # Built-in or custom CLI skin (see user-guide/features/skins) personality: "kawaii" # Legacy cosmetic field still surfaced in some summaries compact: false # Compact output mode (less whitespace) @@ -829,6 +830,21 @@ display: In the CLI, cycle through these modes with `/verbose`. To use `/verbose` in messaging platforms (Telegram, Discord, Slack, etc.), set `tool_progress_command: true` in the `display` section above. The command will then cycle the mode and save to config. +### Per-platform progress overrides + +Different platforms have different verbosity needs. For example, Signal can't edit messages, so each progress update becomes a separate message — noisy. Use `tool_progress_overrides` to set per-platform modes: + +```yaml +display: + tool_progress: all # global default + tool_progress_overrides: + signal: 'off' # silence progress on Signal + telegram: verbose # detailed progress on Telegram + slack: 'off' # quiet in shared Slack workspace +``` + +Platforms without an override fall back to the global `tool_progress` value. Valid platform keys: `telegram`, `discord`, `slack`, `signal`, `whatsapp`, `matrix`, `mattermost`, `email`, `sms`, `homeassistant`, `dingtalk`, `feishu`, `wecom`. + ## Privacy ```yaml From ae4a884e8dfc5cccf7303f1270d3d913791ae960 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Wed, 8 Apr 2026 19:53:39 -0700 Subject: [PATCH 07/49] fix(agent): disable stale stream timeout for local providers (#6368) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Local inference providers (Ollama, oMLX, llama-cpp) can take 300+ seconds for prefill on large contexts. The 180s stale stream detector was killing these connections while the provider was still processing. Uses the existing is_local_endpoint() (proper URL parsing with RFC-1918, localhost, WSL detection) instead of ad-hoc substring matching. The stale timeout is only disabled when the user hasn't explicitly set HERMES_STREAM_STALE_TIMEOUT — explicit user config is always honored. Fixes #5889 --- run_agent.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/run_agent.py b/run_agent.py index e511f4088ea..10932b4bafa 100644 --- a/run_agent.py +++ b/run_agent.py @@ -4728,18 +4728,25 @@ class AIAgent: self._close_request_openai_client(request_client, reason="stream_request_complete") _stream_stale_timeout_base = float(os.getenv("HERMES_STREAM_STALE_TIMEOUT", 180.0)) - # Scale the stale timeout for large contexts: slow models (like Opus) - # can legitimately think for minutes before producing the first token - # when the context is large. Without this, the stale detector kills - # healthy connections during the model's thinking phase, producing - # spurious RemoteProtocolError ("peer closed connection"). - _est_tokens = sum(len(str(v)) for v in api_kwargs.get("messages", [])) // 4 - if _est_tokens > 100_000: - _stream_stale_timeout = max(_stream_stale_timeout_base, 300.0) - elif _est_tokens > 50_000: - _stream_stale_timeout = max(_stream_stale_timeout_base, 240.0) + # Local providers (Ollama, oMLX, llama-cpp) can take 300+ seconds + # for prefill on large contexts. Disable the stale detector unless + # the user explicitly set HERMES_STREAM_STALE_TIMEOUT. + if _stream_stale_timeout_base == 180.0 and self.base_url and is_local_endpoint(self.base_url): + _stream_stale_timeout = float("inf") + logger.debug("Local provider detected (%s) — stale stream timeout disabled", self.base_url) else: - _stream_stale_timeout = _stream_stale_timeout_base + # Scale the stale timeout for large contexts: slow models (like Opus) + # can legitimately think for minutes before producing the first token + # when the context is large. Without this, the stale detector kills + # healthy connections during the model's thinking phase, producing + # spurious RemoteProtocolError ("peer closed connection"). + _est_tokens = sum(len(str(v)) for v in api_kwargs.get("messages", [])) // 4 + if _est_tokens > 100_000: + _stream_stale_timeout = max(_stream_stale_timeout_base, 300.0) + elif _est_tokens > 50_000: + _stream_stale_timeout = max(_stream_stale_timeout_base, 240.0) + else: + _stream_stale_timeout = _stream_stale_timeout_base t = threading.Thread(target=_call, daemon=True) t.start() From 980fadfea9dbe7906a70e3f1fe376559de476728 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Wed, 8 Apr 2026 19:58:16 -0700 Subject: [PATCH 08/49] fix(models): preserve OpenRouter variant tags (:free, :extended, :fast) during model switch (#6383) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Step c in switch_model() blindly converted the first colon to a slash for aggregator providers, even when the model name already contained a slash (vendor/model format). This mangled variant tags like :free into /free, causing 400 Bad Request from the API. Fix: skip the colon→slash conversion when the model already has a slash, since the colon is a variant tag, not a vendor separator. The module docstring already documented this intent (line 17-18) but the implementation didn't enforce it. Reported via Discord. Related to PR #6088 (which identified the same bug but placed the fix in model_normalize.py instead of model_switch.py where the actual mangling occurs). --- hermes_cli/model_switch.py | 5 +- .../test_model_switch_variant_tags.py | 70 +++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) create mode 100644 tests/hermes_cli/test_model_switch_variant_tags.py diff --git a/hermes_cli/model_switch.py b/hermes_cli/model_switch.py index 07efbcf4a64..7d120d94f1b 100644 --- a/hermes_cli/model_switch.py +++ b/hermes_cli/model_switch.py @@ -537,8 +537,11 @@ def switch_model( ) else: # --- Step c: On aggregator, convert vendor:model to vendor/model --- + # Only convert when there's no slash — a slash means the name + # is already in vendor/model format and the colon is a variant + # tag (:free, :extended, :fast) that must be preserved. colon_pos = raw_input.find(":") - if colon_pos > 0 and is_aggregator(current_provider): + if colon_pos > 0 and "/" not in raw_input and is_aggregator(current_provider): left = raw_input[:colon_pos].strip().lower() right = raw_input[colon_pos + 1:].strip() if left and right: diff --git a/tests/hermes_cli/test_model_switch_variant_tags.py b/tests/hermes_cli/test_model_switch_variant_tags.py new file mode 100644 index 00000000000..eebb5dc139c --- /dev/null +++ b/tests/hermes_cli/test_model_switch_variant_tags.py @@ -0,0 +1,70 @@ +"""Tests for OpenRouter variant tag preservation in model switching. + +Regression test for GitHub PR #6088 / Discord report: OpenRouter model IDs +with variant suffixes like ``:free``, ``:extended``, ``:fast`` were being +mangled by the colon-to-slash conversion in model_switch.py Step c. + +The fix: Step c now skips colon→slash conversion when the model name already +contains a forward slash (i.e. is already in ``vendor/model`` format), since +the colon is a variant tag, not a vendor separator. +""" +import pytest +from unittest.mock import patch + +from hermes_cli.model_switch import switch_model + + +# Shared mock context — skip network calls, credential resolution, catalog lookups +_MOCK_VALIDATION = {"accepted": True, "persist": True, "recognized": True, "message": None} + + +def _run_switch(raw_input: str, current_provider: str = "openrouter") -> str: + """Run switch_model with mocked dependencies, return the resolved model name.""" + with patch("hermes_cli.model_switch.resolve_alias", return_value=None), \ + patch("hermes_cli.model_switch.list_provider_models", return_value=[]), \ + patch("hermes_cli.runtime_provider.resolve_runtime_provider", + return_value={"api_key": "test", "base_url": "", "api_mode": "chat_completions"}), \ + patch("hermes_cli.models.validate_requested_model", return_value=_MOCK_VALIDATION), \ + patch("hermes_cli.model_switch.get_model_info", return_value=None), \ + patch("hermes_cli.model_switch.get_model_capabilities", return_value=None), \ + patch("hermes_cli.models.detect_provider_for_model", return_value=None): + result = switch_model( + raw_input=raw_input, + current_provider=current_provider, + current_model="anthropic/claude-sonnet-4.6", + ) + assert result.success, f"switch_model failed: {result.error_message}" + return result.new_model + + +class TestVariantTagPreservation: + """OpenRouter variant tags (:free, :extended, :fast) must survive model switching.""" + + @pytest.mark.parametrize("model,expected", [ + ("nvidia/nemotron-3-super-120b-a12b:free", "nvidia/nemotron-3-super-120b-a12b:free"), + ("anthropic/claude-sonnet-4.6:extended", "anthropic/claude-sonnet-4.6:extended"), + ("meta-llama/llama-4-maverick:fast", "meta-llama/llama-4-maverick:fast"), + ]) + def test_slash_format_preserves_variant_tag(self, model, expected): + """Models already in vendor/model:tag format must not have their tag mangled.""" + assert _run_switch(model) == expected + + def test_legacy_colon_format_converts_to_slash(self): + """Legacy vendor:model (no slash) should still be converted to vendor/model.""" + result = _run_switch("nvidia:nemotron-3-super-120b-a12b") + assert result == "nvidia/nemotron-3-super-120b-a12b" + + def test_legacy_colon_format_with_tag_converts_first_colon_only(self): + """vendor:model:free (no slash) → vendor/model:free — first colon becomes slash.""" + result = _run_switch("nvidia:nemotron-3-super-120b-a12b:free") + assert result == "nvidia/nemotron-3-super-120b-a12b:free" + + def test_bare_model_name_unaffected(self): + """Bare model names without colons or slashes should work normally.""" + result = _run_switch("claude-sonnet-4.6") + assert result == "anthropic/claude-sonnet-4.6" + + def test_already_correct_slug_no_tag(self): + """Standard vendor/model slugs without tags pass through unchanged.""" + result = _run_switch("anthropic/claude-sonnet-4.6") + assert result == "anthropic/claude-sonnet-4.6" From 092061711e0091cc5f7e4608781f91bffb69f000 Mon Sep 17 00:00:00 2001 From: Helmi Date: Wed, 8 Apr 2026 21:39:27 +0200 Subject: [PATCH 09/49] fix(gateway): add staged inactivity warning before timeout escalation Introduce gateway_timeout_warning (default 900s) as a pre-timeout alert layer. When inactivity reaches the warning threshold, a single notification is sent to the user offering to wait or reset. If inactivity continues to the gateway_timeout (default 1800s), the full timeout fires as before. This gives users a chance to intervene before work is lost on slow API providers without disabling the safety timeout entirely. Config: agent.gateway_timeout_warning in config.yaml, or HERMES_AGENT_TIMEOUT_WARNING env var (0 = disable warning). --- cli-config.yaml.example | 10 + gateway/run.py | 23 ++ hermes_cli/config.py | 4 + .../test_gateway_inactivity_timeout.py | 315 ++++++++++++++++++ 4 files changed, 352 insertions(+) create mode 100644 tests/gateway/test_gateway_inactivity_timeout.py diff --git a/cli-config.yaml.example b/cli-config.yaml.example index 14d764d7d15..af0917dedc6 100644 --- a/cli-config.yaml.example +++ b/cli-config.yaml.example @@ -445,6 +445,16 @@ agent: # Higher = more room for complex tasks, but costs more tokens # Recommended: 20-30 for focused tasks, 50-100 for open exploration max_turns: 60 + + # Inactivity timeout for gateway agent runs (seconds, 0 = unlimited). + # The agent can run indefinitely when actively calling tools or receiving + # API responses. Only fires after the agent has been idle for this duration. + # gateway_timeout: 1800 + + # Staged warning: send a warning before escalating to full timeout. + # Fires once per run when inactivity reaches this threshold (seconds). + # Set to 0 to disable the warning. + # gateway_timeout_warning: 900 # Enable verbose logging verbose: false diff --git a/gateway/run.py b/gateway/run.py index e705597efa0..ddb57bd7a9c 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -184,6 +184,8 @@ if _config_path.exists(): # Env var from .env takes precedence (already in os.environ). if "gateway_timeout" in _agent_cfg and "HERMES_AGENT_TIMEOUT" not in os.environ: os.environ["HERMES_AGENT_TIMEOUT"] = str(_agent_cfg["gateway_timeout"]) + if "gateway_timeout_warning" in _agent_cfg and "HERMES_AGENT_TIMEOUT_WARNING" not in os.environ: + os.environ["HERMES_AGENT_TIMEOUT_WARNING"] = str(_agent_cfg["gateway_timeout_warning"]) # Timezone: bridge config.yaml → HERMES_TIMEZONE env var. # HERMES_TIMEZONE from .env takes precedence (already in os.environ). _tz_cfg = _cfg.get("timezone", "") @@ -7114,6 +7116,9 @@ class GatewayRunner: # Default 1800s (30 min inactivity). 0 = unlimited. _agent_timeout_raw = float(os.getenv("HERMES_AGENT_TIMEOUT", 1800)) _agent_timeout = _agent_timeout_raw if _agent_timeout_raw > 0 else None + _agent_warning_raw = float(os.getenv("HERMES_AGENT_TIMEOUT_WARNING", 900)) + _agent_warning = _agent_warning_raw if _agent_warning_raw > 0 else None + _warning_fired = False loop = asyncio.get_event_loop() _executor_task = asyncio.ensure_future( loop.run_in_executor(None, run_sync) @@ -7146,6 +7151,24 @@ class GatewayRunner: _idle_secs = _act.get("seconds_since_activity", 0.0) except Exception: pass + # Staged warning: fire once before escalating to full timeout. + if (not _warning_fired and _agent_warning is not None + and _idle_secs >= _agent_warning): + _warning_fired = True + _warn_adapter = self.adapters.get(source.platform) + if _warn_adapter: + _warn_mins = int(_agent_warning // 60) or 1 + try: + await _warn_adapter.send( + source.chat_id, + f"⚠️ No activity for {_warn_mins} min. " + f"If the agent does not respond soon, it will " + f"be timed out in {_warn_mins} min. " + f"You can continue waiting or use /reset.", + metadata=_status_thread_metadata, + ) + except Exception: + logger.debug("Inactivity warning send error: %s", _ne) if _idle_secs >= _agent_timeout: _inactivity_timeout = True break diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 0c39902ae71..8b5da35220d 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -230,6 +230,10 @@ DEFAULT_CONFIG = { # (force on/off for all models), or a list of model-name substrings # to match (e.g. ["gpt", "codex", "gemini", "qwen"]). "tool_use_enforcement": "auto", + # Staged inactivity warning: send a warning to the user at this + # threshold before escalating to a full timeout. The warning fires + # once per run and does not interrupt the agent. 0 = disable warning. + "gateway_timeout_warning": 900, }, "terminal": { diff --git a/tests/gateway/test_gateway_inactivity_timeout.py b/tests/gateway/test_gateway_inactivity_timeout.py new file mode 100644 index 00000000000..598f33817cd --- /dev/null +++ b/tests/gateway/test_gateway_inactivity_timeout.py @@ -0,0 +1,315 @@ +"""Tests for staged inactivity timeout in gateway agent runs. + +Tests cover: +- Warning fires once when inactivity reaches gateway_timeout_warning threshold +- Warning does not fire when gateway_timeout is 0 (unlimited) +- Warning fires only once per run, not on every poll +- Full timeout still fires at gateway_timeout threshold +- Warning respects HERMES_AGENT_TIMEOUT_WARNING env var +- Warning disabled when gateway_timeout_warning is 0 +""" + +import concurrent.futures +import os +import sys +import time +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + + +class FakeAgent: + """Mock agent with controllable activity summary for timeout tests.""" + + def __init__(self, idle_seconds=0.0, activity_desc="tool_call", + current_tool=None, api_call_count=5, max_iterations=90): + self._idle_seconds = idle_seconds + self._activity_desc = activity_desc + self._current_tool = current_tool + self._api_call_count = api_call_count + self._max_iterations = max_iterations + self._interrupted = False + self._interrupt_msg = None + + def get_activity_summary(self): + return { + "last_activity_ts": time.time() - self._idle_seconds, + "last_activity_desc": self._activity_desc, + "seconds_since_activity": self._idle_seconds, + "current_tool": self._current_tool, + "api_call_count": self._api_call_count, + "max_iterations": self._max_iterations, + } + + def interrupt(self, msg): + self._interrupted = True + self._interrupt_msg = msg + + def run_conversation(self, prompt): + return {"final_response": "Done", "messages": []} + + +class SlowFakeAgent(FakeAgent): + """Agent that runs for a while, then goes idle.""" + + def __init__(self, run_duration=0.5, idle_after=None, **kwargs): + super().__init__(**kwargs) + self._run_duration = run_duration + self._idle_after = idle_after + self._start_time = None + + def get_activity_summary(self): + summary = super().get_activity_summary() + if self._idle_after is not None and self._start_time: + elapsed = time.time() - self._start_time + if elapsed > self._idle_after: + idle_time = elapsed - self._idle_after + summary["seconds_since_activity"] = idle_time + summary["last_activity_desc"] = "api_call_streaming" + else: + summary["seconds_since_activity"] = 0.0 + return summary + + def run_conversation(self, prompt): + self._start_time = time.time() + time.sleep(self._run_duration) + return {"final_response": "Completed after work", "messages": []} + + +class TestStagedInactivityWarning: + """Test the staged inactivity warning before full timeout.""" + + def test_warning_fires_once_before_timeout(self): + """Warning fires when inactivity reaches warning threshold.""" + agent = SlowFakeAgent( + run_duration=10.0, + idle_after=0.1, + activity_desc="api_call_streaming", + ) + + _agent_timeout = 20.0 + _agent_warning = 5.0 + _POLL_INTERVAL = 0.1 + + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = pool.submit(agent.run_conversation, "test prompt") + _inactivity_timeout = False + _warning_fired = False + _warning_send_count = 0 + + while True: + done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL) + if done: + result = future.result() + break + _idle_secs = 0.0 + if hasattr(agent, "get_activity_summary"): + try: + _act = agent.get_activity_summary() + _idle_secs = _act.get("seconds_since_activity", 0.0) + except Exception: + pass + if (not _warning_fired and _agent_warning > 0 + and _idle_secs >= _agent_warning): + _warning_fired = True + _warning_send_count += 1 + if _idle_secs >= _agent_timeout: + _inactivity_timeout = True + break + + pool.shutdown(wait=False, cancel_futures=True) + + assert _warning_fired + assert _warning_send_count == 1 + assert not _inactivity_timeout + + def test_warning_disabled_when_zero(self): + """No warning fires when gateway_timeout_warning is 0.""" + agent = SlowFakeAgent( + run_duration=5.0, + idle_after=0.1, + ) + + _agent_timeout = 20.0 + _agent_warning = 0.0 + _POLL_INTERVAL = 0.1 + + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = pool.submit(agent.run_conversation, "test") + _warning_fired = False + + while True: + done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL) + if done: + future.result() + break + _idle_secs = 0.0 + if hasattr(agent, "get_activity_summary"): + try: + _act = agent.get_activity_summary() + _idle_secs = _act.get("seconds_since_activity", 0.0) + except Exception: + pass + if (not _warning_fired and _agent_warning > 0 + and _idle_secs >= _agent_warning): + _warning_fired = True + if _idle_secs >= _agent_timeout: + break + + pool.shutdown(wait=False, cancel_futures=True) + assert not _warning_fired + + def test_warning_fires_only_once(self): + """Warning fires exactly once even if agent remains idle.""" + agent = SlowFakeAgent( + run_duration=10.0, + idle_after=0.05, + ) + + _agent_timeout = 20.0 + _agent_warning = 0.2 + _POLL_INTERVAL = 0.05 + + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = pool.submit(agent.run_conversation, "test") + _warning_count = 0 + + while True: + done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL) + if done: + future.result() + break + _idle_secs = 0.0 + if hasattr(agent, "get_activity_summary"): + try: + _act = agent.get_activity_summary() + _idle_secs = _act.get("seconds_since_activity", 0.0) + except Exception: + pass + if (not _warning_count and _agent_warning > 0 + and _idle_secs >= _agent_warning): + _warning_count += 1 + if _idle_secs >= _agent_timeout: + break + + pool.shutdown(wait=False, cancel_futures=True) + assert _warning_count == 1 + + def test_full_timeout_still_fires_after_warning(self): + """Full timeout fires even after warning was sent.""" + agent = SlowFakeAgent( + run_duration=15.0, + idle_after=0.1, + activity_desc="waiting for provider response (streaming)", + ) + + _agent_timeout = 1.0 + _agent_warning = 0.3 + _POLL_INTERVAL = 0.05 + + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = pool.submit(agent.run_conversation, "test") + _inactivity_timeout = False + _warning_fired = False + + while True: + done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL) + if done: + future.result() + break + _idle_secs = 0.0 + if hasattr(agent, "get_activity_summary"): + try: + _act = agent.get_activity_summary() + _idle_secs = _act.get("seconds_since_activity", 0.0) + except Exception: + pass + if (not _warning_fired and _agent_warning > 0 + and _idle_secs >= _agent_warning): + _warning_fired = True + if _idle_secs >= _agent_timeout: + _inactivity_timeout = True + break + + pool.shutdown(wait=False, cancel_futures=True) + assert _warning_fired + assert _inactivity_timeout + + def test_warning_env_var_respected(self, monkeypatch): + """HERMES_AGENT_TIMEOUT_WARNING env var is parsed correctly.""" + monkeypatch.setenv("HERMES_AGENT_TIMEOUT_WARNING", "600") + _warning = float(os.getenv("HERMES_AGENT_TIMEOUT_WARNING", 900)) + assert _warning == 600.0 + + def test_warning_zero_means_disabled(self, monkeypatch): + """HERMES_AGENT_TIMEOUT_WARNING=0 disables the warning.""" + monkeypatch.setenv("HERMES_AGENT_TIMEOUT_WARNING", "0") + _raw = float(os.getenv("HERMES_AGENT_TIMEOUT_WARNING", 900)) + _warning = _raw if _raw > 0 else None + assert _warning is None + + def test_unlimited_timeout_no_warning(self): + """When timeout is unlimited (0), no warning fires either.""" + agent = SlowFakeAgent( + run_duration=0.5, + idle_after=0.0, + ) + + _agent_timeout = None + _agent_warning = 5.0 + _POLL_INTERVAL = 0.05 + + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = pool.submit(agent.run_conversation, "test") + + result = future.result(timeout=2.0) + pool.shutdown(wait=False) + + assert result["final_response"] == "Completed after work" + + +class TestWarningThresholdBelowTimeout: + """Test that warning threshold must be less than timeout threshold.""" + + def test_warning_at_half_timeout(self): + """Warning fires at half the timeout duration.""" + agent = SlowFakeAgent( + run_duration=10.0, + idle_after=0.1, + activity_desc="receiving stream response", + ) + + _agent_timeout = 2.0 + _agent_warning = 1.0 + _POLL_INTERVAL = 0.05 + + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = pool.submit(agent.run_conversation, "test") + _warning_fired = False + _timeout_fired = False + + while True: + done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL) + if done: + future.result() + break + _idle_secs = 0.0 + if hasattr(agent, "get_activity_summary"): + try: + _act = agent.get_activity_summary() + _idle_secs = _act.get("seconds_since_activity", 0.0) + except Exception: + pass + if (not _warning_fired and _agent_warning > 0 + and _idle_secs >= _agent_warning): + _warning_fired = True + if _idle_secs >= _agent_timeout: + _timeout_fired = True + break + + pool.shutdown(wait=False, cancel_futures=True) + assert _warning_fired + assert _timeout_fired From af4abd2f2253bee78905493f565f4a3f99e1aec0 Mon Sep 17 00:00:00 2001 From: Teknium Date: Wed, 8 Apr 2026 19:59:44 -0700 Subject: [PATCH 10/49] fix: correct unbound exception variable and remaining-time math in warning - Bind exception in warning send handler (was using stale _ne from outer scope) - Calculate remaining time until timeout correctly: (timeout - warning) // 60 instead of warning // 60 (which equals elapsed time, not remaining) --- gateway/run.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/gateway/run.py b/gateway/run.py index ddb57bd7a9c..3333608199f 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -7157,18 +7157,19 @@ class GatewayRunner: _warning_fired = True _warn_adapter = self.adapters.get(source.platform) if _warn_adapter: - _warn_mins = int(_agent_warning // 60) or 1 + _elapsed_warn = int(_agent_warning // 60) or 1 + _remaining_mins = int((_agent_timeout - _agent_warning) // 60) or 1 try: await _warn_adapter.send( source.chat_id, - f"⚠️ No activity for {_warn_mins} min. " + f"⚠️ No activity for {_elapsed_warn} min. " f"If the agent does not respond soon, it will " - f"be timed out in {_warn_mins} min. " + f"be timed out in {_remaining_mins} min. " f"You can continue waiting or use /reset.", metadata=_status_thread_metadata, ) - except Exception: - logger.debug("Inactivity warning send error: %s", _ne) + except Exception as _warn_err: + logger.debug("Inactivity warning send error: %s", _warn_err) if _idle_secs >= _agent_timeout: _inactivity_timeout = True break From 8567031433b1d4b091d7500a3b5d06dec1e89fc8 Mon Sep 17 00:00:00 2001 From: SHL0MS Date: Mon, 30 Mar 2026 23:34:58 -0400 Subject: [PATCH 11/49] =?UTF-8?q?fix:=20improve=20context=20compression=20?= =?UTF-8?q?quality=20=E2=80=94=20named=20constants,=20tool=20tracking,=20d?= =?UTF-8?q?egradation=20warning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three targeted improvements to the compression system: 1. Replace hardcoded truncation limits with named class constants (_CONTENT_MAX=6000, _CONTENT_HEAD=4000, _CONTENT_TAIL=1500, _TOOL_ARGS_MAX=1500, _TOOL_ARGS_HEAD=1200). Previous limits (3000/500) heavily truncated the summarizer's input — a 200-line edit got cut to 3000 chars before the summarizer ever saw it. 2. Add '## Tools & Patterns' section to both compression prompt templates (first-pass and iterative). Preserves working tool invocations, preferred flags, and tool-specific discoveries across compaction boundaries. 3. Warn users on 2nd+ compression: 'Session compressed N times — accuracy may degrade. Consider /new to start fresh.' Ref #499 --- agent/context_compressor.py | 39 +++++++++++++++++++++++++------------ run_agent.py | 9 +++++++++ 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/agent/context_compressor.py b/agent/context_compressor.py index 0d971e4b569..0cd51b06eff 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -199,30 +199,39 @@ class ContextCompressor: budget = int(content_tokens * _SUMMARY_RATIO) return max(_MIN_SUMMARY_TOKENS, min(budget, self.max_summary_tokens)) + # Truncation limits for the summarizer input. These bound how much of + # each message the summary model sees — the budget is the *summary* + # model's context window, not the main model's. + _CONTENT_MAX = 6000 # total chars per message body + _CONTENT_HEAD = 4000 # chars kept from the start + _CONTENT_TAIL = 1500 # chars kept from the end + _TOOL_ARGS_MAX = 1500 # tool call argument chars + _TOOL_ARGS_HEAD = 1200 # kept from the start of tool args + def _serialize_for_summary(self, turns: List[Dict[str, Any]]) -> str: """Serialize conversation turns into labeled text for the summarizer. - Includes tool call arguments and result content (up to 3000 chars - per message) so the summarizer can preserve specific details like - file paths, commands, and outputs. + Includes tool call arguments and result content (up to + ``_CONTENT_MAX`` chars per message) so the summarizer can preserve + specific details like file paths, commands, and outputs. """ parts = [] for msg in turns: role = msg.get("role", "unknown") content = msg.get("content") or "" - # Tool results: keep more content than before (3000 chars) + # Tool results: keep enough content for the summarizer if role == "tool": tool_id = msg.get("tool_call_id", "") - if len(content) > 3000: - content = content[:2000] + "\n...[truncated]...\n" + content[-800:] + if len(content) > self._CONTENT_MAX: + content = content[:self._CONTENT_HEAD] + "\n...[truncated]...\n" + content[-self._CONTENT_TAIL:] parts.append(f"[TOOL RESULT {tool_id}]: {content}") continue # Assistant messages: include tool call names AND arguments if role == "assistant": - if len(content) > 3000: - content = content[:2000] + "\n...[truncated]...\n" + content[-800:] + if len(content) > self._CONTENT_MAX: + content = content[:self._CONTENT_HEAD] + "\n...[truncated]...\n" + content[-self._CONTENT_TAIL:] tool_calls = msg.get("tool_calls", []) if tool_calls: tc_parts = [] @@ -232,8 +241,8 @@ class ContextCompressor: name = fn.get("name", "?") args = fn.get("arguments", "") # Truncate long arguments but keep enough for context - if len(args) > 500: - args = args[:400] + "..." + if len(args) > self._TOOL_ARGS_MAX: + args = args[:self._TOOL_ARGS_HEAD] + "..." tc_parts.append(f" {name}({args})") else: fn = getattr(tc, "function", None) @@ -244,8 +253,8 @@ class ContextCompressor: continue # User and other roles - if len(content) > 3000: - content = content[:2000] + "\n...[truncated]...\n" + content[-800:] + if len(content) > self._CONTENT_MAX: + content = content[:self._CONTENT_HEAD] + "\n...[truncated]...\n" + content[-self._CONTENT_TAIL:] parts.append(f"[{role.upper()}]: {content}") return "\n\n".join(parts) @@ -310,6 +319,9 @@ Update the summary using this exact structure. PRESERVE all existing information ## Critical Context [Any specific values, error messages, configuration details, or data that would be lost without explicit preservation] +## Tools & Patterns +[Which tools were used, how they were used effectively, and any tool-specific discoveries. Accumulate across compactions.] + Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions. Write only the summary body. Do not include any preamble or prefix.""" @@ -348,6 +360,9 @@ Use this exact structure: ## Critical Context [Any specific values, error messages, configuration details, or data that would be lost without explicit preservation] +## Tools & Patterns +[Which tools were used, how they were used effectively, and any tool-specific discoveries (e.g., preferred flags, working invocations, successful command patterns)] + Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions. The goal is to prevent the next assistant from repeating work or losing important details. Write only the summary body. Do not include any preamble or prefix.""" diff --git a/run_agent.py b/run_agent.py index 10932b4bafa..b473b825ecb 100644 --- a/run_agent.py +++ b/run_agent.py @@ -6013,6 +6013,15 @@ class AIAgent: except Exception as e: logger.warning("Session DB compression split failed — new session will NOT be indexed: %s", e) + # Warn on repeated compressions (quality degrades with each pass) + _cc = self.context_compressor.compression_count + if _cc >= 2: + self._vprint( + f"{self.log_prefix}⚠️ Session compressed {_cc} times — " + f"accuracy may degrade. Consider /new to start fresh.", + force=True, + ) + # Update token estimate after compaction so pressure calculations # use the post-compression count, not the stale pre-compression one. _compressed_est = ( From 989d4ea43d8fd59f022db00303b2eae14f10ab3a Mon Sep 17 00:00:00 2001 From: Teknium Date: Wed, 8 Apr 2026 20:19:12 -0700 Subject: [PATCH 12/49] fix: set compression_count on mock to avoid TypeError in test The new degradation warning reads compression_count as an int, but the existing test's MagicMock returns a MagicMock object for that attribute, causing '>=' comparison to fail. --- tests/run_agent/test_context_pressure.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/run_agent/test_context_pressure.py b/tests/run_agent/test_context_pressure.py index 522603fdb5f..a946ddd9cb1 100644 --- a/tests/run_agent/test_context_pressure.py +++ b/tests/run_agent/test_context_pressure.py @@ -219,6 +219,7 @@ class TestContextPressureFlags: ] agent.context_compressor.context_length = 200_000 agent.context_compressor.threshold_tokens = 100_000 + agent.context_compressor.compression_count = 1 agent._todo_store = MagicMock() agent._todo_store.format_for_injection.return_value = None From ffeaf6ffae91289c9a75869f1b03a54cc54729e1 Mon Sep 17 00:00:00 2001 From: Hermes Agent Date: Thu, 9 Apr 2026 00:18:40 +0000 Subject: [PATCH 13/49] feat(discord): inherit forum channel topic in thread sessions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ORIGINAL INCIDENT: Discord forum descriptions (the topic field on ForumChannel) were invisible to the agent. When a user set project instructions in a forum's description (e.g. tool-evaluations), threads created in that forum had no Channel Topic in their session context. Discovered while evaluating per-forum auto-context injection for web-tap-terminal development threads. ISSUE IN THE CODE: In gateway/platforms/discord.py, all three session entry points (_handle_message, _build_slash_event, _dispatch_thread_session) read chat_topic via getattr(channel, 'topic', None). Discord Thread objects don't carry a topic — only the parent ForumChannel does. So chat_topic was always None for forum threads, and the Channel Topic line was never injected into build_session_context_prompt output. The infrastructure to handle this was already in place — _is_forum_parent() detects forum channels, _format_thread_chat_name() traverses to the parent, and build_session_context_prompt() renders Channel Topic when present. The forum parent was being identified; its topic just wasn't being read. HOW THIS COMMIT FIXES IT: Adds _get_effective_topic(channel, is_thread) helper that reads channel.topic first, then falls back to the parent forum's topic when the channel is a thread inside a forum. All three session entry points now call this helper instead of inlining getattr(channel, 'topic', None). Existing tests pass unchanged. Co-authored-by: dhabibi <9087935+dhabibi@users.noreply.github.com> --- gateway/platforms/discord.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index b802f5712cb..36984202e58 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -1767,8 +1767,9 @@ class DiscordAdapter(BasePlatformAdapter): if hasattr(interaction.channel, "guild") and interaction.channel.guild: chat_name = f"{interaction.channel.guild.name} / #{chat_name}" - # Get channel topic (if available) - chat_topic = getattr(interaction.channel, "topic", None) + # Get channel topic (if available). + # For forum threads, inherit the parent forum's topic. + chat_topic = self._get_effective_topic(interaction.channel, is_thread=is_thread) source = self.build_source( chat_id=str(interaction.channel_id), @@ -1842,6 +1843,10 @@ class DiscordAdapter(BasePlatformAdapter): chat_name = f"{guild_name} / {thread_name}" if guild_name else thread_name + # Inherit forum topic when the thread was created inside a forum channel. + _chan = getattr(interaction, "channel", None) + chat_topic = self._get_effective_topic(_chan, is_thread=True) if _chan else None + source = self.build_source( chat_id=thread_id, chat_name=chat_name, @@ -1849,6 +1854,7 @@ class DiscordAdapter(BasePlatformAdapter): user_id=str(interaction.user.id), user_name=interaction.user.display_name, thread_id=thread_id, + chat_topic=chat_topic, ) event = MessageEvent( @@ -2134,6 +2140,15 @@ class DiscordAdapter(BasePlatformAdapter): return True return False + def _get_effective_topic(self, channel: Any, is_thread: bool = False) -> Optional[str]: + """Return the channel topic, falling back to the parent forum's topic for forum threads.""" + topic = getattr(channel, "topic", None) + if not topic and is_thread: + parent = getattr(channel, "parent", None) + if parent and self._is_forum_parent(parent): + topic = getattr(parent, "topic", None) + return topic + def _format_thread_chat_name(self, thread: Any) -> str: """Build a readable chat name for thread-like Discord channels, including forum context when available.""" thread_name = getattr(thread, "name", None) or str(getattr(thread, "id", "thread")) @@ -2301,8 +2316,10 @@ class DiscordAdapter(BasePlatformAdapter): if hasattr(message.channel, "guild") and message.channel.guild: chat_name = f"{message.channel.guild.name} / #{chat_name}" - # Get channel topic (if available - TextChannels have topics, DMs/threads don't) - chat_topic = getattr(message.channel, "topic", None) + # Get channel topic (if available - TextChannels have topics, DMs/threads don't). + # For threads whose parent is a forum channel, inherit the parent's topic + # so forum descriptions (e.g. project instructions) appear in the session context. + chat_topic = self._get_effective_topic(message.channel, is_thread=is_thread) # Build source source = self.build_source( From 54db7cbbe1fe74a361485b95d2370b8b679bbd0a Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Wed, 8 Apr 2026 21:31:44 -0700 Subject: [PATCH 14/49] fix(agent): tiered context pressure warnings + gateway dedup (#6411) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Combines the approaches from PR #6309 (duan78) and PR #5963 (KUSH42): Tiered warnings (from #5963): - Replaces boolean _context_pressure_warned with float _context_pressure_warned_at - Fires at 85% (orange) and re-fires at 95% (red/critical) - Adds 'compacting context...' status message before compression Gateway dedup (from #6309): - Class-level dict _context_pressure_last_warned survives across AIAgent instances (gateway creates a new instance per message) - 5-minute cooldown per session prevents warning spam - Higher-tier warnings bypass the cooldown (85% → 95% always fires) - Compression reset clears the dedup entry for the session - Stale entries evicted (older than 2x cooldown) to prevent memory leak Does NOT inject into messages — purely user-facing via _safe_print (CLI) and status_callback (gateway). Zero prompt cache impact. Fixes #6309. Fixes #5963. --- run_agent.py | 45 +++++++-- tests/run_agent/test_context_pressure.py | 120 ++++++++++++++++++++++- 2 files changed, 155 insertions(+), 10 deletions(-) diff --git a/run_agent.py b/run_agent.py index b473b825ecb..02803890a6a 100644 --- a/run_agent.py +++ b/run_agent.py @@ -442,6 +442,13 @@ class AIAgent: for AI models that support function calling. """ + # ── Class-level context pressure dedup (survives across instances) ── + # The gateway creates a new AIAgent per message, so instance-level flags + # reset every time. This dict tracks {session_id: (warn_level, timestamp)} + # to suppress duplicate warnings within a cooldown window. + _context_pressure_last_warned: dict = {} + _CONTEXT_PRESSURE_COOLDOWN = 300 # seconds between re-warning same session + @property def base_url(self) -> str: return self._base_url @@ -673,7 +680,8 @@ class AIAgent: # Context pressure warnings: notify the USER (not the LLM) as context # fills up. Purely informational — displayed in CLI output and sent via # status_callback for gateway platforms. Does NOT inject into messages. - self._context_pressure_warned = False + # Tiered: fires at 85% and again at 95% of compaction threshold. + self._context_pressure_warned_at = 0.0 # highest tier already shown # Activity tracking — updated on each API call, tool execution, and # stream chunk. Used by the gateway timeout handler to report what the @@ -6034,12 +6042,16 @@ class AIAgent: # Only reset the pressure warning if compression actually brought # us below the warning level (85% of threshold). When compression # can't reduce enough (e.g. threshold is very low, or system prompt - # alone exceeds the warning level), keep the flag set to prevent + # alone exceeds the warning level), keep the tier set to prevent # spamming the user with repeated warnings every loop iteration. if self.context_compressor.threshold_tokens > 0: _post_progress = _compressed_est / self.context_compressor.threshold_tokens if _post_progress < 0.85: - self._context_pressure_warned = False + self._context_pressure_warned_at = 0.0 + # Clear class-level dedup for this session so a fresh + # warning cycle can start if context grows again. + _sid = self.session_id or "default" + AIAgent._context_pressure_last_warned.pop(_sid, None) # Clear the file-read dedup cache. After compression the original # read content is summarised away — if the model re-reads the same @@ -8979,13 +8991,34 @@ class AIAgent: # compaction fires, not the raw context window. # Does not inject into messages — just prints to CLI output # and fires status_callback for gateway platforms. + # Tiered: 85% (orange) and 95% (red/critical). if _compressor.threshold_tokens > 0: _compaction_progress = _real_tokens / _compressor.threshold_tokens - if _compaction_progress >= 0.85 and not self._context_pressure_warned: - self._context_pressure_warned = True - self._emit_context_pressure(_compaction_progress, _compressor) + # Determine the warning tier for this progress level + _warn_tier = 0.0 + if _compaction_progress >= 0.95: + _warn_tier = 0.95 + elif _compaction_progress >= 0.85: + _warn_tier = 0.85 + if _warn_tier > self._context_pressure_warned_at: + # Class-level dedup: check if this session was already + # warned at this tier within the cooldown window. + _sid = self.session_id or "default" + _last = AIAgent._context_pressure_last_warned.get(_sid) + _now = time.time() + if _last is None or _last[0] < _warn_tier or (_now - _last[1]) >= self._CONTEXT_PRESSURE_COOLDOWN: + self._context_pressure_warned_at = _warn_tier + AIAgent._context_pressure_last_warned[_sid] = (_warn_tier, _now) + self._emit_context_pressure(_compaction_progress, _compressor) + # Evict stale entries (older than 2x cooldown) + _cutoff = _now - self._CONTEXT_PRESSURE_COOLDOWN * 2 + AIAgent._context_pressure_last_warned = { + k: v for k, v in AIAgent._context_pressure_last_warned.items() + if v[1] > _cutoff + } if self.compression_enabled and _compressor.should_compress(_real_tokens): + self._safe_print(" ⟳ compacting context…") messages, active_system_prompt = self._compress_context( messages, system_message, approx_tokens=self.context_compressor.last_prompt_tokens, diff --git a/tests/run_agent/test_context_pressure.py b/tests/run_agent/test_context_pressure.py index a946ddd9cb1..4140749c519 100644 --- a/tests/run_agent/test_context_pressure.py +++ b/tests/run_agent/test_context_pressure.py @@ -150,8 +150,8 @@ def agent(): class TestContextPressureFlags: """Context pressure warning flag tracking on AIAgent.""" - def test_flag_initialized_false(self, agent): - assert agent._context_pressure_warned is False + def test_flag_initialized_zero(self, agent): + assert agent._context_pressure_warned_at == 0.0 def test_emit_calls_status_callback(self, agent): """status_callback should be invoked with event type and message.""" @@ -210,7 +210,7 @@ class TestContextPressureFlags: def test_flag_reset_on_compression(self, agent): """After _compress_context, context pressure flag should reset.""" - agent._context_pressure_warned = True + agent._context_pressure_warned_at = 0.85 agent.compression_enabled = True agent.context_compressor = MagicMock() @@ -234,7 +234,7 @@ class TestContextPressureFlags: ] agent._compress_context(messages, "system prompt") - assert agent._context_pressure_warned is False + assert agent._context_pressure_warned_at == 0.0 def test_emit_callback_error_handled(self, agent): """If status_callback raises, it should be caught gracefully.""" @@ -247,3 +247,115 @@ class TestContextPressureFlags: # Should not raise agent._emit_context_pressure(0.85, compressor) + + def test_tiered_reemits_at_95(self, agent): + """Warning fires at 85%, then fires again when crossing 95%.""" + agent._context_pressure_warned_at = 0.85 + # Simulate crossing 95%: the tier (0.95) > warned_at (0.85) + assert 0.95 > agent._context_pressure_warned_at + # After emission at 95%, the tier should update + agent._context_pressure_warned_at = 0.95 + assert agent._context_pressure_warned_at == 0.95 + + def test_tiered_no_double_emit_at_same_level(self, agent): + """Once warned at 85%, further 85%+ readings don't re-warn.""" + agent._context_pressure_warned_at = 0.85 + # At 88%, tier is 0.85, which is NOT > warned_at (0.85) + _warn_tier = 0.85 if 0.88 >= 0.85 else 0.0 + assert not (_warn_tier > agent._context_pressure_warned_at) + + def test_flag_not_reset_when_compression_insufficient(self, agent): + """When compression can't drop below 85%, keep the flag set.""" + agent._context_pressure_warned_at = 0.85 + agent.compression_enabled = True + + agent.context_compressor = MagicMock() + agent.context_compressor.compress.return_value = [ + {"role": "user", "content": "Summary of conversation so far."} + ] + agent.context_compressor.context_length = 200 + # Use a small threshold so the tiny compressed output still + # represents >= 85% of it (prevents flag reset). + agent.context_compressor.threshold_tokens = 10 + agent.context_compressor.compression_count = 1 + agent.context_compressor.last_prompt_tokens = 0 + + agent._todo_store = MagicMock() + agent._todo_store.format_for_injection.return_value = None + agent._build_system_prompt = MagicMock(return_value="system prompt") + agent._cached_system_prompt = "old system prompt" + agent._session_db = None + + messages = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi there"}, + ] + agent._compress_context(messages, "system prompt") + + # Post-compression is ~90% of threshold — flag should NOT reset + assert agent._context_pressure_warned_at == 0.85 + + +class TestContextPressureGatewayDedup: + """Class-level dedup prevents warning spam across AIAgent instances.""" + + def setup_method(self): + """Clear class-level dedup state between tests.""" + AIAgent._context_pressure_last_warned.clear() + + def test_second_instance_within_cooldown_suppressed(self): + """Same session, same tier, within cooldown — should be suppressed.""" + import time + sid = "test_session_dedup" + # Simulate first warning + AIAgent._context_pressure_last_warned[sid] = (0.85, time.time()) + # Second instance checking same tier within cooldown + _last = AIAgent._context_pressure_last_warned.get(sid) + _should_warn = _last is None or _last[0] < 0.85 or (time.time() - _last[1]) >= AIAgent._CONTEXT_PRESSURE_COOLDOWN + assert not _should_warn + + def test_higher_tier_fires_despite_cooldown(self): + """Same session, higher tier — should fire even within cooldown.""" + import time + sid = "test_session_tier" + AIAgent._context_pressure_last_warned[sid] = (0.85, time.time()) + _last = AIAgent._context_pressure_last_warned.get(sid) + # 0.95 > 0.85 stored tier → should warn + _should_warn = _last is None or _last[0] < 0.95 or (time.time() - _last[1]) >= AIAgent._CONTEXT_PRESSURE_COOLDOWN + assert _should_warn + + def test_warning_fires_after_cooldown_expires(self): + """Same session, same tier, after cooldown — should fire again.""" + import time + sid = "test_session_expired" + # Set a timestamp far in the past + AIAgent._context_pressure_last_warned[sid] = (0.85, time.time() - AIAgent._CONTEXT_PRESSURE_COOLDOWN - 1) + _last = AIAgent._context_pressure_last_warned.get(sid) + _should_warn = _last is None or _last[0] < 0.85 or (time.time() - _last[1]) >= AIAgent._CONTEXT_PRESSURE_COOLDOWN + assert _should_warn + + def test_compression_clears_dedup(self): + """After compression drops below 85%, dedup entry should be cleared.""" + import time + sid = "test_session_clear" + AIAgent._context_pressure_last_warned[sid] = (0.85, time.time()) + assert sid in AIAgent._context_pressure_last_warned + # Simulate what _compress_context does on reset + AIAgent._context_pressure_last_warned.pop(sid, None) + assert sid not in AIAgent._context_pressure_last_warned + + def test_eviction_removes_stale_entries(self): + """Stale entries older than 2x cooldown should be evicted.""" + import time + _now = time.time() + AIAgent._context_pressure_last_warned = { + "fresh": (0.85, _now), + "stale": (0.85, _now - AIAgent._CONTEXT_PRESSURE_COOLDOWN * 3), + } + _cutoff = _now - AIAgent._CONTEXT_PRESSURE_COOLDOWN * 2 + AIAgent._context_pressure_last_warned = { + k: v for k, v in AIAgent._context_pressure_last_warned.items() + if v[1] > _cutoff + } + assert "fresh" in AIAgent._context_pressure_last_warned + assert "stale" not in AIAgent._context_pressure_last_warned From e7d3e9d767b473b9fbcf7b85884aa90758a514f9 Mon Sep 17 00:00:00 2001 From: angelos Date: Thu, 9 Apr 2026 02:12:26 +0000 Subject: [PATCH 15/49] fix(terminal): persistent sandbox envs survive between turns MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `_cleanup_task_resources` was unconditionally calling `cleanup_vm()` at the end of every `run_conversation` (i.e. every user turn), tearing down the docker/daytona/modal sandbox container regardless of its `persistent_filesystem` setting. This contradicted the documented intent of `terminal.lifetime_seconds` (idle reaper) and `container_persistent`, and caused per-turn loss of `/workspace`, `~/.config`, agent CLI auth state, and any other content living inside the sandbox. The unconditional teardown was introduced in fbd3a2fd ("prevent leakage of morph instances between tasks", 2025-11-04) to plug a Morph backend leak, two days after `lifetime_seconds` shipped in faecbddd. It was later refactored into `_cleanup_task_resources` in 70dd3a16 without changing semantics. Code and docs have disagreed since. Fix: introduce `terminal_tool.is_persistent_env(task_id)` and skip the per-turn `cleanup_vm` when the active env is persistent. The idle reaper (`_cleanup_inactive_envs`) still tears persistent envs down once `terminal.lifetime_seconds` is exceeded. Non-persistent backends (Morph) are unchanged — still torn down per turn, preserving the original leak-prevention intent. --- run_agent.py | 22 +++++++++++++++++++--- tools/terminal_tool.py | 17 +++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/run_agent.py b/run_agent.py index 02803890a6a..793ddd6750f 100644 --- a/run_agent.py +++ b/run_agent.py @@ -66,7 +66,7 @@ from model_tools import ( handle_function_call, check_toolset_requirements, ) -from tools.terminal_tool import cleanup_vm, get_active_env +from tools.terminal_tool import cleanup_vm, get_active_env, is_persistent_env from tools.tool_result_storage import maybe_persist_tool_result, enforce_turn_budget from tools.interrupt import set_interrupt as _set_interrupt from tools.browser_tool import cleanup_browser @@ -1695,9 +1695,25 @@ class AIAgent: return None def _cleanup_task_resources(self, task_id: str) -> None: - """Clean up VM and browser resources for a given task.""" + """Clean up VM and browser resources for a given task. + + Skips ``cleanup_vm`` when the active terminal environment is marked + persistent (``persistent_filesystem=True``) so that long-lived sandbox + containers survive between turns. The idle reaper in + ``terminal_tool._cleanup_inactive_envs`` still tears them down once + ``terminal.lifetime_seconds`` is exceeded. Non-persistent backends are + torn down per-turn as before to prevent resource leakage (the original + intent of this hook for the Morph backend, see commit fbd3a2fd). + """ try: - cleanup_vm(task_id) + if is_persistent_env(task_id): + if self.verbose_logging: + logging.debug( + f"Skipping per-turn cleanup_vm for persistent env {task_id}; " + f"idle reaper will handle it." + ) + else: + cleanup_vm(task_id) except Exception as e: if self.verbose_logging: logging.warning(f"Failed to cleanup VM for task {task_id}: {e}") diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index 243127a2958..183e8983356 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -814,6 +814,23 @@ def get_active_env(task_id: str): return _active_environments.get(task_id) +def is_persistent_env(task_id: str) -> bool: + """Return True if the active environment for task_id is configured for + cross-turn persistence (``persistent_filesystem=True``). + + Used by the agent loop to skip per-turn teardown for backends whose whole + point is to survive between turns (docker with ``container_persistent``, + daytona, modal, etc.). Non-persistent backends (e.g. Morph) still get torn + down at end-of-turn to prevent leakage. The idle reaper + (``_cleanup_inactive_envs``) handles persistent envs once they exceed + ``terminal.lifetime_seconds``. + """ + env = get_active_env(task_id) + if env is None: + return False + return bool(getattr(env, "_persistent", False)) + + def get_active_environments_info() -> Dict[str, Any]: """Get information about currently active environments.""" info = { From e94008c404f8d8af76c972d295f61797820106fd Mon Sep 17 00:00:00 2001 From: helix4u <4317663+helix4u@users.noreply.github.com> Date: Wed, 8 Apr 2026 20:12:43 -0600 Subject: [PATCH 16/49] fix(terminal): guard invalid command values --- .../tools/test_terminal_none_command_guard.py | 21 +++++++++++ tools/terminal_tool.py | 35 ++++++++++++++++--- 2 files changed, 52 insertions(+), 4 deletions(-) create mode 100644 tests/tools/test_terminal_none_command_guard.py diff --git a/tests/tools/test_terminal_none_command_guard.py b/tests/tools/test_terminal_none_command_guard.py new file mode 100644 index 00000000000..05455836d12 --- /dev/null +++ b/tests/tools/test_terminal_none_command_guard.py @@ -0,0 +1,21 @@ +"""Regression tests for invalid/None terminal command handling.""" + +import json + +from tools.terminal_tool import _transform_sudo_command, terminal_tool + + +def test_transform_sudo_command_none_returns_cleanly(): + transformed, sudo_stdin = _transform_sudo_command(None) + + assert transformed is None + assert sudo_stdin is None + + +def test_terminal_tool_none_command_returns_clean_error(): + result = json.loads(terminal_tool(None)) # type: ignore[arg-type] + + assert result["exit_code"] == -1 + assert result["status"] == "error" + assert "expected string" in result["error"].lower() + assert "nonetype" in result["error"].lower() diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index 183e8983356..96a1147759c 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -327,7 +327,19 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str: del os.environ["HERMES_SPINNER_PAUSE"] -def _transform_sudo_command(command: str) -> tuple[str, str | None]: +def _safe_command_preview(command: Any, limit: int = 200) -> str: + """Return a log-safe preview for possibly-invalid command values.""" + if command is None: + return "" + if isinstance(command, str): + return command[:limit] + try: + return repr(command)[:limit] + except Exception: + return f"<{type(command).__name__}>" + + +def _transform_sudo_command(command: str | None) -> tuple[str | None, str | None]: """ Transform sudo commands to use -S flag if SUDO_PASSWORD is available. @@ -365,6 +377,9 @@ def _transform_sudo_command(command: str) -> tuple[str, str | None]: import re # Check if command even contains sudo + if command is None: + return None, None + if not re.search(r'\bsudo\b', command): return command, None # No sudo in command, nothing to do @@ -1050,6 +1065,18 @@ def terminal_tool( # Note: force parameter is internal only, not exposed to model API """ try: + if not isinstance(command, str): + logger.warning( + "Rejected invalid terminal command value: %s", + type(command).__name__, + ) + return json.dumps({ + "output": "", + "exit_code": -1, + "error": f"Invalid command: expected string, got {type(command).__name__}", + "status": "error", + }, ensure_ascii=False) + # Get configuration config = _get_env_config() env_type = config["env_type"] @@ -1207,7 +1234,7 @@ def terminal_tool( workdir_error = _validate_workdir(workdir) if workdir_error: logger.warning("Blocked dangerous workdir: %s (command: %s)", - workdir[:200], command[:200]) + workdir[:200], _safe_command_preview(command)) return json.dumps({ "output": "", "exit_code": -1, @@ -1347,12 +1374,12 @@ def terminal_tool( retry_count += 1 wait_time = 2 ** retry_count logger.warning("Execution error, retrying in %ds (attempt %d/%d) - Command: %s - Error: %s: %s - Task: %s, Backend: %s", - wait_time, retry_count, max_retries, command[:200], type(e).__name__, e, effective_task_id, env_type) + wait_time, retry_count, max_retries, _safe_command_preview(command), type(e).__name__, e, effective_task_id, env_type) time.sleep(wait_time) continue logger.error("Execution failed after %d retries - Command: %s - Error: %s: %s - Task: %s, Backend: %s", - max_retries, command[:200], type(e).__name__, e, effective_task_id, env_type) + max_retries, _safe_command_preview(command), type(e).__name__, e, effective_task_id, env_type) return json.dumps({ "output": "", "exit_code": -1, From 1d8d4f28ae05198e995433c2c0f30ed324093494 Mon Sep 17 00:00:00 2001 From: xingkongliang Date: Wed, 8 Apr 2026 11:21:24 +1000 Subject: [PATCH 17/49] fix(gateway): prevent background process notifications from triggering false pairing requests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a background process with notify_on_complete=True finishes, the gateway injects a synthetic MessageEvent to notify the session. This event was constructed without user_id, causing _is_user_authorized() to reject it and — for DM-origin sessions — trigger the pairing flow, sending "Hi~ I don't recognize you yet!" with a pairing code to the chat owner. Add an `internal` flag to MessageEvent that bypasses authorization checks for system-generated synthetic events. Only the process watcher sets this flag; no external/adapter code path can produce it. Includes 4 regression tests covering the fix and the normal pairing path. --- gateway/platforms/base.py | 4 + gateway/run.py | 8 +- .../test_internal_event_bypass_pairing.py | 219 ++++++++++++++++++ 3 files changed, 229 insertions(+), 2 deletions(-) create mode 100644 tests/gateway/test_internal_event_bypass_pairing.py diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index a888eede94e..c72fa513bb3 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -407,6 +407,10 @@ class MessageEvent: # Auto-loaded skill for topic/channel bindings (e.g., Telegram DM Topics) auto_skill: Optional[str] = None + # Internal flag — set for synthetic events (e.g. background process + # completion notifications) that must bypass user authorization checks. + internal: bool = False + # Timestamps timestamp: datetime = field(default_factory=datetime.now) diff --git a/gateway/run.py b/gateway/run.py index 3333608199f..c9d3f07e6ec 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1781,8 +1781,11 @@ class GatewayRunner: """ source = event.source - # Check if user is authorized - if not self._is_user_authorized(source): + # Internal events (e.g. background-process completion notifications) + # are system-generated and must skip user authorization. + if getattr(event, "internal", False): + pass + elif not self._is_user_authorized(source): logger.warning("Unauthorized user: %s (%s) on %s", source.user_id, source.user_name, source.platform.value) # In DMs: offer pairing code. In groups: silently ignore. if source.chat_type == "dm" and self._get_unauthorized_dm_behavior(source.platform) == "pair": @@ -6160,6 +6163,7 @@ class GatewayRunner: text=synth_text, message_type=MessageType.TEXT, source=_source, + internal=True, ) logger.info( "Process %s finished — injecting agent notification for session %s", diff --git a/tests/gateway/test_internal_event_bypass_pairing.py b/tests/gateway/test_internal_event_bypass_pairing.py new file mode 100644 index 00000000000..b910086f05e --- /dev/null +++ b/tests/gateway/test_internal_event_bypass_pairing.py @@ -0,0 +1,219 @@ +"""Tests that internal synthetic events (e.g. background process completion) +bypass user authorization and do not trigger DM pairing. + +Regression test for the bug where ``_run_process_watcher`` with +``notify_on_complete=True`` injected a ``MessageEvent`` without ``user_id``, +causing ``_is_user_authorized`` to reject it and the gateway to send a +pairing code to the chat. +""" + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest + +from gateway.config import GatewayConfig, Platform +from gateway.platforms.base import MessageEvent +from gateway.run import GatewayRunner +from gateway.session import SessionSource + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class _FakeRegistry: + """Return pre-canned sessions, then None once exhausted.""" + + def __init__(self, sessions): + self._sessions = list(sessions) + + def get(self, session_id): + if self._sessions: + return self._sessions.pop(0) + return None + + +def _build_runner(monkeypatch, tmp_path) -> GatewayRunner: + """Create a GatewayRunner with notifications set to 'all'.""" + (tmp_path / "config.yaml").write_text( + "display:\n background_process_notifications: all\n", + encoding="utf-8", + ) + + import gateway.run as gateway_run + + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + + runner = GatewayRunner(GatewayConfig()) + adapter = SimpleNamespace(send=AsyncMock(), handle_message=AsyncMock()) + runner.adapters[Platform.DISCORD] = adapter + return runner + + +def _watcher_dict_with_notify(): + return { + "session_id": "proc_test_internal", + "check_interval": 0, + "session_key": "agent:main:discord:dm:123", + "platform": "discord", + "chat_id": "123", + "thread_id": "", + "notify_on_complete": True, + } + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_notify_on_complete_sets_internal_flag(monkeypatch, tmp_path): + """Synthetic completion event must have internal=True.""" + import tools.process_registry as pr_module + + sessions = [ + SimpleNamespace( + output_buffer="done\n", exited=True, exit_code=0, command="echo test" + ), + ] + monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions)) + + async def _instant_sleep(*_a, **_kw): + pass + monkeypatch.setattr(asyncio, "sleep", _instant_sleep) + + runner = _build_runner(monkeypatch, tmp_path) + adapter = runner.adapters[Platform.DISCORD] + + await runner._run_process_watcher(_watcher_dict_with_notify()) + + assert adapter.handle_message.await_count == 1 + event = adapter.handle_message.await_args.args[0] + assert isinstance(event, MessageEvent) + assert event.internal is True, "Synthetic completion event must be marked internal" + + +@pytest.mark.asyncio +async def test_internal_event_bypasses_authorization(monkeypatch, tmp_path): + """An internal event should skip _is_user_authorized entirely.""" + import gateway.run as gateway_run + + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + (tmp_path / "config.yaml").write_text("", encoding="utf-8") + + runner = GatewayRunner(GatewayConfig()) + + # Create an internal event with no user_id (simulates the bug scenario) + source = SessionSource( + platform=Platform.DISCORD, + chat_id="123", + chat_type="dm", + ) + event = MessageEvent( + text="[SYSTEM: Background process completed]", + source=source, + internal=True, + ) + + # Track if _is_user_authorized is called + auth_called = False + original_auth = GatewayRunner._is_user_authorized + + def tracking_auth(self, src): + nonlocal auth_called + auth_called = True + return original_auth(self, src) + + monkeypatch.setattr(GatewayRunner, "_is_user_authorized", tracking_auth) + + # _handle_message will proceed past auth check and eventually fail on + # downstream logic. We just need to verify auth is skipped. + try: + await runner._handle_message(event) + except Exception: + pass # Expected — downstream code needs more setup + + assert not auth_called, ( + "_is_user_authorized should NOT be called for internal events" + ) + + +@pytest.mark.asyncio +async def test_internal_event_does_not_trigger_pairing(monkeypatch, tmp_path): + """An internal event with no user_id must not generate a pairing code.""" + import gateway.run as gateway_run + + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + (tmp_path / "config.yaml").write_text("", encoding="utf-8") + + runner = GatewayRunner(GatewayConfig()) + # Add adapter so pairing would have somewhere to send + adapter = SimpleNamespace(send=AsyncMock()) + runner.adapters[Platform.DISCORD] = adapter + + source = SessionSource( + platform=Platform.DISCORD, + chat_id="123", + chat_type="dm", # DM would normally trigger pairing + ) + event = MessageEvent( + text="[SYSTEM: Background process completed]", + source=source, + internal=True, + ) + + # Track pairing code generation + generate_called = False + original_generate = runner.pairing_store.generate_code + + def tracking_generate(*args, **kwargs): + nonlocal generate_called + generate_called = True + return original_generate(*args, **kwargs) + + runner.pairing_store.generate_code = tracking_generate + + try: + await runner._handle_message(event) + except Exception: + pass # Expected — downstream code needs more setup + + assert not generate_called, ( + "Pairing code should NOT be generated for internal events" + ) + + +@pytest.mark.asyncio +async def test_non_internal_event_without_user_triggers_pairing(monkeypatch, tmp_path): + """Verify the normal (non-internal) path still triggers pairing for unknown users.""" + import gateway.run as gateway_run + + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + (tmp_path / "config.yaml").write_text("", encoding="utf-8") + + runner = GatewayRunner(GatewayConfig()) + adapter = SimpleNamespace(send=AsyncMock()) + runner.adapters[Platform.DISCORD] = adapter + + source = SessionSource( + platform=Platform.DISCORD, + chat_id="123", + chat_type="dm", + user_id="unknown_user_999", + ) + # Normal event (not internal) + event = MessageEvent( + text="hello", + source=source, + internal=False, + ) + + result = await runner._handle_message(event) + + # Should return None (unauthorized) and send pairing message + assert result is None + assert adapter.send.await_count == 1 + sent_text = adapter.send.await_args.args[1] + assert "don't recognize you" in sent_text From 5449c01d263556beb93bff6b525ad67f80a528ba Mon Sep 17 00:00:00 2001 From: Teknium Date: Wed, 8 Apr 2026 22:48:07 -0700 Subject: [PATCH 18/49] fix: clean env vars in pairing regression test The test_non_internal_event_without_user_triggers_pairing test relied on no Discord auth env vars being set, but gateway/run.py loads dotenv at module level. In environments with DISCORD_ALLOW_ALL_USERS=True in .env, the auth check passed instead of triggering the pairing flow. Clear DISCORD_ALLOW_ALL_USERS, DISCORD_ALLOWED_USERS, GATEWAY_ALLOW_ALL_USERS, and GATEWAY_ALLOWED_USERS via monkeypatch to ensure test isolation. --- tests/gateway/test_internal_event_bypass_pairing.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/gateway/test_internal_event_bypass_pairing.py b/tests/gateway/test_internal_event_bypass_pairing.py index b910086f05e..19ecd7059ee 100644 --- a/tests/gateway/test_internal_event_bypass_pairing.py +++ b/tests/gateway/test_internal_event_bypass_pairing.py @@ -193,6 +193,13 @@ async def test_non_internal_event_without_user_triggers_pairing(monkeypatch, tmp monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) (tmp_path / "config.yaml").write_text("", encoding="utf-8") + # Clear env vars that could let all users through (loaded by + # module-level dotenv in gateway/run.py from the real ~/.hermes/.env). + monkeypatch.delenv("DISCORD_ALLOW_ALL_USERS", raising=False) + monkeypatch.delenv("DISCORD_ALLOWED_USERS", raising=False) + monkeypatch.delenv("GATEWAY_ALLOW_ALL_USERS", raising=False) + monkeypatch.delenv("GATEWAY_ALLOWED_USERS", raising=False) + runner = GatewayRunner(GatewayConfig()) adapter = SimpleNamespace(send=AsyncMock()) runner.adapters[Platform.DISCORD] = adapter From 30a0fcaec8ff142813cc6826aa3da953c645e7ee Mon Sep 17 00:00:00 2001 From: helix4u <4317663+helix4u@users.noreply.github.com> Date: Wed, 8 Apr 2026 14:24:59 -0600 Subject: [PATCH 19/49] fix(slack): handle assistant thread lifecycle events --- gateway/platforms/slack.py | 157 +++++++++++++++++++++++++++++++++++- tests/gateway/test_slack.py | 90 ++++++++++++++++++++- 2 files changed, 242 insertions(+), 5 deletions(-) diff --git a/gateway/platforms/slack.py b/gateway/platforms/slack.py index 7af313d325e..8685b92ed00 100644 --- a/gateway/platforms/slack.py +++ b/gateway/platforms/slack.py @@ -14,7 +14,7 @@ import logging import os import re import time -from typing import Dict, Optional, Any +from typing import Dict, Optional, Any, Tuple try: from slack_bolt.async_app import AsyncApp @@ -95,6 +95,11 @@ class SlackAdapter(BasePlatformAdapter): # respond to ALL subsequent messages in that thread automatically. self._mentioned_threads: set = set() self._MENTIONED_THREADS_MAX = 5000 + # Assistant thread metadata keyed by (channel_id, thread_ts). Slack's + # AI Assistant lifecycle events can arrive before/alongside message + # events, and they carry the user/thread identity needed for stable + # session + memory scoping. + self._assistant_threads: Dict[Tuple[str, str], Dict[str, str]] = {} async def connect(self) -> bool: """Connect to Slack via Socket Mode.""" @@ -181,6 +186,14 @@ class SlackAdapter(BasePlatformAdapter): async def handle_app_mention(event, say): pass + @self._app.event("assistant_thread_started") + async def handle_assistant_thread_started(event, say): + await self._handle_assistant_thread_lifecycle_event(event) + + @self._app.event("assistant_thread_context_changed") + async def handle_assistant_thread_context_changed(event, say): + await self._handle_assistant_thread_lifecycle_event(event) + # Register slash command handler @self._app.command("/hermes") async def handle_hermes_command(ack, command): @@ -755,6 +768,129 @@ class SlackAdapter(BasePlatformAdapter): # ----- Internal handlers ----- + def _assistant_thread_key(self, channel_id: str, thread_ts: str) -> Optional[Tuple[str, str]]: + """Return a stable cache key for Slack assistant thread metadata.""" + if not channel_id or not thread_ts: + return None + return (str(channel_id), str(thread_ts)) + + def _extract_assistant_thread_metadata(self, event: dict) -> Dict[str, str]: + """Extract Slack Assistant thread identity data from an event payload.""" + assistant_thread = event.get("assistant_thread") or {} + context = assistant_thread.get("context") or event.get("context") or {} + + channel_id = ( + assistant_thread.get("channel_id") + or event.get("channel") + or context.get("channel_id") + or "" + ) + thread_ts = ( + assistant_thread.get("thread_ts") + or event.get("thread_ts") + or event.get("message_ts") + or "" + ) + user_id = ( + assistant_thread.get("user_id") + or event.get("user") + or context.get("user_id") + or "" + ) + team_id = ( + event.get("team") + or event.get("team_id") + or assistant_thread.get("team_id") + or "" + ) + context_channel_id = context.get("channel_id") or "" + + return { + "channel_id": str(channel_id) if channel_id else "", + "thread_ts": str(thread_ts) if thread_ts else "", + "user_id": str(user_id) if user_id else "", + "team_id": str(team_id) if team_id else "", + "context_channel_id": str(context_channel_id) if context_channel_id else "", + } + + def _cache_assistant_thread_metadata(self, metadata: Dict[str, str]) -> None: + """Remember assistant thread identity data for later message events.""" + channel_id = metadata.get("channel_id", "") + thread_ts = metadata.get("thread_ts", "") + key = self._assistant_thread_key(channel_id, thread_ts) + if not key: + return + + existing = self._assistant_threads.get(key, {}) + merged = dict(existing) + merged.update({k: v for k, v in metadata.items() if v}) + self._assistant_threads[key] = merged + + team_id = merged.get("team_id", "") + if team_id and channel_id: + self._channel_team[channel_id] = team_id + + def _lookup_assistant_thread_metadata( + self, + event: dict, + channel_id: str = "", + thread_ts: str = "", + ) -> Dict[str, str]: + """Load cached assistant-thread metadata that matches the current event.""" + metadata = self._extract_assistant_thread_metadata(event) + if channel_id and not metadata.get("channel_id"): + metadata["channel_id"] = channel_id + if thread_ts and not metadata.get("thread_ts"): + metadata["thread_ts"] = thread_ts + + key = self._assistant_thread_key( + metadata.get("channel_id", ""), + metadata.get("thread_ts", ""), + ) + cached = self._assistant_threads.get(key, {}) if key else {} + if cached: + merged = dict(cached) + merged.update({k: v for k, v in metadata.items() if v}) + return merged + return metadata + + def _seed_assistant_thread_session(self, metadata: Dict[str, str]) -> None: + """Prime the session store so assistant threads get stable user scoping.""" + session_store = getattr(self, "_session_store", None) + if not session_store: + return + + channel_id = metadata.get("channel_id", "") + thread_ts = metadata.get("thread_ts", "") + user_id = metadata.get("user_id", "") + if not channel_id or not thread_ts or not user_id: + return + + source = self.build_source( + chat_id=channel_id, + chat_name=channel_id, + chat_type="dm", + user_id=user_id, + thread_id=thread_ts, + chat_topic=metadata.get("context_channel_id") or None, + ) + + try: + session_store.get_or_create_session(source) + except Exception: + logger.debug( + "[Slack] Failed to seed assistant thread session for %s/%s", + channel_id, + thread_ts, + exc_info=True, + ) + + async def _handle_assistant_thread_lifecycle_event(self, event: dict) -> None: + """Handle Slack Assistant lifecycle events that carry user/thread identity.""" + metadata = self._extract_assistant_thread_metadata(event) + self._cache_assistant_thread_metadata(metadata) + self._seed_assistant_thread_session(metadata) + async def _handle_slack_message(self, event: dict) -> None: """Handle an incoming Slack message event.""" # Dedup: Slack Socket Mode can redeliver events after reconnects (#4777) @@ -781,10 +917,21 @@ class SlackAdapter(BasePlatformAdapter): return text = event.get("text", "") - user_id = event.get("user", "") channel_id = event.get("channel", "") ts = event.get("ts", "") - team_id = event.get("team", "") + assistant_meta = self._lookup_assistant_thread_metadata( + event, + channel_id=channel_id, + thread_ts=event.get("thread_ts", ""), + ) + user_id = event.get("user") or assistant_meta.get("user_id", "") + if not channel_id: + channel_id = assistant_meta.get("channel_id", "") + team_id = ( + event.get("team") + or event.get("team_id") + or assistant_meta.get("team_id", "") + ) # Track which workspace owns this channel if team_id and channel_id: @@ -792,6 +939,8 @@ class SlackAdapter(BasePlatformAdapter): # Determine if this is a DM or channel message channel_type = event.get("channel_type", "") + if not channel_type and channel_id.startswith("D"): + channel_type = "im" is_dm = channel_type == "im" # Build thread_ts for session keying. @@ -800,7 +949,7 @@ class SlackAdapter(BasePlatformAdapter): # In DMs: only use the real thread_ts — top-level DMs should share # one continuous session, threaded DMs get their own session. if is_dm: - thread_ts = event.get("thread_ts") # None for top-level DMs + thread_ts = event.get("thread_ts") or assistant_meta.get("thread_ts") # None for top-level DMs else: thread_ts = event.get("thread_ts") or ts # ts fallback for channels diff --git a/tests/gateway/test_slack.py b/tests/gateway/test_slack.py index 89b44718344..0bad0abe567 100644 --- a/tests/gateway/test_slack.py +++ b/tests/gateway/test_slack.py @@ -96,7 +96,7 @@ class TestAppMentionHandler: """Verify that the app_mention event handler is registered.""" def test_app_mention_registered_on_connect(self): - """connect() should register both 'message' and 'app_mention' handlers.""" + """connect() should register message + assistant lifecycle handlers.""" config = PlatformConfig(enabled=True, token="xoxb-fake") adapter = SlackAdapter(config) @@ -145,6 +145,8 @@ class TestAppMentionHandler: assert "message" in registered_events assert "app_mention" in registered_events + assert "assistant_thread_started" in registered_events + assert "assistant_thread_context_changed" in registered_events assert "/hermes" in registered_commands @@ -840,6 +842,92 @@ class TestThreadReplyHandling: adapter.handle_message.assert_not_called() +# --------------------------------------------------------------------------- +# TestAssistantThreadLifecycle +# --------------------------------------------------------------------------- + + +class TestAssistantThreadLifecycle: + """Slack Assistant lifecycle events should seed session/user context.""" + + @pytest.fixture() + def mock_session_store(self): + store = MagicMock() + store._entries = {} + store._ensure_loaded = MagicMock() + store.config = MagicMock() + store.config.group_sessions_per_user = True + store.get_or_create_session = MagicMock() + return store + + @pytest.fixture() + def assistant_adapter(self, mock_session_store): + config = PlatformConfig(enabled=True, token="***") + a = SlackAdapter(config) + a._app = MagicMock() + a._app.client = AsyncMock() + a._bot_user_id = "U_BOT" + a._team_bot_user_ids = {"T_TEAM": "U_BOT"} + a._running = True + a.handle_message = AsyncMock() + a.set_session_store(mock_session_store) + return a + + @pytest.mark.asyncio + async def test_lifecycle_event_seeds_session_store(self, assistant_adapter, mock_session_store): + event = { + "type": "assistant_thread_started", + "team_id": "T_TEAM", + "assistant_thread": { + "channel_id": "D123", + "thread_ts": "171.000", + "user_id": "U_USER", + "context": {"channel_id": "C_ORIGIN"}, + }, + } + + await assistant_adapter._handle_assistant_thread_lifecycle_event(event) + + assert assistant_adapter._assistant_threads[("D123", "171.000")]["user_id"] == "U_USER" + mock_session_store.get_or_create_session.assert_called_once() + source = mock_session_store.get_or_create_session.call_args[0][0] + assert source.chat_id == "D123" + assert source.chat_type == "dm" + assert source.user_id == "U_USER" + assert source.thread_id == "171.000" + assert source.chat_topic == "C_ORIGIN" + + @pytest.mark.asyncio + async def test_message_uses_cached_assistant_thread_identity(self, assistant_adapter): + assistant_adapter._assistant_threads[("D123", "171.000")] = { + "channel_id": "D123", + "thread_ts": "171.000", + "user_id": "U_USER", + "team_id": "T_TEAM", + } + assistant_adapter._app.client.users_info = AsyncMock(return_value={ + "user": {"profile": {"display_name": "Tyler"}} + }) + assistant_adapter._app.client.reactions_add = AsyncMock() + assistant_adapter._app.client.reactions_remove = AsyncMock() + + event = { + "text": "hello from assistant dm", + "channel": "D123", + "channel_type": "im", + "thread_ts": "171.000", + "ts": "171.111", + "team": "T_TEAM", + } + + await assistant_adapter._handle_slack_message(event) + + msg_event = assistant_adapter.handle_message.call_args[0][0] + assert msg_event.source.user_id == "U_USER" + assert msg_event.source.thread_id == "171.000" + assert msg_event.source.user_name == "Tyler" + + # --------------------------------------------------------------------------- # TestUserNameResolution # --------------------------------------------------------------------------- From 241bd4fc7e48cfeb4417a41145bcf796e2df7a6a Mon Sep 17 00:00:00 2001 From: Teknium Date: Wed, 8 Apr 2026 22:44:22 -0700 Subject: [PATCH 20/49] fix: add size cap to assistant thread metadata cache Prevents unbounded memory growth in _assistant_threads dict. Evicts oldest entries when exceeding _ASSISTANT_THREADS_MAX (5000), matching the pattern used by _mentioned_threads and _seen_messages. --- gateway/platforms/slack.py | 7 +++++++ tests/gateway/test_slack.py | 22 ++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/gateway/platforms/slack.py b/gateway/platforms/slack.py index 8685b92ed00..26184b7eb52 100644 --- a/gateway/platforms/slack.py +++ b/gateway/platforms/slack.py @@ -100,6 +100,7 @@ class SlackAdapter(BasePlatformAdapter): # events, and they carry the user/thread identity needed for stable # session + memory scoping. self._assistant_threads: Dict[Tuple[str, str], Dict[str, str]] = {} + self._ASSISTANT_THREADS_MAX = 5000 async def connect(self) -> bool: """Connect to Slack via Socket Mode.""" @@ -826,6 +827,12 @@ class SlackAdapter(BasePlatformAdapter): merged.update({k: v for k, v in metadata.items() if v}) self._assistant_threads[key] = merged + # Evict oldest entries when the cache exceeds the limit + if len(self._assistant_threads) > self._ASSISTANT_THREADS_MAX: + excess = len(self._assistant_threads) - self._ASSISTANT_THREADS_MAX // 2 + for old_key in list(self._assistant_threads)[:excess]: + del self._assistant_threads[old_key] + team_id = merged.get("team_id", "") if team_id and channel_id: self._channel_team[channel_id] = team_id diff --git a/tests/gateway/test_slack.py b/tests/gateway/test_slack.py index 0bad0abe567..67c7cce1dce 100644 --- a/tests/gateway/test_slack.py +++ b/tests/gateway/test_slack.py @@ -927,6 +927,28 @@ class TestAssistantThreadLifecycle: assert msg_event.source.thread_id == "171.000" assert msg_event.source.user_name == "Tyler" + def test_assistant_threads_cache_eviction(self, assistant_adapter): + """Cache should evict oldest entries when exceeding the size limit.""" + assistant_adapter._ASSISTANT_THREADS_MAX = 10 + # Fill to the limit + for i in range(10): + assistant_adapter._cache_assistant_thread_metadata({ + "channel_id": f"D{i}", + "thread_ts": f"{i}.000", + "user_id": f"U{i}", + }) + assert len(assistant_adapter._assistant_threads) == 10 + + # Adding one more should trigger eviction (down to max // 2 = 5) + assistant_adapter._cache_assistant_thread_metadata({ + "channel_id": "D999", + "thread_ts": "999.000", + "user_id": "U999", + }) + assert len(assistant_adapter._assistant_threads) <= 10 + # The newest entry must survive eviction + assert ("D999", "999.000") in assistant_adapter._assistant_threads + # --------------------------------------------------------------------------- # TestUserNameResolution From d97f6cec7fa85038654b8b58529aed6307a104be Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Wed, 8 Apr 2026 23:54:03 -0700 Subject: [PATCH 21/49] feat(gateway): add BlueBubbles iMessage platform adapter (#6437) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds Apple iMessage as a gateway platform via BlueBubbles macOS server. Architecture: - Webhook-based inbound (event-driven, no polling/dedup needed) - Email/phone → chat GUID resolution for user-friendly addressing - Private API safety (checks helper_connected before tapback/typing) - Inbound attachment downloading (images, audio, documents cached locally) - Markdown stripping for clean iMessage delivery - Smart progress suppression for platforms without message editing Based on PR #5869 by @benjaminsehl (webhook architecture, GUID resolution, Private API safety, progress suppression) with inbound attachment downloading from PR #4588 by @1960697431 (attachment cache routing). Integration points: Platform enum, env config, adapter factory, auth maps, cron delivery, send_message routing, channel directory, platform hints, toolset definition, setup wizard, status display. 27 tests covering config, adapter, webhook parsing, GUID resolution, attachment download routing, toolset consistency, and prompt hints. --- agent/prompt_builder.py | 7 + cron/scheduler.py | 5 +- gateway/channel_directory.py | 2 +- gateway/config.py | 27 + gateway/platforms/bluebubbles.py | 828 ++++++++++++++++++++++++++++++ gateway/run.py | 27 +- hermes_cli/config.py | 1 + hermes_cli/gateway.py | 28 + hermes_cli/status.py | 1 + hermes_cli/tools_config.py | 1 + tests/gateway/test_bluebubbles.py | 361 +++++++++++++ tools/cronjob_tools.py | 2 +- tools/send_message_tool.py | 30 ++ toolsets.py | 8 +- 14 files changed, 1321 insertions(+), 7 deletions(-) create mode 100644 gateway/platforms/bluebubbles.py create mode 100644 tests/gateway/test_bluebubbles.py diff --git a/agent/prompt_builder.py b/agent/prompt_builder.py index b1b0891f592..8302973aac7 100644 --- a/agent/prompt_builder.py +++ b/agent/prompt_builder.py @@ -349,6 +349,13 @@ PLATFORM_HINTS = { "only — no markdown, no formatting. SMS messages are limited to ~1600 " "characters, so be brief and direct." ), + "bluebubbles": ( + "You are chatting via iMessage (BlueBubbles). iMessage does not render " + "markdown formatting — use plain text. Keep responses concise as they " + "appear as text messages. You can send media files natively: include " + "MEDIA:/absolute/path/to/file in your response. Images (.jpg, .png, " + ".heic) appear as photos and other files arrive as attachments." + ), } CONTEXT_FILE_MAX_CHARS = 20_000 diff --git a/cron/scheduler.py b/cron/scheduler.py index 33a9b899359..6a7f12acd6c 100644 --- a/cron/scheduler.py +++ b/cron/scheduler.py @@ -44,7 +44,7 @@ logger = logging.getLogger(__name__) _KNOWN_DELIVERY_PLATFORMS = frozenset({ "telegram", "discord", "slack", "whatsapp", "signal", "matrix", "mattermost", "homeassistant", "dingtalk", "feishu", - "wecom", "sms", "email", "webhook", + "wecom", "sms", "email", "webhook", "bluebubbles", }) from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run @@ -91,7 +91,7 @@ def _resolve_delivery_target(job: dict) -> Optional[dict]: } # Origin missing (e.g. job created via API/script) — try each # platform's home channel as a fallback instead of silently dropping. - for platform_name in ("matrix", "telegram", "discord", "slack"): + for platform_name in ("matrix", "telegram", "discord", "slack", "bluebubbles"): chat_id = os.getenv(f"{platform_name.upper()}_HOME_CHANNEL", "") if chat_id: logger.info( @@ -236,6 +236,7 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option "wecom": Platform.WECOM, "email": Platform.EMAIL, "sms": Platform.SMS, + "bluebubbles": Platform.BLUEBUBBLES, } platform = platform_map.get(platform_name.lower()) if not platform: diff --git a/gateway/channel_directory.py b/gateway/channel_directory.py index 0d124721757..022ebcae4e1 100644 --- a/gateway/channel_directory.py +++ b/gateway/channel_directory.py @@ -77,7 +77,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]: logger.warning("Channel directory: failed to build %s: %s", platform.value, e) # Telegram, WhatsApp & Signal can't enumerate chats -- pull from session history - for plat_name in ("telegram", "whatsapp", "signal", "email", "sms"): + for plat_name in ("telegram", "whatsapp", "signal", "email", "sms", "bluebubbles"): if plat_name not in platforms: platforms[plat_name] = _build_from_sessions(plat_name) diff --git a/gateway/config.py b/gateway/config.py index 047ad542f5a..96ee831701f 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -63,6 +63,7 @@ class Platform(Enum): WEBHOOK = "webhook" FEISHU = "feishu" WECOM = "wecom" + BLUEBUBBLES = "bluebubbles" @dataclass @@ -287,6 +288,9 @@ class GatewayConfig: # WeCom uses extra dict for bot credentials elif platform == Platform.WECOM and config.extra.get("bot_id"): connected.append(platform) + # BlueBubbles uses extra dict for local server config + elif platform == Platform.BLUEBUBBLES and config.extra.get("server_url") and config.extra.get("password"): + connected.append(platform) return connected def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]: @@ -948,6 +952,29 @@ def _apply_env_overrides(config: GatewayConfig) -> None: name=os.getenv("WECOM_HOME_CHANNEL_NAME", "Home"), ) + # BlueBubbles (iMessage) + bluebubbles_server_url = os.getenv("BLUEBUBBLES_SERVER_URL") + bluebubbles_password = os.getenv("BLUEBUBBLES_PASSWORD") + if bluebubbles_server_url and bluebubbles_password: + if Platform.BLUEBUBBLES not in config.platforms: + config.platforms[Platform.BLUEBUBBLES] = PlatformConfig() + config.platforms[Platform.BLUEBUBBLES].enabled = True + config.platforms[Platform.BLUEBUBBLES].extra.update({ + "server_url": bluebubbles_server_url.rstrip("/"), + "password": bluebubbles_password, + "webhook_host": os.getenv("BLUEBUBBLES_WEBHOOK_HOST", "127.0.0.1"), + "webhook_port": int(os.getenv("BLUEBUBBLES_WEBHOOK_PORT", "8645")), + "webhook_path": os.getenv("BLUEBUBBLES_WEBHOOK_PATH", "/bluebubbles-webhook"), + "send_read_receipts": os.getenv("BLUEBUBBLES_SEND_READ_RECEIPTS", "true").lower() in ("true", "1", "yes"), + }) + bluebubbles_home = os.getenv("BLUEBUBBLES_HOME_CHANNEL") + if bluebubbles_home and Platform.BLUEBUBBLES in config.platforms: + config.platforms[Platform.BLUEBUBBLES].home_channel = HomeChannel( + platform=Platform.BLUEBUBBLES, + chat_id=bluebubbles_home, + name=os.getenv("BLUEBUBBLES_HOME_CHANNEL_NAME", "Home"), + ) + # Session settings idle_minutes = os.getenv("SESSION_IDLE_MINUTES") if idle_minutes: diff --git a/gateway/platforms/bluebubbles.py b/gateway/platforms/bluebubbles.py new file mode 100644 index 00000000000..83f94d3bf87 --- /dev/null +++ b/gateway/platforms/bluebubbles.py @@ -0,0 +1,828 @@ +"""BlueBubbles iMessage platform adapter. + +Uses the local BlueBubbles macOS server for outbound REST sends and inbound +webhooks. Supports text messaging, media attachments (images, voice, video, +documents), tapback reactions, typing indicators, and read receipts. + +Architecture based on PR #5869 (benjaminsehl) with inbound attachment +downloading from PR #4588 (YuhangLin). +""" + +import asyncio +import json +import logging +import os +import re +import uuid +from datetime import datetime +from typing import Any, Dict, List, Optional +from urllib.parse import quote + +import httpx + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import ( + BasePlatformAdapter, + MessageEvent, + MessageType, + SendResult, + cache_image_from_bytes, + cache_audio_from_bytes, + cache_document_from_bytes, +) + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +DEFAULT_WEBHOOK_HOST = "127.0.0.1" +DEFAULT_WEBHOOK_PORT = 8645 +DEFAULT_WEBHOOK_PATH = "/bluebubbles-webhook" +MAX_TEXT_LENGTH = 4000 + +# Tapback reaction codes (BlueBubbles associatedMessageType values) +_TAPBACK_ADDED = { + 2000: "love", 2001: "like", 2002: "dislike", + 2003: "laugh", 2004: "emphasize", 2005: "question", +} +_TAPBACK_REMOVED = { + 3000: "love", 3001: "like", 3002: "dislike", + 3003: "laugh", 3004: "emphasize", 3005: "question", +} + +# Webhook event types that carry user messages +_MESSAGE_EVENTS = {"new-message", "message", "updated-message"} + +# Log redaction patterns +_PHONE_RE = re.compile(r"\+?\d{7,15}") +_EMAIL_RE = re.compile(r"[\w.+-]+@[\w-]+\.[\w.]+") + + +def _redact(text: str) -> str: + """Redact phone numbers and emails from log output.""" + text = _PHONE_RE.sub("[REDACTED]", text) + text = _EMAIL_RE.sub("[REDACTED]", text) + return text + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def check_bluebubbles_requirements() -> bool: + try: + import aiohttp # noqa: F401 + import httpx as _httpx # noqa: F401 + except ImportError: + return False + return True + + +def _normalize_server_url(raw: str) -> str: + value = (raw or "").strip() + if not value: + return "" + if not re.match(r"^https?://", value, flags=re.I): + value = f"http://{value}" + return value.rstrip("/") + + +def _strip_markdown(text: str) -> str: + """Strip common markdown formatting for iMessage plain-text delivery.""" + text = re.sub(r"\*\*(.+?)\*\*", r"\1", text, flags=re.DOTALL) + text = re.sub(r"\*(.+?)\*", r"\1", text, flags=re.DOTALL) + text = re.sub(r"__(.+?)__", r"\1", text, flags=re.DOTALL) + text = re.sub(r"_(.+?)_", r"\1", text, flags=re.DOTALL) + text = re.sub(r"```[a-zA-Z0-9_+-]*\n?", "", text) + text = re.sub(r"`(.+?)`", r"\1", text) + text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE) + text = re.sub(r"\[([^\]]+)\]\(([^\)]+)\)", r"\1", text) + text = re.sub(r"\n{3,}", "\n\n", text) + return text.strip() + + +# --------------------------------------------------------------------------- +# Adapter +# --------------------------------------------------------------------------- + +class BlueBubblesAdapter(BasePlatformAdapter): + platform = Platform.BLUEBUBBLES + MAX_MESSAGE_LENGTH = MAX_TEXT_LENGTH + + def __init__(self, config: PlatformConfig): + super().__init__(config, Platform.BLUEBUBBLES) + extra = config.extra or {} + self.server_url = _normalize_server_url( + extra.get("server_url") or os.getenv("BLUEBUBBLES_SERVER_URL", "") + ) + self.password = extra.get("password") or os.getenv("BLUEBUBBLES_PASSWORD", "") + self.webhook_host = ( + extra.get("webhook_host") + or os.getenv("BLUEBUBBLES_WEBHOOK_HOST", DEFAULT_WEBHOOK_HOST) + ) + self.webhook_port = int( + extra.get("webhook_port") + or os.getenv("BLUEBUBBLES_WEBHOOK_PORT", str(DEFAULT_WEBHOOK_PORT)) + ) + self.webhook_path = ( + extra.get("webhook_path") + or os.getenv("BLUEBUBBLES_WEBHOOK_PATH", DEFAULT_WEBHOOK_PATH) + ) + if not str(self.webhook_path).startswith("/"): + self.webhook_path = f"/{self.webhook_path}" + self.send_read_receipts = bool(extra.get("send_read_receipts", True)) + self.client: Optional[httpx.AsyncClient] = None + self._runner = None + self._private_api_enabled: Optional[bool] = None + self._helper_connected: bool = False + self._guid_cache: Dict[str, str] = {} + + # ------------------------------------------------------------------ + # API helpers + # ------------------------------------------------------------------ + + def _api_url(self, path: str) -> str: + sep = "&" if "?" in path else "?" + return f"{self.server_url}{path}{sep}password={quote(self.password, safe='')}" + + async def _api_get(self, path: str) -> Dict[str, Any]: + assert self.client is not None + res = await self.client.get(self._api_url(path)) + res.raise_for_status() + return res.json() + + async def _api_post(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]: + assert self.client is not None + res = await self.client.post(self._api_url(path), json=payload) + res.raise_for_status() + return res.json() + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + async def connect(self) -> bool: + if not self.server_url or not self.password: + logger.error( + "[bluebubbles] BLUEBUBBLES_SERVER_URL and BLUEBUBBLES_PASSWORD are required" + ) + return False + from aiohttp import web + + self.client = httpx.AsyncClient(timeout=30.0) + try: + await self._api_get("/api/v1/ping") + info = await self._api_get("/api/v1/server/info") + server_data = (info or {}).get("data", {}) + self._private_api_enabled = bool(server_data.get("private_api")) + self._helper_connected = bool(server_data.get("helper_connected")) + logger.info( + "[bluebubbles] connected to %s (private_api=%s, helper=%s)", + self.server_url, + self._private_api_enabled, + self._helper_connected, + ) + except Exception as exc: + logger.error( + "[bluebubbles] cannot reach server at %s: %s", self.server_url, exc + ) + if self.client: + await self.client.aclose() + self.client = None + return False + + app = web.Application() + app.router.add_get("/health", lambda _: web.Response(text="ok")) + app.router.add_post(self.webhook_path, self._handle_webhook) + self._runner = web.AppRunner(app) + await self._runner.setup() + site = web.TCPSite(self._runner, self.webhook_host, self.webhook_port) + await site.start() + self._mark_connected() + logger.info( + "[bluebubbles] webhook listening on http://%s:%s%s", + self.webhook_host, + self.webhook_port, + self.webhook_path, + ) + return True + + async def disconnect(self) -> None: + if self.client: + await self.client.aclose() + self.client = None + if self._runner: + await self._runner.cleanup() + self._runner = None + self._mark_disconnected() + + # ------------------------------------------------------------------ + # Chat GUID resolution + # ------------------------------------------------------------------ + + async def _resolve_chat_guid(self, target: str) -> Optional[str]: + """Resolve an email/phone to a BlueBubbles chat GUID. + + If *target* already contains a semicolon (raw GUID format like + ``iMessage;-;user@example.com``), it is returned as-is. Otherwise + the adapter queries the BlueBubbles chat list and matches on + ``chatIdentifier`` or participant address. + """ + target = (target or "").strip() + if not target: + return None + # Already a raw GUID + if ";" in target: + return target + if target in self._guid_cache: + return self._guid_cache[target] + try: + payload = await self._api_post( + "/api/v1/chat/query", + {"limit": 100, "offset": 0, "with": ["participants"]}, + ) + for chat in payload.get("data", []) or []: + guid = chat.get("guid") or chat.get("chatGuid") + identifier = chat.get("chatIdentifier") or chat.get("identifier") + if identifier == target: + if guid: + self._guid_cache[target] = guid + return guid + for part in chat.get("participants", []) or []: + if (part.get("address") or "").strip() == target and guid: + self._guid_cache[target] = guid + return guid + except Exception: + pass + return None + + async def _create_chat_for_handle( + self, address: str, message: str + ) -> SendResult: + """Create a new chat by sending the first message to *address*.""" + payload = { + "addresses": [address], + "message": message, + "tempGuid": f"temp-{datetime.utcnow().timestamp()}", + } + try: + res = await self._api_post("/api/v1/chat/new", payload) + data = res.get("data") or {} + msg_id = data.get("guid") or data.get("messageGuid") or "ok" + return SendResult(success=True, message_id=str(msg_id), raw_response=res) + except Exception as exc: + return SendResult(success=False, error=str(exc)) + + # ------------------------------------------------------------------ + # Text sending + # ------------------------------------------------------------------ + + async def send( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + text = _strip_markdown(content or "") + if not text: + return SendResult(success=False, error="BlueBubbles send requires text") + chunks = self.truncate_message(text, max_length=self.MAX_MESSAGE_LENGTH) + last = SendResult(success=True) + for chunk in chunks: + guid = await self._resolve_chat_guid(chat_id) + if not guid: + # If the target looks like an address, try creating a new chat + if self._private_api_enabled and ( + "@" in chat_id or re.match(r"^\+\d+", chat_id) + ): + return await self._create_chat_for_handle(chat_id, chunk) + return SendResult( + success=False, + error=f"BlueBubbles chat not found for target: {chat_id}", + ) + payload: Dict[str, Any] = { + "chatGuid": guid, + "tempGuid": f"temp-{datetime.utcnow().timestamp()}", + "message": chunk, + } + if reply_to and self._private_api_enabled and self._helper_connected: + payload["method"] = "private-api" + payload["selectedMessageGuid"] = reply_to + payload["partIndex"] = 0 + try: + res = await self._api_post("/api/v1/message/text", payload) + data = res.get("data") or {} + msg_id = data.get("guid") or data.get("messageGuid") or "ok" + last = SendResult( + success=True, message_id=str(msg_id), raw_response=res + ) + except Exception as exc: + return SendResult(success=False, error=str(exc)) + return last + + # ------------------------------------------------------------------ + # Media sending (outbound) + # ------------------------------------------------------------------ + + async def _send_attachment( + self, + chat_id: str, + file_path: str, + filename: Optional[str] = None, + caption: Optional[str] = None, + is_audio_message: bool = False, + ) -> SendResult: + """Send a file attachment via BlueBubbles multipart upload.""" + if not self.client: + return SendResult(success=False, error="Not connected") + if not os.path.isfile(file_path): + return SendResult(success=False, error=f"File not found: {file_path}") + + guid = await self._resolve_chat_guid(chat_id) + if not guid: + return SendResult(success=False, error=f"Chat not found: {chat_id}") + + fname = filename or os.path.basename(file_path) + try: + with open(file_path, "rb") as f: + files = {"attachment": (fname, f, "application/octet-stream")} + data: Dict[str, str] = { + "chatGuid": guid, + "name": fname, + "tempGuid": uuid.uuid4().hex, + } + if is_audio_message: + data["isAudioMessage"] = "true" + res = await self.client.post( + self._api_url("/api/v1/message/attachment"), + files=files, + data=data, + timeout=120, + ) + res.raise_for_status() + result = res.json() + + if caption: + await self.send(chat_id, caption) + + if result.get("status") == 200: + rdata = result.get("data") or {} + msg_id = rdata.get("guid") if isinstance(rdata, dict) else None + return SendResult( + success=True, message_id=msg_id, raw_response=result + ) + return SendResult( + success=False, + error=result.get("message", "Attachment upload failed"), + ) + except Exception as e: + return SendResult(success=False, error=str(e)) + + async def send_image( + self, + chat_id: str, + image_url: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + try: + from gateway.platforms.base import cache_image_from_url + + local_path = await cache_image_from_url(image_url) + return await self._send_attachment(chat_id, local_path, caption=caption) + except Exception: + return await super().send_image(chat_id, image_url, caption, reply_to) + + async def send_image_file( + self, + chat_id: str, + image_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, + ) -> SendResult: + return await self._send_attachment(chat_id, image_path, caption=caption) + + async def send_voice( + self, + chat_id: str, + audio_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, + ) -> SendResult: + return await self._send_attachment( + chat_id, audio_path, caption=caption, is_audio_message=True + ) + + async def send_video( + self, + chat_id: str, + video_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, + ) -> SendResult: + return await self._send_attachment(chat_id, video_path, caption=caption) + + async def send_document( + self, + chat_id: str, + file_path: str, + caption: Optional[str] = None, + file_name: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, + ) -> SendResult: + return await self._send_attachment( + chat_id, file_path, filename=file_name, caption=caption + ) + + async def send_animation( + self, + chat_id: str, + animation_url: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + return await self.send_image( + chat_id, animation_url, caption, reply_to, metadata + ) + + # ------------------------------------------------------------------ + # Typing indicators + # ------------------------------------------------------------------ + + async def send_typing(self, chat_id: str, metadata=None) -> None: + if not self._private_api_enabled or not self._helper_connected or not self.client: + return + try: + guid = await self._resolve_chat_guid(chat_id) + if guid: + encoded = quote(guid, safe="") + await self.client.post( + self._api_url(f"/api/v1/chat/{encoded}/typing"), timeout=5 + ) + except Exception: + pass + + async def stop_typing(self, chat_id: str) -> None: + if not self._private_api_enabled or not self._helper_connected or not self.client: + return + try: + guid = await self._resolve_chat_guid(chat_id) + if guid: + encoded = quote(guid, safe="") + await self.client.delete( + self._api_url(f"/api/v1/chat/{encoded}/typing"), timeout=5 + ) + except Exception: + pass + + # ------------------------------------------------------------------ + # Read receipts + # ------------------------------------------------------------------ + + async def mark_read(self, chat_id: str) -> bool: + if not self._private_api_enabled or not self._helper_connected or not self.client: + return False + try: + guid = await self._resolve_chat_guid(chat_id) + if guid: + encoded = quote(guid, safe="") + await self.client.post( + self._api_url(f"/api/v1/chat/{encoded}/read"), timeout=5 + ) + return True + except Exception: + pass + return False + + # ------------------------------------------------------------------ + # Tapback reactions + # ------------------------------------------------------------------ + + async def send_reaction( + self, + chat_id: str, + message_guid: str, + reaction: str, + part_index: int = 0, + ) -> SendResult: + """Send a tapback reaction (requires Private API helper).""" + if not self._private_api_enabled or not self._helper_connected: + return SendResult( + success=False, error="Private API helper not connected" + ) + guid = await self._resolve_chat_guid(chat_id) + if not guid: + return SendResult(success=False, error=f"Chat not found: {chat_id}") + try: + res = await self._api_post( + "/api/v1/message/react", + { + "chatGuid": guid, + "selectedMessageGuid": message_guid, + "reaction": reaction, + "partIndex": part_index, + }, + ) + return SendResult(success=True, raw_response=res) + except Exception as exc: + return SendResult(success=False, error=str(exc)) + + # ------------------------------------------------------------------ + # Chat info + # ------------------------------------------------------------------ + + async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + is_group = ";+;" in (chat_id or "") + info: Dict[str, Any] = { + "name": chat_id, + "type": "group" if is_group else "dm", + } + try: + guid = await self._resolve_chat_guid(chat_id) + if guid: + encoded = quote(guid, safe="") + res = await self._api_get( + f"/api/v1/chat/{encoded}?with=participants" + ) + data = (res or {}).get("data", {}) + display_name = ( + data.get("displayName") + or data.get("chatIdentifier") + or chat_id + ) + participants = [] + for p in data.get("participants", []) or []: + addr = (p.get("address") or "").strip() + if addr: + participants.append(addr) + info["name"] = display_name + if participants: + info["participants"] = participants + except Exception: + pass + return info + + def format_message(self, content: str) -> str: + return _strip_markdown(content) + + # ------------------------------------------------------------------ + # Inbound attachment downloading (from #4588) + # ------------------------------------------------------------------ + + async def _download_attachment( + self, att_guid: str, att_meta: Dict[str, Any] + ) -> Optional[str]: + """Download an attachment from BlueBubbles and cache it locally. + + Returns the local file path on success, None on failure. + """ + if not self.client: + return None + try: + encoded = quote(att_guid, safe="") + resp = await self.client.get( + self._api_url(f"/api/v1/attachment/{encoded}/download"), + timeout=60, + follow_redirects=True, + ) + resp.raise_for_status() + data = resp.content + + mime = (att_meta.get("mimeType") or "").lower() + transfer_name = att_meta.get("transferName", "") + + if mime.startswith("image/"): + ext_map = { + "image/jpeg": ".jpg", + "image/png": ".png", + "image/gif": ".gif", + "image/webp": ".webp", + "image/heic": ".jpg", + "image/heif": ".jpg", + "image/tiff": ".jpg", + } + ext = ext_map.get(mime, ".jpg") + return cache_image_from_bytes(data, ext) + + if mime.startswith("audio/"): + ext_map = { + "audio/mp3": ".mp3", + "audio/mpeg": ".mp3", + "audio/ogg": ".ogg", + "audio/wav": ".wav", + "audio/x-caf": ".mp3", + "audio/mp4": ".m4a", + "audio/aac": ".m4a", + } + ext = ext_map.get(mime, ".mp3") + return cache_audio_from_bytes(data, ext) + + # Videos, documents, and everything else + filename = transfer_name or f"file_{uuid.uuid4().hex[:8]}" + return cache_document_from_bytes(data, filename) + + except Exception as exc: + logger.warning( + "[bluebubbles] failed to download attachment %s: %s", + _redact(att_guid), + exc, + ) + return None + + # ------------------------------------------------------------------ + # Webhook handling + # ------------------------------------------------------------------ + + def _extract_payload_record( + self, payload: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + data = payload.get("data") + if isinstance(data, dict): + return data + if isinstance(data, list): + for item in data: + if isinstance(item, dict): + return item + if isinstance(payload.get("message"), dict): + return payload.get("message") + return payload if isinstance(payload, dict) else None + + @staticmethod + def _value(*candidates: Any) -> Optional[str]: + for candidate in candidates: + if isinstance(candidate, str) and candidate.strip(): + return candidate.strip() + return None + + async def _handle_webhook(self, request): + from aiohttp import web + + token = ( + request.query.get("password") + or request.query.get("guid") + or request.headers.get("x-password") + or request.headers.get("x-guid") + or request.headers.get("x-bluebubbles-guid") + ) + if token != self.password: + return web.json_response({"error": "unauthorized"}, status=401) + try: + raw = await request.read() + body = raw.decode("utf-8", errors="replace") + try: + payload = json.loads(body) + except Exception: + from urllib.parse import parse_qs + + form = parse_qs(body) + payload_str = ( + form.get("payload") + or form.get("data") + or form.get("message") + or [""] + )[0] + payload = json.loads(payload_str) if payload_str else {} + except Exception as exc: + logger.error("[bluebubbles] webhook parse error: %s", exc) + return web.json_response({"error": "invalid payload"}, status=400) + + event_type = self._value(payload.get("type"), payload.get("event")) or "" + # Only process message events; silently acknowledge everything else + if event_type and event_type not in _MESSAGE_EVENTS: + return web.Response(text="ok") + + record = self._extract_payload_record(payload) or {} + is_from_me = bool( + record.get("isFromMe") + or record.get("fromMe") + or record.get("is_from_me") + ) + if is_from_me: + return web.Response(text="ok") + + # Skip tapback reactions delivered as messages + assoc_type = record.get("associatedMessageType") + if isinstance(assoc_type, int) and assoc_type in { + **_TAPBACK_ADDED, + **_TAPBACK_REMOVED, + }: + return web.Response(text="ok") + + text = ( + self._value( + record.get("text"), record.get("message"), record.get("body") + ) + or "" + ) + + # --- Inbound attachment handling --- + attachments = record.get("attachments") or [] + media_urls: List[str] = [] + media_types: List[str] = [] + msg_type = MessageType.TEXT + + for att in attachments: + att_guid = att.get("guid", "") + if not att_guid: + continue + cached = await self._download_attachment(att_guid, att) + if cached: + mime = (att.get("mimeType") or "").lower() + media_urls.append(cached) + media_types.append(mime) + if mime.startswith("image/"): + msg_type = MessageType.PHOTO + elif mime.startswith("audio/") or (att.get("uti") or "").endswith( + "caf" + ): + msg_type = MessageType.VOICE + elif mime.startswith("video/"): + msg_type = MessageType.VIDEO + else: + msg_type = MessageType.DOCUMENT + + # With multiple attachments, prefer PHOTO if any images present + if len(media_urls) > 1: + mime_prefixes = {(m or "").split("/")[0] for m in media_types} + if "image" in mime_prefixes: + msg_type = MessageType.PHOTO + + if not text and media_urls: + text = "(attachment)" + # --- End attachment handling --- + + chat_guid = self._value( + record.get("chatGuid"), + payload.get("chatGuid"), + record.get("chat_guid"), + payload.get("chat_guid"), + payload.get("guid"), + ) + chat_identifier = self._value( + record.get("chatIdentifier"), + record.get("identifier"), + payload.get("chatIdentifier"), + payload.get("identifier"), + ) + sender = ( + self._value( + record.get("handle", {}).get("address") + if isinstance(record.get("handle"), dict) + else None, + record.get("sender"), + record.get("from"), + record.get("address"), + ) + or chat_identifier + or chat_guid + ) + if not (chat_guid or chat_identifier) and sender: + chat_identifier = sender + if not sender or not (chat_guid or chat_identifier) or not text: + return web.json_response({"error": "missing message fields"}, status=400) + + session_chat_id = chat_guid or chat_identifier + is_group = bool(record.get("isGroup")) or (";+;" in (chat_guid or "")) + source = self.build_source( + chat_id=session_chat_id, + chat_name=chat_identifier or sender, + chat_type="group" if is_group else "dm", + user_id=sender, + user_name=sender, + chat_id_alt=chat_identifier, + ) + event = MessageEvent( + text=text, + message_type=msg_type, + source=source, + raw_message=payload, + message_id=self._value( + record.get("guid"), + record.get("messageGuid"), + record.get("id"), + ), + reply_to_message_id=self._value( + record.get("threadOriginatorGuid"), + record.get("associatedMessageGuid"), + ), + media_urls=media_urls, + media_types=media_types, + ) + task = asyncio.create_task(self.handle_message(event)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + + # Fire-and-forget read receipt + if self.send_read_receipts and session_chat_id: + asyncio.create_task(self.mark_read(session_chat_id)) + + return web.Response(text="ok") diff --git a/gateway/run.py b/gateway/run.py index c9d3f07e6ec..27703a10248 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1075,6 +1075,7 @@ class GatewayRunner: "MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS", "FEISHU_ALLOWED_USERS", "WECOM_ALLOWED_USERS", + "BLUEBUBBLES_ALLOWED_USERS", "GATEWAY_ALLOWED_USERS") ) _allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any( @@ -1085,7 +1086,8 @@ class GatewayRunner: "SMS_ALLOW_ALL_USERS", "MATTERMOST_ALLOW_ALL_USERS", "MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS", "FEISHU_ALLOW_ALL_USERS", - "WECOM_ALLOW_ALL_USERS") + "WECOM_ALLOW_ALL_USERS", + "BLUEBUBBLES_ALLOW_ALL_USERS") ) if not _any_allowlist and not _allow_all: logger.warning( @@ -1656,6 +1658,13 @@ class GatewayRunner: adapter.gateway_runner = self # For cross-platform delivery return adapter + elif platform == Platform.BLUEBUBBLES: + from gateway.platforms.bluebubbles import BlueBubblesAdapter, check_bluebubbles_requirements + if not check_bluebubbles_requirements(): + logger.warning("BlueBubbles: aiohttp/httpx missing or BLUEBUBBLES_SERVER_URL/BLUEBUBBLES_PASSWORD not configured") + return None + return BlueBubblesAdapter(config) + return None def _is_user_authorized(self, source: SessionSource) -> bool: @@ -1694,6 +1703,7 @@ class GatewayRunner: Platform.DINGTALK: "DINGTALK_ALLOWED_USERS", Platform.FEISHU: "FEISHU_ALLOWED_USERS", Platform.WECOM: "WECOM_ALLOWED_USERS", + Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOWED_USERS", } platform_allow_all_map = { Platform.TELEGRAM: "TELEGRAM_ALLOW_ALL_USERS", @@ -1708,6 +1718,7 @@ class GatewayRunner: Platform.DINGTALK: "DINGTALK_ALLOW_ALL_USERS", Platform.FEISHU: "FEISHU_ALLOW_ALL_USERS", Platform.WECOM: "WECOM_ALLOW_ALL_USERS", + Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOW_ALL_USERS", } # Per-platform allow-all flag (e.g., DISCORD_ALLOW_ALL_USERS=true) @@ -5523,7 +5534,7 @@ class GatewayRunner: Platform.TELEGRAM, Platform.DISCORD, Platform.SLACK, Platform.WHATSAPP, Platform.SIGNAL, Platform.MATTERMOST, Platform.MATRIX, Platform.HOMEASSISTANT, Platform.EMAIL, Platform.SMS, Platform.DINGTALK, - Platform.FEISHU, Platform.WECOM, Platform.LOCAL, + Platform.FEISHU, Platform.WECOM, Platform.BLUEBUBBLES, Platform.LOCAL, }) async def _handle_update_command(self, event: MessageEvent) -> str: @@ -6426,6 +6437,18 @@ class GatewayRunner: if not adapter: return + # Skip tool progress for platforms that don't support message + # editing (e.g. iMessage/BlueBubbles) — each progress update + # would become a separate message bubble, which is noisy. + from gateway.platforms.base import BasePlatformAdapter as _BaseAdapter + if type(adapter).edit_message is _BaseAdapter.edit_message: + while not progress_queue.empty(): + try: + progress_queue.get_nowait() + except Exception: + break + return + progress_lines = [] # Accumulated tool lines progress_msg_id = None # ID of the progress message to edit can_edit = True # False once an edit fails (platform doesn't support it) diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 8b5da35220d..4357119a2ad 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -39,6 +39,7 @@ _EXTRA_ENV_KEYS = frozenset({ "DINGTALK_CLIENT_ID", "DINGTALK_CLIENT_SECRET", "FEISHU_APP_ID", "FEISHU_APP_SECRET", "FEISHU_ENCRYPT_KEY", "FEISHU_VERIFICATION_TOKEN", "WECOM_BOT_ID", "WECOM_SECRET", + "BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_PASSWORD", "TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT", "WHATSAPP_MODE", "WHATSAPP_ENABLED", "MATTERMOST_HOME_CHANNEL", "MATTERMOST_REPLY_MODE", diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index 89b01b18c57..82689f8fffd 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -1588,6 +1588,34 @@ _PLATFORMS = [ "help": "Chat ID for scheduled results and notifications."}, ], }, + { + "key": "bluebubbles", + "label": "BlueBubbles (iMessage)", + "emoji": "💬", + "token_var": "BLUEBUBBLES_SERVER_URL", + "setup_instructions": [ + "1. Install BlueBubbles on a Mac that will act as your iMessage server:", + " https://bluebubbles.app/", + "2. Complete the BlueBubbles setup wizard — sign in with your Apple ID", + "3. In BlueBubbles Settings → API, note the Server URL and password", + "4. The server URL is typically http://:1234", + "5. Hermes connects via the BlueBubbles REST API and receives", + " incoming messages via a local webhook", + "6. To authorize users, use DM pairing: hermes pairing generate bluebubbles", + " Share the code — the user sends it via iMessage to get approved", + ], + "vars": [ + {"name": "BLUEBUBBLES_SERVER_URL", "prompt": "BlueBubbles server URL (e.g. http://192.168.1.10:1234)", "password": False, + "help": "The URL shown in BlueBubbles Settings → API."}, + {"name": "BLUEBUBBLES_PASSWORD", "prompt": "BlueBubbles server password", "password": True, + "help": "The password shown in BlueBubbles Settings → API."}, + {"name": "BLUEBUBBLES_ALLOWED_USERS", "prompt": "Pre-authorized phone numbers or iMessage IDs (comma-separated, or leave empty for DM pairing)", "password": False, + "is_allowlist": True, + "help": "Optional — pre-authorize specific users. Leave empty to use DM pairing instead (recommended)."}, + {"name": "BLUEBUBBLES_HOME_CHANNEL", "prompt": "Home channel (phone number or iMessage ID for cron/notifications, or empty)", "password": False, + "help": "Phone number or Apple ID to deliver cron results and notifications to."}, + ], + }, ] diff --git a/hermes_cli/status.py b/hermes_cli/status.py index 6fe8f7df0b4..eed89885d29 100644 --- a/hermes_cli/status.py +++ b/hermes_cli/status.py @@ -302,6 +302,7 @@ def show_status(args): "DingTalk": ("DINGTALK_CLIENT_ID", None), "Feishu": ("FEISHU_APP_ID", "FEISHU_HOME_CHANNEL"), "WeCom": ("WECOM_BOT_ID", "WECOM_HOME_CHANNEL"), + "BlueBubbles": ("BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_HOME_CHANNEL"), } for name, (token_var, home_var) in platforms.items(): diff --git a/hermes_cli/tools_config.py b/hermes_cli/tools_config.py index 65525d27d00..9a50a2c5d5f 100644 --- a/hermes_cli/tools_config.py +++ b/hermes_cli/tools_config.py @@ -126,6 +126,7 @@ PLATFORMS = { "slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"}, "whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"}, "signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"}, + "bluebubbles": {"label": "💙 BlueBubbles", "default_toolset": "hermes-bluebubbles"}, "homeassistant": {"label": "🏠 Home Assistant", "default_toolset": "hermes-homeassistant"}, "email": {"label": "📧 Email", "default_toolset": "hermes-email"}, "matrix": {"label": "💬 Matrix", "default_toolset": "hermes-matrix"}, diff --git a/tests/gateway/test_bluebubbles.py b/tests/gateway/test_bluebubbles.py new file mode 100644 index 00000000000..939a69ff152 --- /dev/null +++ b/tests/gateway/test_bluebubbles.py @@ -0,0 +1,361 @@ +"""Tests for the BlueBubbles iMessage gateway adapter.""" +import pytest + +from gateway.config import Platform, PlatformConfig + + +def _make_adapter(monkeypatch, **extra): + monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234") + monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret") + from gateway.platforms.bluebubbles import BlueBubblesAdapter + + cfg = PlatformConfig( + enabled=True, + extra={ + "server_url": "http://localhost:1234", + "password": "secret", + **extra, + }, + ) + return BlueBubblesAdapter(cfg) + + +class TestBlueBubblesPlatformEnum: + def test_bluebubbles_enum_exists(self): + assert Platform.BLUEBUBBLES.value == "bluebubbles" + + +class TestBlueBubblesConfigLoading: + def test_apply_env_overrides_bluebubbles(self, monkeypatch): + monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234") + monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret") + monkeypatch.setenv("BLUEBUBBLES_WEBHOOK_PORT", "9999") + from gateway.config import GatewayConfig, _apply_env_overrides + + config = GatewayConfig() + _apply_env_overrides(config) + assert Platform.BLUEBUBBLES in config.platforms + bc = config.platforms[Platform.BLUEBUBBLES] + assert bc.enabled is True + assert bc.extra["server_url"] == "http://localhost:1234" + assert bc.extra["password"] == "secret" + assert bc.extra["webhook_port"] == 9999 + + def test_connected_platforms_includes_bluebubbles(self, monkeypatch): + monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234") + monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret") + from gateway.config import GatewayConfig, _apply_env_overrides + + config = GatewayConfig() + _apply_env_overrides(config) + assert Platform.BLUEBUBBLES in config.get_connected_platforms() + + def test_home_channel_set_from_env(self, monkeypatch): + monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234") + monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret") + monkeypatch.setenv("BLUEBUBBLES_HOME_CHANNEL", "user@example.com") + from gateway.config import GatewayConfig, _apply_env_overrides + + config = GatewayConfig() + _apply_env_overrides(config) + hc = config.platforms[Platform.BLUEBUBBLES].home_channel + assert hc is not None + assert hc.chat_id == "user@example.com" + + def test_not_connected_without_password(self, monkeypatch): + monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234") + monkeypatch.delenv("BLUEBUBBLES_PASSWORD", raising=False) + from gateway.config import GatewayConfig, _apply_env_overrides + + config = GatewayConfig() + _apply_env_overrides(config) + assert Platform.BLUEBUBBLES not in config.get_connected_platforms() + + +class TestBlueBubblesHelpers: + def test_check_requirements(self, monkeypatch): + monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234") + monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret") + from gateway.platforms.bluebubbles import check_bluebubbles_requirements + + assert check_bluebubbles_requirements() is True + + def test_format_message_strips_markdown(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + assert adapter.format_message("**Hello** `world`") == "Hello world" + + def test_strip_markdown_headers(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + assert adapter.format_message("## Heading\ntext") == "Heading\ntext" + + def test_strip_markdown_links(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + assert adapter.format_message("[click here](http://example.com)") == "click here" + + def test_init_normalizes_webhook_path(self, monkeypatch): + adapter = _make_adapter(monkeypatch, webhook_path="bluebubbles-webhook") + assert adapter.webhook_path == "/bluebubbles-webhook" + + def test_init_preserves_leading_slash(self, monkeypatch): + adapter = _make_adapter(monkeypatch, webhook_path="/my-hook") + assert adapter.webhook_path == "/my-hook" + + def test_server_url_normalized(self, monkeypatch): + adapter = _make_adapter(monkeypatch, server_url="http://localhost:1234/") + assert adapter.server_url == "http://localhost:1234" + + def test_server_url_adds_scheme(self, monkeypatch): + adapter = _make_adapter(monkeypatch, server_url="localhost:1234") + assert adapter.server_url == "http://localhost:1234" + + +class TestBlueBubblesWebhookParsing: + def test_webhook_prefers_chat_guid_over_message_guid(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + payload = { + "guid": "MESSAGE-GUID", + "chatGuid": "iMessage;-;user@example.com", + "chatIdentifier": "user@example.com", + } + record = adapter._extract_payload_record(payload) or {} + chat_guid = adapter._value( + record.get("chatGuid"), + payload.get("chatGuid"), + record.get("chat_guid"), + payload.get("chat_guid"), + payload.get("guid"), + ) + assert chat_guid == "iMessage;-;user@example.com" + + def test_webhook_can_fall_back_to_sender_when_chat_fields_missing(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + payload = { + "data": { + "guid": "MESSAGE-GUID", + "text": "hello", + "handle": {"address": "user@example.com"}, + "isFromMe": False, + } + } + record = adapter._extract_payload_record(payload) or {} + chat_guid = adapter._value( + record.get("chatGuid"), + payload.get("chatGuid"), + record.get("chat_guid"), + payload.get("chat_guid"), + payload.get("guid"), + ) + chat_identifier = adapter._value( + record.get("chatIdentifier"), + record.get("identifier"), + payload.get("chatIdentifier"), + payload.get("identifier"), + ) + sender = ( + adapter._value( + record.get("handle", {}).get("address") + if isinstance(record.get("handle"), dict) + else None, + record.get("sender"), + record.get("from"), + record.get("address"), + ) + or chat_identifier + or chat_guid + ) + if not (chat_guid or chat_identifier) and sender: + chat_identifier = sender + assert chat_identifier == "user@example.com" + + def test_extract_payload_record_accepts_list_data(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + payload = { + "type": "new-message", + "data": [ + { + "text": "hello", + "chatGuid": "iMessage;-;user@example.com", + "chatIdentifier": "user@example.com", + } + ], + } + record = adapter._extract_payload_record(payload) + assert record == payload["data"][0] + + def test_extract_payload_record_dict_data(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + payload = {"data": {"text": "hello", "chatGuid": "iMessage;-;+1234"}} + record = adapter._extract_payload_record(payload) + assert record["text"] == "hello" + + def test_extract_payload_record_fallback_to_message(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + payload = {"message": {"text": "hello"}} + record = adapter._extract_payload_record(payload) + assert record["text"] == "hello" + + +class TestBlueBubblesGuidResolution: + def test_raw_guid_returned_as_is(self, monkeypatch): + """If target already contains ';' it's a raw GUID — return unchanged.""" + adapter = _make_adapter(monkeypatch) + import asyncio + + result = asyncio.get_event_loop().run_until_complete( + adapter._resolve_chat_guid("iMessage;-;user@example.com") + ) + assert result == "iMessage;-;user@example.com" + + def test_empty_target_returns_none(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + import asyncio + + result = asyncio.get_event_loop().run_until_complete( + adapter._resolve_chat_guid("") + ) + assert result is None + + +class TestBlueBubblesToolsetIntegration: + def test_toolset_exists(self): + from toolsets import TOOLSETS + + assert "hermes-bluebubbles" in TOOLSETS + + def test_toolset_in_gateway_composite(self): + from toolsets import TOOLSETS + + gateway = TOOLSETS["hermes-gateway"] + assert "hermes-bluebubbles" in gateway["includes"] + + +class TestBlueBubblesPromptHint: + def test_platform_hint_exists(self): + from agent.prompt_builder import PLATFORM_HINTS + + assert "bluebubbles" in PLATFORM_HINTS + hint = PLATFORM_HINTS["bluebubbles"] + assert "iMessage" in hint + assert "plain text" in hint + + +class TestBlueBubblesAttachmentDownload: + """Verify _download_attachment routes to the correct cache helper.""" + + def test_download_image_uses_image_cache(self, monkeypatch): + """Image MIME routes to cache_image_from_bytes.""" + adapter = _make_adapter(monkeypatch) + import asyncio + import httpx + + # Mock the HTTP client response + class MockResponse: + status_code = 200 + content = b"\x89PNG\r\n\x1a\n" + + def raise_for_status(self): + pass + + async def mock_get(*args, **kwargs): + return MockResponse() + + adapter.client = type("MockClient", (), {"get": mock_get})() + + cached_path = None + + def mock_cache_image(data, ext): + nonlocal cached_path + cached_path = f"/tmp/test_image{ext}" + return cached_path + + monkeypatch.setattr( + "gateway.platforms.bluebubbles.cache_image_from_bytes", + mock_cache_image, + ) + + att_meta = {"mimeType": "image/png", "transferName": "photo.png"} + result = asyncio.get_event_loop().run_until_complete( + adapter._download_attachment("att-guid-123", att_meta) + ) + assert result == "/tmp/test_image.png" + + def test_download_audio_uses_audio_cache(self, monkeypatch): + """Audio MIME routes to cache_audio_from_bytes.""" + adapter = _make_adapter(monkeypatch) + import asyncio + + class MockResponse: + status_code = 200 + content = b"fake-audio-data" + + def raise_for_status(self): + pass + + async def mock_get(*args, **kwargs): + return MockResponse() + + adapter.client = type("MockClient", (), {"get": mock_get})() + + cached_path = None + + def mock_cache_audio(data, ext): + nonlocal cached_path + cached_path = f"/tmp/test_audio{ext}" + return cached_path + + monkeypatch.setattr( + "gateway.platforms.bluebubbles.cache_audio_from_bytes", + mock_cache_audio, + ) + + att_meta = {"mimeType": "audio/mpeg", "transferName": "voice.mp3"} + result = asyncio.get_event_loop().run_until_complete( + adapter._download_attachment("att-guid-456", att_meta) + ) + assert result == "/tmp/test_audio.mp3" + + def test_download_document_uses_document_cache(self, monkeypatch): + """Non-image/audio MIME routes to cache_document_from_bytes.""" + adapter = _make_adapter(monkeypatch) + import asyncio + + class MockResponse: + status_code = 200 + content = b"fake-doc-data" + + def raise_for_status(self): + pass + + async def mock_get(*args, **kwargs): + return MockResponse() + + adapter.client = type("MockClient", (), {"get": mock_get})() + + cached_path = None + + def mock_cache_doc(data, filename): + nonlocal cached_path + cached_path = f"/tmp/{filename}" + return cached_path + + monkeypatch.setattr( + "gateway.platforms.bluebubbles.cache_document_from_bytes", + mock_cache_doc, + ) + + att_meta = {"mimeType": "application/pdf", "transferName": "report.pdf"} + result = asyncio.get_event_loop().run_until_complete( + adapter._download_attachment("att-guid-789", att_meta) + ) + assert result == "/tmp/report.pdf" + + def test_download_returns_none_without_client(self, monkeypatch): + """No client → returns None gracefully.""" + adapter = _make_adapter(monkeypatch) + adapter.client = None + import asyncio + + result = asyncio.get_event_loop().run_until_complete( + adapter._download_attachment("att-guid", {"mimeType": "image/png"}) + ) + assert result is None diff --git a/tools/cronjob_tools.py b/tools/cronjob_tools.py index 595ad8bc71a..ccb8bc6f63d 100644 --- a/tools/cronjob_tools.py +++ b/tools/cronjob_tools.py @@ -455,7 +455,7 @@ Important safety rule: cron-run sessions should not recursively schedule more cr }, "deliver": { "type": "string", - "description": "Delivery target: origin, local, telegram, discord, slack, whatsapp, signal, matrix, mattermost, homeassistant, dingtalk, feishu, wecom, email, sms, or platform:chat_id or platform:chat_id:thread_id for Telegram topics. Examples: 'origin', 'local', 'telegram', 'telegram:-1001234567890:17585', 'discord:#engineering'" + "description": "Delivery target: origin, local, telegram, discord, slack, whatsapp, signal, matrix, mattermost, homeassistant, dingtalk, feishu, wecom, email, sms, bluebubbles, or platform:chat_id or platform:chat_id:thread_id for Telegram topics. Examples: 'origin', 'local', 'telegram', 'telegram:-1001234567890:17585', 'discord:#engineering'" }, "skills": { "type": "array", diff --git a/tools/send_message_tool.py b/tools/send_message_tool.py index 164b8a2f47e..76b3e158205 100644 --- a/tools/send_message_tool.py +++ b/tools/send_message_tool.py @@ -148,6 +148,7 @@ def _handle_send(args): "slack": Platform.SLACK, "whatsapp": Platform.WHATSAPP, "signal": Platform.SIGNAL, + "bluebubbles": Platform.BLUEBUBBLES, "matrix": Platform.MATRIX, "mattermost": Platform.MATTERMOST, "homeassistant": Platform.HOMEASSISTANT, @@ -396,6 +397,8 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None, result = await _send_feishu(pconfig, chat_id, chunk, thread_id=thread_id) elif platform == Platform.WECOM: result = await _send_wecom(pconfig.extra, chat_id, chunk) + elif platform == Platform.BLUEBUBBLES: + result = await _send_bluebubbles(pconfig.extra, chat_id, chunk) else: result = {"error": f"Direct sending not yet implemented for {platform.value}"} @@ -870,6 +873,33 @@ async def _send_wecom(extra, chat_id, message): return _error(f"WeCom send failed: {e}") +async def _send_bluebubbles(extra, chat_id, message): + """Send via BlueBubbles iMessage server using the adapter's REST API.""" + try: + from gateway.platforms.bluebubbles import BlueBubblesAdapter, check_bluebubbles_requirements + if not check_bluebubbles_requirements(): + return {"error": "BlueBubbles requirements not met (need aiohttp + httpx)."} + except ImportError: + return {"error": "BlueBubbles adapter not available."} + + try: + from gateway.config import PlatformConfig + pconfig = PlatformConfig(extra=extra) + adapter = BlueBubblesAdapter(pconfig) + connected = await adapter.connect() + if not connected: + return _error("BlueBubbles: failed to connect to server") + try: + result = await adapter.send(chat_id, message) + if not result.success: + return _error(f"BlueBubbles send failed: {result.error}") + return {"success": True, "platform": "bluebubbles", "chat_id": chat_id, "message_id": result.message_id} + finally: + await adapter.disconnect() + except Exception as e: + return _error(f"BlueBubbles send failed: {e}") + + async def _send_feishu(pconfig, chat_id, message, media_files=None, thread_id=None): """Send via Feishu/Lark using the adapter's send pipeline.""" try: diff --git a/toolsets.py b/toolsets.py index 2a359b60a75..a786ee7c663 100644 --- a/toolsets.py +++ b/toolsets.py @@ -311,6 +311,12 @@ TOOLSETS = { "includes": [] }, + "hermes-bluebubbles": { + "description": "BlueBubbles iMessage bot toolset - Apple iMessage via local BlueBubbles server", + "tools": _HERMES_CORE_TOOLS, + "includes": [] + }, + "hermes-homeassistant": { "description": "Home Assistant bot toolset - smart home event monitoring and control", "tools": _HERMES_CORE_TOOLS, @@ -368,7 +374,7 @@ TOOLSETS = { "hermes-gateway": { "description": "Gateway toolset - union of all messaging platform tools", "tools": [], - "includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-signal", "hermes-homeassistant", "hermes-email", "hermes-sms", "hermes-mattermost", "hermes-matrix", "hermes-dingtalk", "hermes-feishu", "hermes-wecom", "hermes-webhook"] + "includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-signal", "hermes-bluebubbles", "hermes-homeassistant", "hermes-email", "hermes-sms", "hermes-mattermost", "hermes-matrix", "hermes-dingtalk", "hermes-feishu", "hermes-wecom", "hermes-webhook"] } } From 25757d631b493381c22efe45984655b06ae97651 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Boschi?= Date: Thu, 9 Apr 2026 07:27:31 +0200 Subject: [PATCH 22/49] feat(hindsight): feature parity, setup wizard, and config improvements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Port missing features from the hindsight-hermes external integration package into the native plugin. Only touches plugin files — no core changes. Features: - Tags on retain/recall (tags, recall_tags, recall_tags_match) - Recall config (recall_max_tokens, recall_max_input_chars, recall_types, recall_prompt_preamble) - Retain controls (retain_every_n_turns, auto_retain, auto_recall, retain_async via aretain_batch, retain_context) - Bank config via Banks API (bank_mission, bank_retain_mission) - Structured JSON retain with per-message timestamps - Full session accumulation with document_id for dedup - Custom post_setup() wizard with curses picker - Mode-aware dep install (hindsight-client for cloud, hindsight-all for local) - local_external mode and openai_compatible LLM provider - OpenRouter support with auto base URL - Auto-upgrade of hindsight-client to >=0.4.22 on session start - Comprehensive debug logging across all operations - 46 unit tests - Updated README and website docs --- plugins/memory/hindsight/README.md | 74 ++- plugins/memory/hindsight/__init__.py | 449 +++++++++++-- plugins/memory/hindsight/plugin.yaml | 6 +- .../plugins/memory/test_hindsight_provider.py | 598 ++++++++++++++++++ .../user-guide/features/memory-providers.md | 18 +- 5 files changed, 1072 insertions(+), 73 deletions(-) create mode 100644 tests/plugins/memory/test_hindsight_provider.py diff --git a/plugins/memory/hindsight/README.md b/plugins/memory/hindsight/README.md index 3a1df59e4d2..024a9930312 100644 --- a/plugins/memory/hindsight/README.md +++ b/plugins/memory/hindsight/README.md @@ -1,11 +1,12 @@ # Hindsight Memory Provider -Long-term memory with knowledge graph, entity resolution, and multi-strategy retrieval. Supports cloud and local (embedded) modes. +Long-term memory with knowledge graph, entity resolution, and multi-strategy retrieval. Supports cloud, local embedded, and local external modes. ## Requirements - **Cloud:** API key from [ui.hindsight.vectorize.io](https://ui.hindsight.vectorize.io) -- **Local:** API key for a supported LLM provider (OpenAI, Anthropic, Gemini, Groq, MiniMax, or Ollama). Embeddings and reranking run locally — no additional API keys needed. +- **Local Embedded:** API key for a supported LLM provider (OpenAI, Anthropic, Gemini, Groq, OpenRouter, MiniMax, Ollama, or any OpenAI-compatible endpoint). Embeddings and reranking run locally — no additional API keys needed. +- **Local External:** A running Hindsight instance (Docker or self-hosted) reachable over HTTP. ## Setup @@ -21,17 +22,28 @@ hermes config set memory.provider hindsight echo "HINDSIGHT_API_KEY=your-key" >> ~/.hermes/.env ``` -### Cloud Mode +### Cloud Connects to the Hindsight Cloud API. Requires an API key from [ui.hindsight.vectorize.io](https://ui.hindsight.vectorize.io). -### Local Mode +### Local Embedded -Runs an embedded Hindsight server with built-in PostgreSQL. Requires an LLM API key (e.g. Groq, OpenAI, Anthropic) for memory extraction and synthesis. The daemon starts automatically in the background on first use and stops after 5 minutes of inactivity. +Hermes spins up a local Hindsight daemon with built-in PostgreSQL. Requires an LLM API key for memory extraction and synthesis. The daemon starts automatically in the background on first use and stops after 5 minutes of inactivity. + +Supports any OpenAI-compatible LLM endpoint (llama.cpp, vLLM, LM Studio, etc.) — pick `openai_compatible` as the provider and enter the base URL. Daemon startup logs: `~/.hermes/logs/hindsight-embed.log` Daemon runtime logs: `~/.hindsight/profiles/.log` +To open the Hindsight web UI (local embedded mode only): +```bash +hindsight-embed -p hermes ui start +``` + +### Local External + +Points the plugin at an existing Hindsight instance you're already running (Docker, self-hosted, etc.). No daemon management — just a URL and an optional API key. + ## Config Config file: `~/.hermes/hindsight/config.json` @@ -40,40 +52,58 @@ Config file: `~/.hermes/hindsight/config.json` | Key | Default | Description | |-----|---------|-------------| -| `mode` | `cloud` | `cloud` or `local` | -| `api_url` | `https://api.hindsight.vectorize.io` | API URL (cloud mode) | -| `api_url` | `http://localhost:8888` | API URL (local mode, unused — daemon manages its own port) | +| `mode` | `cloud` | `cloud`, `local_embedded`, or `local_external` | +| `api_url` | `https://api.hindsight.vectorize.io` | API URL (cloud and local_external modes) | -### Memory +### Memory Bank | Key | Default | Description | |-----|---------|-------------| | `bank_id` | `hermes` | Memory bank name | -| `budget` | `mid` | Recall thoroughness: `low` / `mid` / `high` | +| `bank_mission` | — | Reflect mission (identity/framing for reflect reasoning). Applied via Banks API. | +| `bank_retain_mission` | — | Retain mission (steers what gets extracted). Applied via Banks API. | + +### Recall + +| Key | Default | Description | +|-----|---------|-------------| +| `recall_budget` | `mid` | Recall thoroughness: `low` / `mid` / `high` | +| `recall_prefetch_method` | `recall` | Auto-recall method: `recall` (raw facts) or `reflect` (LLM synthesis) | +| `recall_max_tokens` | `4096` | Maximum tokens for recall results | +| `recall_max_input_chars` | `800` | Maximum input query length for auto-recall | +| `recall_prompt_preamble` | — | Custom preamble for recalled memories in context | +| `recall_tags` | — | Tags to filter when searching memories | +| `recall_tags_match` | `any` | Tag matching mode: `any` / `all` / `any_strict` / `all_strict` | +| `auto_recall` | `true` | Automatically recall memories before each turn | + +### Retain + +| Key | Default | Description | +|-----|---------|-------------| +| `auto_retain` | `true` | Automatically retain conversation turns | +| `retain_async` | `true` | Process retain asynchronously on the Hindsight server | +| `retain_every_n_turns` | `1` | Retain every N turns (1 = every turn) | +| `retain_context` | `conversation between Hermes Agent and the User` | Context label for retained memories | +| `tags` | — | Tags applied when storing memories | ### Integration | Key | Default | Description | |-----|---------|-------------| | `memory_mode` | `hybrid` | How memories are integrated into the agent | -| `prefetch_method` | `recall` | Method for automatic context injection | **memory_mode:** - `hybrid` — automatic context injection + tools available to the LLM - `context` — automatic injection only, no tools exposed - `tools` — tools only, no automatic injection -**prefetch_method:** -- `recall` — injects raw memory facts (fast) -- `reflect` — injects LLM-synthesized summary (slower, more coherent) - -### Local Mode LLM +### Local Embedded LLM | Key | Default | Description | |-----|---------|-------------| -| `llm_provider` | `openai` | LLM provider: `openai`, `anthropic`, `gemini`, `groq`, `minimax`, `ollama` | -| `llm_model` | per-provider | Model name (e.g. `gpt-4o-mini`, `openai/gpt-oss-120b`) | -| `llm_base_url` | — | LLM Base URL override (e.g. `https://openrouter.ai/api/v1`) | +| `llm_provider` | `openai` | `openai`, `anthropic`, `gemini`, `groq`, `openrouter`, `minimax`, `ollama`, `lmstudio`, `openai_compatible` | +| `llm_model` | per-provider | Model name (e.g. `gpt-4o-mini`, `qwen/qwen3.5-9b`) | +| `llm_base_url` | — | Endpoint URL for `openai_compatible` (e.g. `http://192.168.1.10:8080/v1`) | The LLM API key is stored in `~/.hermes/.env` as `HINDSIGHT_LLM_API_KEY`. @@ -97,4 +127,8 @@ Available in `hybrid` and `tools` memory modes: | `HINDSIGHT_API_URL` | Override API endpoint | | `HINDSIGHT_BANK_ID` | Override bank name | | `HINDSIGHT_BUDGET` | Override recall budget | -| `HINDSIGHT_MODE` | Override mode (`cloud` / `local`) | +| `HINDSIGHT_MODE` | Override mode (`cloud`, `local_embedded`, `local_external`) | + +## Client Version + +Requires `hindsight-client >= 0.4.22`. The plugin auto-upgrades on session start if an older version is detected. diff --git a/plugins/memory/hindsight/__init__.py b/plugins/memory/hindsight/__init__.py index c87497745e8..c39679b73c8 100644 --- a/plugins/memory/hindsight/__init__.py +++ b/plugins/memory/hindsight/__init__.py @@ -28,21 +28,25 @@ from hermes_constants import get_hermes_home from typing import Any, Dict, List from agent.memory_provider import MemoryProvider +from hermes_constants import get_hermes_home from tools.registry import tool_error logger = logging.getLogger(__name__) _DEFAULT_API_URL = "https://api.hindsight.vectorize.io" _DEFAULT_LOCAL_URL = "http://localhost:8888" +_MIN_CLIENT_VERSION = "0.4.22" _VALID_BUDGETS = {"low", "mid", "high"} _PROVIDER_DEFAULT_MODELS = { "openai": "gpt-4o-mini", "anthropic": "claude-haiku-4-5", "gemini": "gemini-2.5-flash", "groq": "openai/gpt-oss-120b", + "openrouter": "qwen/qwen3.5-9b", "minimax": "MiniMax-M2.7", "ollama": "gemma3:12b", "lmstudio": "local-model", + "openai_compatible": "your-model-name", } @@ -188,6 +192,7 @@ class HindsightMemoryProvider(MemoryProvider): self._bank_id = "hermes" self._budget = "mid" self._mode = "cloud" + self._llm_base_url = "" self._memory_mode = "hybrid" # "context", "tools", or "hybrid" self._prefetch_method = "recall" # "recall" or "reflect" self._client = None @@ -195,6 +200,31 @@ class HindsightMemoryProvider(MemoryProvider): self._prefetch_lock = threading.Lock() self._prefetch_thread = None self._sync_thread = None + self._session_id = "" + + # Tags + self._tags: list[str] | None = None + self._recall_tags: list[str] | None = None + self._recall_tags_match = "any" + + # Retain controls + self._auto_retain = True + self._retain_every_n_turns = 1 + self._retain_context = "conversation between Hermes Agent and the User" + self._turn_counter = 0 + self._session_turns: list[str] = [] # accumulates ALL turns for the session + + # Recall controls + self._auto_recall = True + self._recall_max_tokens = 4096 + self._recall_types: list[str] | None = None + self._recall_prompt_preamble = "" + self._recall_max_input_chars = 800 + + # Bank + self._bank_mission = "" + self._bank_retain_mission: str | None = None + self._retain_async = True @property def name(self) -> str: @@ -204,7 +234,7 @@ class HindsightMemoryProvider(MemoryProvider): try: cfg = _load_config() mode = cfg.get("mode", "cloud") - if mode == "local": + if mode in ("local", "local_embedded", "local_external"): return True has_key = bool(cfg.get("apiKey") or os.environ.get("HINDSIGHT_API_KEY", "")) has_url = bool(cfg.get("api_url") or os.environ.get("HINDSIGHT_API_URL", "")) @@ -228,73 +258,306 @@ class HindsightMemoryProvider(MemoryProvider): existing.update(values) config_path.write_text(json.dumps(existing, indent=2)) + def post_setup(self, hermes_home: str, config: dict) -> None: + """Custom setup wizard — installs only the deps needed for the selected mode.""" + import getpass + import subprocess + import shutil + import sys + from pathlib import Path + + from hermes_cli.config import save_config + + from hermes_cli.memory_setup import _curses_select + + print("\n Configuring Hindsight memory:\n") + + # Step 1: Mode selection + mode_items = [ + ("Cloud", "Hindsight Cloud API (lightweight, just needs an API key)"), + ("Local Embedded", "Run Hindsight locally (downloads ~200MB, needs LLM key)"), + ("Local External", "Connect to an existing Hindsight instance"), + ] + mode_idx = _curses_select(" Select mode", mode_items, default=0) + mode = ["cloud", "local_embedded", "local_external"][mode_idx] + + provider_config: dict = {"mode": mode} + env_writes: dict = {} + + # Step 2: Install/upgrade deps for selected mode + _MIN_CLIENT_VERSION = "0.4.22" + cloud_dep = f"hindsight-client>={_MIN_CLIENT_VERSION}" + local_dep = "hindsight-all" + if mode == "local_embedded": + deps_to_install = [local_dep] + elif mode == "local_external": + deps_to_install = [cloud_dep] + else: + deps_to_install = [cloud_dep] + + print(f"\n Checking dependencies...") + uv_path = shutil.which("uv") + if not uv_path: + print(" ⚠ uv not found — install it: curl -LsSf https://astral.sh/uv/install.sh | sh") + print(f" Then run manually: uv pip install --python {sys.executable} {' '.join(deps_to_install)}") + else: + try: + subprocess.run( + [uv_path, "pip", "install", "--python", sys.executable, "--quiet", "--upgrade"] + deps_to_install, + check=True, timeout=120, capture_output=True, + ) + print(f" ✓ Dependencies up to date") + except Exception as e: + print(f" ⚠ Install failed: {e}") + print(f" Run manually: uv pip install --python {sys.executable} {' '.join(deps_to_install)}") + + # Step 3: Mode-specific config + if mode == "cloud": + print(f"\n Get your API key at https://ui.hindsight.vectorize.io\n") + existing_key = os.environ.get("HINDSIGHT_API_KEY", "") + if existing_key: + masked = f"...{existing_key[-4:]}" if len(existing_key) > 4 else "set" + sys.stdout.write(f" API key (current: {masked}, blank to keep): ") + sys.stdout.flush() + api_key = getpass.getpass(prompt="") if sys.stdin.isatty() else sys.stdin.readline().strip() + else: + sys.stdout.write(" API key: ") + sys.stdout.flush() + api_key = getpass.getpass(prompt="") if sys.stdin.isatty() else sys.stdin.readline().strip() + if api_key: + env_writes["HINDSIGHT_API_KEY"] = api_key + + val = input(f" API URL [{_DEFAULT_API_URL}]: ").strip() + if val: + provider_config["api_url"] = val + + elif mode == "local_external": + val = input(f" Hindsight API URL [{_DEFAULT_LOCAL_URL}]: ").strip() + provider_config["api_url"] = val or _DEFAULT_LOCAL_URL + + sys.stdout.write(" API key (optional, blank to skip): ") + sys.stdout.flush() + api_key = getpass.getpass(prompt="") if sys.stdin.isatty() else sys.stdin.readline().strip() + if api_key: + env_writes["HINDSIGHT_API_KEY"] = api_key + + else: # local_embedded + providers_list = list(_PROVIDER_DEFAULT_MODELS.keys()) + llm_items = [ + (p, f"default model: {_PROVIDER_DEFAULT_MODELS[p]}") + for p in providers_list + ] + llm_idx = _curses_select(" Select LLM provider", llm_items, default=0) + llm_provider = providers_list[llm_idx] + + provider_config["llm_provider"] = llm_provider + + if llm_provider == "openai_compatible": + val = input(" LLM endpoint URL (e.g. http://192.168.1.10:8080/v1): ").strip() + if val: + provider_config["llm_base_url"] = val + elif llm_provider == "openrouter": + provider_config["llm_base_url"] = "https://openrouter.ai/api/v1" + + default_model = _PROVIDER_DEFAULT_MODELS.get(llm_provider, "gpt-4o-mini") + val = input(f" LLM model [{default_model}]: ").strip() + provider_config["llm_model"] = val or default_model + + sys.stdout.write(" LLM API key: ") + sys.stdout.flush() + llm_key = getpass.getpass(prompt="") if sys.stdin.isatty() else sys.stdin.readline().strip() + if llm_key: + env_writes["HINDSIGHT_LLM_API_KEY"] = llm_key + + # Step 4: Save everything + provider_config["bank_id"] = "hermes" + provider_config["recall_budget"] = "mid" + bank_id = "hermes" + config["memory"]["provider"] = "hindsight" + save_config(config) + + self.save_config(provider_config, hermes_home) + + if env_writes: + env_path = Path(hermes_home) / ".env" + env_path.parent.mkdir(parents=True, exist_ok=True) + existing_lines = [] + if env_path.exists(): + existing_lines = env_path.read_text().splitlines() + updated_keys = set() + new_lines = [] + for line in existing_lines: + key_match = line.split("=", 1)[0].strip() if "=" in line and not line.startswith("#") else None + if key_match and key_match in env_writes: + new_lines.append(f"{key_match}={env_writes[key_match]}") + updated_keys.add(key_match) + else: + new_lines.append(line) + for k, v in env_writes.items(): + if k not in updated_keys: + new_lines.append(f"{k}={v}") + env_path.write_text("\n".join(new_lines) + "\n") + + print(f"\n ✓ Hindsight memory configured ({mode} mode)") + if env_writes: + print(f" API keys saved to .env") + print(f"\n Start a new session to activate.\n") + def get_config_schema(self): return [ - {"key": "mode", "description": "Cloud API or local embedded mode", "default": "cloud", "choices": ["cloud", "local"]}, - {"key": "api_url", "description": "Hindsight API URL", "default": _DEFAULT_API_URL, "when": {"mode": "cloud"}}, + {"key": "mode", "description": "Connection mode", "default": "cloud", "choices": ["cloud", "local_embedded", "local_external"]}, + # Cloud mode + {"key": "api_url", "description": "Hindsight Cloud API URL", "default": _DEFAULT_API_URL, "when": {"mode": "cloud"}}, {"key": "api_key", "description": "Hindsight Cloud API key", "secret": True, "env_var": "HINDSIGHT_API_KEY", "url": "https://ui.hindsight.vectorize.io", "when": {"mode": "cloud"}}, - {"key": "llm_provider", "description": "LLM provider for local mode", "default": "openai", "choices": ["openai", "anthropic", "gemini", "groq", "minimax", "ollama"], "when": {"mode": "local"}}, - {"key": "llm_api_key", "description": "LLM API key for local Hindsight", "secret": True, "env_var": "HINDSIGHT_LLM_API_KEY", "when": {"mode": "local"}}, - {"key": "llm_base_url", "description": "LLM Base URL (e.g. for OpenRouter)", "default": "", "env_var": "HINDSIGHT_API_LLM_BASE_URL", "when": {"mode": "local"}}, - {"key": "llm_model", "description": "LLM model for local mode", "default": "gpt-4o-mini", "default_from": {"field": "llm_provider", "map": _PROVIDER_DEFAULT_MODELS}, "when": {"mode": "local"}}, + # Local external mode + {"key": "api_url", "description": "Hindsight API URL", "default": _DEFAULT_LOCAL_URL, "when": {"mode": "local_external"}}, + {"key": "api_key", "description": "API key (optional)", "secret": True, "env_var": "HINDSIGHT_API_KEY", "when": {"mode": "local_external"}}, + # Local embedded mode + {"key": "llm_provider", "description": "LLM provider", "default": "openai", "choices": ["openai", "anthropic", "gemini", "groq", "openrouter", "minimax", "ollama", "lmstudio", "openai_compatible"], "when": {"mode": "local_embedded"}}, + {"key": "llm_base_url", "description": "Endpoint URL (e.g. http://192.168.1.10:8080/v1)", "default": "", "when": {"mode": "local_embedded", "llm_provider": "openai_compatible"}}, + {"key": "llm_api_key", "description": "LLM API key (optional for openai_compatible)", "secret": True, "env_var": "HINDSIGHT_LLM_API_KEY", "when": {"mode": "local_embedded"}}, + {"key": "llm_model", "description": "LLM model", "default": "gpt-4o-mini", "default_from": {"field": "llm_provider", "map": _PROVIDER_DEFAULT_MODELS}, "when": {"mode": "local_embedded"}}, {"key": "bank_id", "description": "Memory bank name", "default": "hermes"}, - {"key": "budget", "description": "Recall thoroughness", "default": "mid", "choices": ["low", "mid", "high"]}, + {"key": "bank_mission", "description": "Mission/purpose description for the memory bank"}, + {"key": "bank_retain_mission", "description": "Custom extraction prompt for memory retention"}, + {"key": "recall_budget", "description": "Recall thoroughness", "default": "mid", "choices": ["low", "mid", "high"]}, {"key": "memory_mode", "description": "Memory integration mode", "default": "hybrid", "choices": ["hybrid", "context", "tools"]}, - {"key": "prefetch_method", "description": "Auto-recall method", "default": "recall", "choices": ["recall", "reflect"]}, + {"key": "recall_prefetch_method", "description": "Auto-recall method", "default": "recall", "choices": ["recall", "reflect"]}, + {"key": "tags", "description": "Tags applied when storing memories (comma-separated)", "default": ""}, + {"key": "recall_tags", "description": "Tags to filter when searching memories (comma-separated)", "default": ""}, + {"key": "recall_tags_match", "description": "Tag matching mode for recall", "default": "any", "choices": ["any", "all", "any_strict", "all_strict"]}, + {"key": "auto_recall", "description": "Automatically recall memories before each turn", "default": True}, + {"key": "auto_retain", "description": "Automatically retain conversation turns", "default": True}, + {"key": "retain_every_n_turns", "description": "Retain every N turns (1 = every turn)", "default": 1}, + {"key": "retain_async","description": "Process retain asynchronously on the Hindsight server", "default": True}, + {"key": "retain_context", "description": "Context label for retained memories", "default": "conversation between Hermes Agent and the User"}, + {"key": "recall_max_tokens", "description": "Maximum tokens for recall results", "default": 4096}, + {"key": "recall_max_input_chars", "description": "Maximum input query length for auto-recall", "default": 800}, + {"key": "recall_prompt_preamble", "description": "Custom preamble for recalled memories in context"}, ] def _get_client(self): """Return the cached Hindsight client (created once, reused).""" if self._client is None: - if self._mode == "local": + if self._mode == "local_embedded": from hindsight import HindsightEmbedded - # Disable __del__ on the class to prevent "attached to a - # different loop" errors during GC — we handle cleanup in - # shutdown() instead. HindsightEmbedded.__del__ = lambda self: None + llm_provider = self._config.get("llm_provider", "") + if llm_provider in ("openai_compatible", "openrouter"): + llm_provider = "openai" + logger.debug("Creating HindsightEmbedded client (profile=%s, provider=%s)", + self._config.get("profile", "hermes"), llm_provider) kwargs = dict( profile=self._config.get("profile", "hermes"), - llm_provider=self._config.get("llm_provider", ""), - llm_api_key=self._config.get("llm_api_key") or os.environ.get("HINDSIGHT_LLM_API_KEY", ""), + llm_provider=llm_provider, + llm_api_key=self._config.get("llmApiKey") or self._config.get("llm_api_key") or os.environ.get("HINDSIGHT_LLM_API_KEY", ""), llm_model=self._config.get("llm_model", ""), ) - base_url = self._config.get("llm_base_url") or os.environ.get("HINDSIGHT_API_LLM_BASE_URL", "") - if base_url: - kwargs["llm_base_url"] = base_url + if self._llm_base_url: + kwargs["llm_base_url"] = self._llm_base_url self._client = HindsightEmbedded(**kwargs) else: from hindsight_client import Hindsight kwargs = {"base_url": self._api_url, "timeout": 30.0} if self._api_key: kwargs["api_key"] = self._api_key + logger.debug("Creating Hindsight cloud client (url=%s, has_key=%s)", + self._api_url, bool(self._api_key)) self._client = Hindsight(**kwargs) return self._client def initialize(self, session_id: str, **kwargs) -> None: + self._session_id = session_id + + # Check client version and auto-upgrade if needed + try: + from importlib.metadata import version as pkg_version + from packaging.version import Version + installed = pkg_version("hindsight-client") + if Version(installed) < Version(_MIN_CLIENT_VERSION): + logger.warning("hindsight-client %s is outdated (need >=%s), attempting upgrade...", + installed, _MIN_CLIENT_VERSION) + import shutil, subprocess, sys + uv_path = shutil.which("uv") + if uv_path: + try: + subprocess.run( + [uv_path, "pip", "install", "--python", sys.executable, + "--quiet", "--upgrade", f"hindsight-client>={_MIN_CLIENT_VERSION}"], + check=True, timeout=120, capture_output=True, + ) + logger.info("hindsight-client upgraded to >=%s", _MIN_CLIENT_VERSION) + except Exception as e: + logger.warning("Auto-upgrade failed: %s. Run: uv pip install 'hindsight-client>=%s'", + e, _MIN_CLIENT_VERSION) + else: + logger.warning("uv not found. Run: pip install 'hindsight-client>=%s'", _MIN_CLIENT_VERSION) + except Exception: + pass # packaging not available or other issue — proceed anyway + self._config = _load_config() self._mode = self._config.get("mode", "cloud") - self._api_key = self._config.get("apiKey") or os.environ.get("HINDSIGHT_API_KEY", "") - default_url = _DEFAULT_LOCAL_URL if self._mode == "local" else _DEFAULT_API_URL + # "local" is a legacy alias for "local_embedded" + if self._mode == "local": + self._mode = "local_embedded" + self._api_key = self._config.get("apiKey") or self._config.get("api_key") or os.environ.get("HINDSIGHT_API_KEY", "") + default_url = _DEFAULT_LOCAL_URL if self._mode in ("local_embedded", "local_external") else _DEFAULT_API_URL self._api_url = self._config.get("api_url") or os.environ.get("HINDSIGHT_API_URL", default_url) + self._llm_base_url = self._config.get("llm_base_url", "") banks = self._config.get("banks", {}).get("hermes", {}) self._bank_id = self._config.get("bank_id") or banks.get("bankId", "hermes") - budget = self._config.get("budget") or banks.get("budget", "mid") + budget = self._config.get("recall_budget") or self._config.get("budget") or banks.get("budget", "mid") self._budget = budget if budget in _VALID_BUDGETS else "mid" memory_mode = self._config.get("memory_mode", "hybrid") self._memory_mode = memory_mode if memory_mode in ("context", "tools", "hybrid") else "hybrid" - prefetch_method = self._config.get("prefetch_method", "recall") + prefetch_method = self._config.get("recall_prefetch_method", "recall") self._prefetch_method = prefetch_method if prefetch_method in ("recall", "reflect") else "recall" - logger.info("Hindsight initialized: mode=%s, api_url=%s, bank=%s, budget=%s, memory_mode=%s, prefetch_method=%s", - self._mode, self._api_url, self._bank_id, self._budget, self._memory_mode, self._prefetch_method) + # Bank options + self._bank_mission = self._config.get("bank_mission", "") + self._bank_retain_mission = self._config.get("bank_retain_mission") or None + + # Tags + self._tags = self._config.get("tags") or None + self._recall_tags = self._config.get("recall_tags") or None + self._recall_tags_match = self._config.get("recall_tags_match", "any") + + # Retain controls + self._auto_retain = self._config.get("auto_retain", True) + self._retain_every_n_turns = max(1, int(self._config.get("retain_every_n_turns", 1))) + self._retain_context = self._config.get("retain_context", "conversation between Hermes Agent and the User") + + # Recall controls + self._auto_recall = self._config.get("auto_recall", True) + self._recall_max_tokens = int(self._config.get("recall_max_tokens", 4096)) + self._recall_types = self._config.get("recall_types") or None + self._recall_prompt_preamble = self._config.get("recall_prompt_preamble", "") + self._recall_max_input_chars = int(self._config.get("recall_max_input_chars", 800)) + self._retain_async = self._config.get("retain_async", True) + + _client_version = "unknown" + try: + from importlib.metadata import version as pkg_version + _client_version = pkg_version("hindsight-client") + except Exception: + pass + logger.info("Hindsight initialized: mode=%s, api_url=%s, bank=%s, budget=%s, memory_mode=%s, prefetch_method=%s, client=%s", + self._mode, self._api_url, self._bank_id, self._budget, self._memory_mode, self._prefetch_method, _client_version) + logger.debug("Hindsight config: auto_retain=%s, auto_recall=%s, retain_every_n=%d, " + "retain_async=%s, retain_context=%s, " + "recall_max_tokens=%d, recall_max_input_chars=%d, tags=%s, recall_tags=%s", + self._auto_retain, self._auto_recall, self._retain_every_n_turns, + self._retain_async, self._retain_context, + self._recall_max_tokens, self._recall_max_input_chars, + self._tags, self._recall_tags) # For local mode, start the embedded daemon in the background so it # doesn't block the chat. Redirect stdout/stderr to a log file to # prevent rich startup output from spamming the terminal. - if self._mode == "local": + if self._mode == "local_embedded": def _start_daemon(): import traceback log_dir = get_hermes_home() / "logs" @@ -320,6 +583,8 @@ class HindsightMemoryProvider(MemoryProvider): current_provider = self._config.get("llm_provider", "") current_model = self._config.get("llm_model", "") current_base_url = self._config.get("llm_base_url") or os.environ.get("HINDSIGHT_API_LLM_BASE_URL", "") + # Map openai_compatible/openrouter → openai for the daemon (OpenAI wire format) + daemon_provider = "openai" if current_provider in ("openai_compatible", "openrouter") else current_provider # Read saved profile config saved = {} @@ -330,7 +595,7 @@ class HindsightMemoryProvider(MemoryProvider): saved[k.strip()] = v.strip() config_changed = ( - saved.get("HINDSIGHT_API_LLM_PROVIDER") != current_provider or + saved.get("HINDSIGHT_API_LLM_PROVIDER") != daemon_provider or saved.get("HINDSIGHT_API_LLM_MODEL") != current_model or saved.get("HINDSIGHT_API_LLM_API_KEY") != current_key or saved.get("HINDSIGHT_API_LLM_BASE_URL", "") != current_base_url @@ -340,7 +605,7 @@ class HindsightMemoryProvider(MemoryProvider): # Write updated profile .env profile_env.parent.mkdir(parents=True, exist_ok=True) env_lines = ( - f"HINDSIGHT_API_LLM_PROVIDER={current_provider}\n" + f"HINDSIGHT_API_LLM_PROVIDER={daemon_provider}\n" f"HINDSIGHT_API_LLM_API_KEY={current_key}\n" f"HINDSIGHT_API_LLM_MODEL={current_model}\n" f"HINDSIGHT_API_LOG_LEVEL=info\n" @@ -388,47 +653,118 @@ class HindsightMemoryProvider(MemoryProvider): def prefetch(self, query: str, *, session_id: str = "") -> str: if self._prefetch_thread and self._prefetch_thread.is_alive(): + logger.debug("Prefetch: waiting for background thread to complete") self._prefetch_thread.join(timeout=3.0) with self._prefetch_lock: result = self._prefetch_result self._prefetch_result = "" if not result: + logger.debug("Prefetch: no results available") return "" - return f"## Hindsight Memory\n{result}" + logger.debug("Prefetch: returning %d chars of context", len(result)) + header = self._recall_prompt_preamble or ( + "# Hindsight Memory (persistent cross-session context)\n" + "Use this to answer questions about the user and prior sessions. " + "Do not call tools to look up information that is already present here." + ) + return f"{header}\n\n{result}" def queue_prefetch(self, query: str, *, session_id: str = "") -> None: if self._memory_mode == "tools": + logger.debug("Prefetch: skipped (tools-only mode)") return + if not self._auto_recall: + logger.debug("Prefetch: skipped (auto_recall disabled)") + return + # Truncate query to max chars + if self._recall_max_input_chars and len(query) > self._recall_max_input_chars: + query = query[:self._recall_max_input_chars] + def _run(): try: client = self._get_client() if self._prefetch_method == "reflect": + logger.debug("Prefetch: calling reflect (bank=%s, query_len=%d)", self._bank_id, len(query)) resp = _run_sync(client.areflect(bank_id=self._bank_id, query=query, budget=self._budget)) text = resp.text or "" else: - resp = _run_sync(client.arecall(bank_id=self._bank_id, query=query, budget=self._budget)) - text = "\n".join(r.text for r in resp.results if r.text) if resp.results else "" + recall_kwargs: dict = { + "bank_id": self._bank_id, "query": query, + "budget": self._budget, "max_tokens": self._recall_max_tokens, + } + if self._recall_tags: + recall_kwargs["tags"] = self._recall_tags + recall_kwargs["tags_match"] = self._recall_tags_match + if self._recall_types: + recall_kwargs["types"] = self._recall_types + logger.debug("Prefetch: calling recall (bank=%s, query_len=%d, budget=%s)", + self._bank_id, len(query), self._budget) + resp = _run_sync(client.arecall(**recall_kwargs)) + num_results = len(resp.results) if resp.results else 0 + logger.debug("Prefetch: recall returned %d results", num_results) + text = "\n".join(f"- {r.text}" for r in resp.results if r.text) if resp.results else "" if text: with self._prefetch_lock: self._prefetch_result = text except Exception as e: - logger.debug("Hindsight prefetch failed: %s", e) + logger.debug("Hindsight prefetch failed: %s", e, exc_info=True) self._prefetch_thread = threading.Thread(target=_run, daemon=True, name="hindsight-prefetch") self._prefetch_thread.start() def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None: - """Retain conversation turn in background (non-blocking).""" - combined = f"User: {user_content}\nAssistant: {assistant_content}" + """Retain conversation turn in background (non-blocking). + + Respects retain_every_n_turns for batching. + """ + if not self._auto_retain: + logger.debug("sync_turn: skipped (auto_retain disabled)") + return + + from datetime import datetime, timezone + now = datetime.now(timezone.utc).isoformat() + + messages = [ + {"role": "user", "content": user_content, "timestamp": now}, + {"role": "assistant", "content": assistant_content, "timestamp": now}, + ] + + turn = json.dumps(messages) + self._session_turns.append(turn) + self._turn_counter += 1 + + # Only retain every N turns + if self._turn_counter % self._retain_every_n_turns != 0: + logger.debug("sync_turn: buffered turn %d (will retain at turn %d)", + self._turn_counter, self._turn_counter + (self._retain_every_n_turns - self._turn_counter % self._retain_every_n_turns)) + return + + logger.debug("sync_turn: retaining %d turns, total session content %d chars", + len(self._session_turns), sum(len(t) for t in self._session_turns)) + # Send the ENTIRE session as a single JSON array (document_id deduplicates). + # Each element in _session_turns is a JSON string of that turn's messages. + content = "[" + ",".join(self._session_turns) + "]" def _sync(): try: client = self._get_client() - _run_sync(client.aretain( - bank_id=self._bank_id, content=combined, context="conversation" + item: dict = { + "content": content, + "context": self._retain_context, + } + if self._tags: + item["tags"] = self._tags + logger.debug("Hindsight retain: bank=%s, doc=%s, async=%s, content_len=%d, num_turns=%d", + self._bank_id, self._session_id, self._retain_async, len(content), len(self._session_turns)) + _run_sync(client.aretain_batch( + bank_id=self._bank_id, + items=[item], + document_id=self._session_id, + retain_async=self._retain_async, )) + logger.debug("Hindsight retain succeeded") except Exception as e: - logger.warning("Hindsight sync failed: %s", e) + logger.warning("Hindsight sync failed: %s", e, exc_info=True) if self._sync_thread and self._sync_thread.is_alive(): self._sync_thread.join(timeout=5.0) @@ -453,12 +789,18 @@ class HindsightMemoryProvider(MemoryProvider): return tool_error("Missing required parameter: content") context = args.get("context") try: - _run_sync(client.aretain( - bank_id=self._bank_id, content=content, context=context - )) + retain_kwargs: dict = { + "bank_id": self._bank_id, "content": content, "context": context, + } + if self._tags: + retain_kwargs["tags"] = self._tags + logger.debug("Tool hindsight_retain: bank=%s, content_len=%d, context=%s", + self._bank_id, len(content), context) + _run_sync(client.aretain(**retain_kwargs)) + logger.debug("Tool hindsight_retain: success") return json.dumps({"result": "Memory stored successfully."}) except Exception as e: - logger.warning("hindsight_retain failed: %s", e) + logger.warning("hindsight_retain failed: %s", e, exc_info=True) return tool_error(f"Failed to store memory: {e}") elif tool_name == "hindsight_recall": @@ -466,15 +808,26 @@ class HindsightMemoryProvider(MemoryProvider): if not query: return tool_error("Missing required parameter: query") try: - resp = _run_sync(client.arecall( - bank_id=self._bank_id, query=query, budget=self._budget - )) + recall_kwargs: dict = { + "bank_id": self._bank_id, "query": query, "budget": self._budget, + "max_tokens": self._recall_max_tokens, + } + if self._recall_tags: + recall_kwargs["tags"] = self._recall_tags + recall_kwargs["tags_match"] = self._recall_tags_match + if self._recall_types: + recall_kwargs["types"] = self._recall_types + logger.debug("Tool hindsight_recall: bank=%s, query_len=%d, budget=%s", + self._bank_id, len(query), self._budget) + resp = _run_sync(client.arecall(**recall_kwargs)) + num_results = len(resp.results) if resp.results else 0 + logger.debug("Tool hindsight_recall: %d results", num_results) if not resp.results: return json.dumps({"result": "No relevant memories found."}) lines = [f"{i}. {r.text}" for i, r in enumerate(resp.results, 1)] return json.dumps({"result": "\n".join(lines)}) except Exception as e: - logger.warning("hindsight_recall failed: %s", e) + logger.warning("hindsight_recall failed: %s", e, exc_info=True) return tool_error(f"Failed to search memory: {e}") elif tool_name == "hindsight_reflect": @@ -482,24 +835,28 @@ class HindsightMemoryProvider(MemoryProvider): if not query: return tool_error("Missing required parameter: query") try: + logger.debug("Tool hindsight_reflect: bank=%s, query_len=%d, budget=%s", + self._bank_id, len(query), self._budget) resp = _run_sync(client.areflect( bank_id=self._bank_id, query=query, budget=self._budget )) + logger.debug("Tool hindsight_reflect: response_len=%d", len(resp.text or "")) return json.dumps({"result": resp.text or "No relevant memories found."}) except Exception as e: - logger.warning("hindsight_reflect failed: %s", e) + logger.warning("hindsight_reflect failed: %s", e, exc_info=True) return tool_error(f"Failed to reflect: {e}") return tool_error(f"Unknown tool: {tool_name}") def shutdown(self) -> None: + logger.debug("Hindsight shutdown: waiting for background threads") global _loop, _loop_thread for t in (self._prefetch_thread, self._sync_thread): if t and t.is_alive(): t.join(timeout=5.0) if self._client is not None: try: - if self._mode == "local": + if self._mode == "local_embedded": # Use the public close() API. The RuntimeError from # aiohttp's "attached to a different loop" is expected # and harmless — the daemon keeps running independently. diff --git a/plugins/memory/hindsight/plugin.yaml b/plugins/memory/hindsight/plugin.yaml index 79851899202..b12c09142bb 100644 --- a/plugins/memory/hindsight/plugin.yaml +++ b/plugins/memory/hindsight/plugin.yaml @@ -2,9 +2,7 @@ name: hindsight version: 1.0.0 description: "Hindsight — long-term memory with knowledge graph, entity resolution, and multi-strategy retrieval." pip_dependencies: - - hindsight-client - - hindsight-all -requires_env: - - HINDSIGHT_API_KEY + - "hindsight-client>=0.4.22" +requires_env: [] hooks: - on_session_end diff --git a/tests/plugins/memory/test_hindsight_provider.py b/tests/plugins/memory/test_hindsight_provider.py new file mode 100644 index 00000000000..5548a29ad41 --- /dev/null +++ b/tests/plugins/memory/test_hindsight_provider.py @@ -0,0 +1,598 @@ +"""Tests for the Hindsight memory provider plugin. + +Tests cover config loading, tool handlers (tags, max_tokens, types), +prefetch (auto_recall, preamble, query truncation), sync_turn (auto_retain, +turn counting, tags), and schema completeness. +""" + +import json +import threading +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from plugins.memory.hindsight import ( + HindsightMemoryProvider, + RECALL_SCHEMA, + REFLECT_SCHEMA, + RETAIN_SCHEMA, + _load_config, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _clean_env(monkeypatch): + """Ensure no stale env vars leak between tests.""" + for key in ( + "HINDSIGHT_API_KEY", "HINDSIGHT_API_URL", "HINDSIGHT_BANK_ID", + "HINDSIGHT_BUDGET", "HINDSIGHT_MODE", "HINDSIGHT_LLM_API_KEY", + ): + monkeypatch.delenv(key, raising=False) + + +def _make_mock_client(): + """Create a mock Hindsight client with async methods.""" + client = MagicMock() + client.aretain = AsyncMock() + client.arecall = AsyncMock( + return_value=SimpleNamespace( + results=[ + SimpleNamespace(text="Memory 1"), + SimpleNamespace(text="Memory 2"), + ] + ) + ) + client.areflect = AsyncMock( + return_value=SimpleNamespace(text="Synthesized answer") + ) + client.aretain_batch = AsyncMock() + client.aclose = AsyncMock() + return client + + +@pytest.fixture() +def provider(tmp_path, monkeypatch): + """Create an initialized HindsightMemoryProvider with a mock client.""" + config = { + "mode": "cloud", + "apiKey": "test-key", + "api_url": "http://localhost:9999", + "bank_id": "test-bank", + "budget": "mid", + "memory_mode": "hybrid", + } + config_path = tmp_path / "hindsight" / "config.json" + config_path.parent.mkdir(parents=True, exist_ok=True) + config_path.write_text(json.dumps(config)) + + monkeypatch.setattr( + "plugins.memory.hindsight.get_hermes_home", lambda: tmp_path + ) + + p = HindsightMemoryProvider() + p.initialize(session_id="test-session", hermes_home=str(tmp_path), platform="cli") + p._client = _make_mock_client() + return p + + +@pytest.fixture() +def provider_with_config(tmp_path, monkeypatch): + """Create a provider factory that accepts custom config overrides.""" + def _make(**overrides): + config = { + "mode": "cloud", + "apiKey": "test-key", + "api_url": "http://localhost:9999", + "bank_id": "test-bank", + "budget": "mid", + "memory_mode": "hybrid", + } + config.update(overrides) + config_path = tmp_path / "hindsight" / "config.json" + config_path.parent.mkdir(parents=True, exist_ok=True) + config_path.write_text(json.dumps(config)) + + monkeypatch.setattr( + "plugins.memory.hindsight.get_hermes_home", lambda: tmp_path + ) + + p = HindsightMemoryProvider() + p.initialize(session_id="test-session", hermes_home=str(tmp_path), platform="cli") + p._client = _make_mock_client() + return p + return _make + + +# --------------------------------------------------------------------------- +# Schema tests +# --------------------------------------------------------------------------- + + +class TestSchemas: + def test_retain_schema_has_content(self): + assert RETAIN_SCHEMA["name"] == "hindsight_retain" + assert "content" in RETAIN_SCHEMA["parameters"]["properties"] + assert "content" in RETAIN_SCHEMA["parameters"]["required"] + + def test_recall_schema_has_query(self): + assert RECALL_SCHEMA["name"] == "hindsight_recall" + assert "query" in RECALL_SCHEMA["parameters"]["properties"] + assert "query" in RECALL_SCHEMA["parameters"]["required"] + + def test_reflect_schema_has_query(self): + assert REFLECT_SCHEMA["name"] == "hindsight_reflect" + assert "query" in REFLECT_SCHEMA["parameters"]["properties"] + + def test_get_tool_schemas_returns_three(self, provider): + schemas = provider.get_tool_schemas() + assert len(schemas) == 3 + names = {s["name"] for s in schemas} + assert names == {"hindsight_retain", "hindsight_recall", "hindsight_reflect"} + + def test_context_mode_returns_no_tools(self, provider_with_config): + p = provider_with_config(memory_mode="context") + assert p.get_tool_schemas() == [] + + +# --------------------------------------------------------------------------- +# Config tests +# --------------------------------------------------------------------------- + + +class TestConfig: + def test_default_values(self, provider): + assert provider._auto_retain is True + assert provider._auto_recall is True + assert provider._retain_every_n_turns == 1 + assert provider._recall_max_tokens == 4096 + assert provider._recall_max_input_chars == 800 + assert provider._tags is None + assert provider._recall_tags is None + assert provider._bank_mission == "" + assert provider._bank_retain_mission is None + assert provider._retain_context == "conversation between Hermes Agent and the User" + + def test_custom_config_values(self, provider_with_config): + p = provider_with_config( + tags=["tag1", "tag2"], + recall_tags=["recall-tag"], + recall_tags_match="all", + auto_retain=False, + auto_recall=False, + retain_every_n_turns=3, + retain_context="custom-ctx", + bank_retain_mission="Extract key facts", + recall_max_tokens=2048, + recall_types=["world", "experience"], + recall_prompt_preamble="Custom preamble:", + recall_max_input_chars=500, + bank_mission="Test agent mission", + ) + assert p._tags == ["tag1", "tag2"] + assert p._recall_tags == ["recall-tag"] + assert p._recall_tags_match == "all" + assert p._auto_retain is False + assert p._auto_recall is False + assert p._retain_every_n_turns == 3 + assert p._retain_context == "custom-ctx" + assert p._bank_retain_mission == "Extract key facts" + assert p._recall_max_tokens == 2048 + assert p._recall_types == ["world", "experience"] + assert p._recall_prompt_preamble == "Custom preamble:" + assert p._recall_max_input_chars == 500 + assert p._bank_mission == "Test agent mission" + + def test_config_from_env_fallback(self, tmp_path, monkeypatch): + """When no config file exists, falls back to env vars.""" + monkeypatch.setattr( + "plugins.memory.hindsight.get_hermes_home", + lambda: tmp_path / "nonexistent", + ) + monkeypatch.setenv("HINDSIGHT_MODE", "cloud") + monkeypatch.setenv("HINDSIGHT_API_KEY", "env-key") + monkeypatch.setenv("HINDSIGHT_BANK_ID", "env-bank") + monkeypatch.setenv("HINDSIGHT_BUDGET", "high") + + cfg = _load_config() + assert cfg["apiKey"] == "env-key" + assert cfg["banks"]["hermes"]["bankId"] == "env-bank" + assert cfg["banks"]["hermes"]["budget"] == "high" + + +# --------------------------------------------------------------------------- +# Tool handler tests +# --------------------------------------------------------------------------- + + +class TestToolHandlers: + def test_retain_success(self, provider): + result = json.loads(provider.handle_tool_call( + "hindsight_retain", {"content": "user likes dark mode"} + )) + assert result["result"] == "Memory stored successfully." + provider._client.aretain.assert_called_once() + call_kwargs = provider._client.aretain.call_args.kwargs + assert call_kwargs["bank_id"] == "test-bank" + assert call_kwargs["content"] == "user likes dark mode" + + def test_retain_with_tags(self, provider_with_config): + p = provider_with_config(tags=["pref", "ui"]) + p.handle_tool_call("hindsight_retain", {"content": "likes dark mode"}) + call_kwargs = p._client.aretain.call_args.kwargs + assert call_kwargs["tags"] == ["pref", "ui"] + + def test_retain_without_tags(self, provider): + provider.handle_tool_call("hindsight_retain", {"content": "hello"}) + call_kwargs = provider._client.aretain.call_args.kwargs + assert "tags" not in call_kwargs + + def test_retain_missing_content(self, provider): + result = json.loads(provider.handle_tool_call( + "hindsight_retain", {} + )) + assert "error" in result + + def test_recall_success(self, provider): + result = json.loads(provider.handle_tool_call( + "hindsight_recall", {"query": "dark mode"} + )) + assert "Memory 1" in result["result"] + assert "Memory 2" in result["result"] + + def test_recall_passes_max_tokens(self, provider_with_config): + p = provider_with_config(recall_max_tokens=2048) + p.handle_tool_call("hindsight_recall", {"query": "test"}) + call_kwargs = p._client.arecall.call_args.kwargs + assert call_kwargs["max_tokens"] == 2048 + + def test_recall_passes_tags(self, provider_with_config): + p = provider_with_config(recall_tags=["tag1"], recall_tags_match="all") + p.handle_tool_call("hindsight_recall", {"query": "test"}) + call_kwargs = p._client.arecall.call_args.kwargs + assert call_kwargs["tags"] == ["tag1"] + assert call_kwargs["tags_match"] == "all" + + def test_recall_passes_types(self, provider_with_config): + p = provider_with_config(recall_types=["world", "experience"]) + p.handle_tool_call("hindsight_recall", {"query": "test"}) + call_kwargs = p._client.arecall.call_args.kwargs + assert call_kwargs["types"] == ["world", "experience"] + + def test_recall_no_results(self, provider): + provider._client.arecall.return_value = SimpleNamespace(results=[]) + result = json.loads(provider.handle_tool_call( + "hindsight_recall", {"query": "test"} + )) + assert result["result"] == "No relevant memories found." + + def test_recall_missing_query(self, provider): + result = json.loads(provider.handle_tool_call( + "hindsight_recall", {} + )) + assert "error" in result + + def test_reflect_success(self, provider): + result = json.loads(provider.handle_tool_call( + "hindsight_reflect", {"query": "summarize"} + )) + assert result["result"] == "Synthesized answer" + + def test_reflect_missing_query(self, provider): + result = json.loads(provider.handle_tool_call( + "hindsight_reflect", {} + )) + assert "error" in result + + def test_unknown_tool(self, provider): + result = json.loads(provider.handle_tool_call( + "hindsight_unknown", {} + )) + assert "error" in result + + def test_retain_error_handling(self, provider): + provider._client.aretain.side_effect = RuntimeError("connection failed") + result = json.loads(provider.handle_tool_call( + "hindsight_retain", {"content": "test"} + )) + assert "error" in result + assert "connection failed" in result["error"] + + def test_recall_error_handling(self, provider): + provider._client.arecall.side_effect = RuntimeError("timeout") + result = json.loads(provider.handle_tool_call( + "hindsight_recall", {"query": "test"} + )) + assert "error" in result + + +# --------------------------------------------------------------------------- +# Prefetch tests +# --------------------------------------------------------------------------- + + +class TestPrefetch: + def test_prefetch_returns_empty_when_no_result(self, provider): + assert provider.prefetch("test") == "" + + def test_prefetch_default_preamble(self, provider): + provider._prefetch_result = "- some memory" + result = provider.prefetch("test") + assert "Hindsight Memory" in result + assert "- some memory" in result + + def test_prefetch_custom_preamble(self, provider_with_config): + p = provider_with_config(recall_prompt_preamble="Custom header:") + p._prefetch_result = "- memory line" + result = p.prefetch("test") + assert result.startswith("Custom header:") + assert "- memory line" in result + + def test_queue_prefetch_skipped_in_tools_mode(self, provider_with_config): + p = provider_with_config(memory_mode="tools") + p.queue_prefetch("test") + # Should not start a thread + assert p._prefetch_thread is None + + def test_queue_prefetch_skipped_when_auto_recall_off(self, provider_with_config): + p = provider_with_config(auto_recall=False) + p.queue_prefetch("test") + assert p._prefetch_thread is None + + def test_queue_prefetch_truncates_query(self, provider_with_config): + p = provider_with_config(recall_max_input_chars=10) + # Mock _run_sync to capture the query + original_query = None + + def _capture_recall(**kwargs): + nonlocal original_query + original_query = kwargs.get("query", "") + return SimpleNamespace(results=[]) + + p._client.arecall = AsyncMock(side_effect=_capture_recall) + + long_query = "a" * 100 + p.queue_prefetch(long_query) + if p._prefetch_thread: + p._prefetch_thread.join(timeout=5.0) + + # The query passed to arecall should be truncated + if original_query is not None: + assert len(original_query) <= 10 + + def test_queue_prefetch_passes_recall_params(self, provider_with_config): + p = provider_with_config( + recall_tags=["t1"], + recall_tags_match="all", + recall_max_tokens=1024, + recall_types=["world"], + ) + p.queue_prefetch("test query") + if p._prefetch_thread: + p._prefetch_thread.join(timeout=5.0) + + call_kwargs = p._client.arecall.call_args.kwargs + assert call_kwargs["max_tokens"] == 1024 + assert call_kwargs["tags"] == ["t1"] + assert call_kwargs["tags_match"] == "all" + assert call_kwargs["types"] == ["world"] + + +# --------------------------------------------------------------------------- +# sync_turn tests +# --------------------------------------------------------------------------- + + +class TestSyncTurn: + def _get_retain_kwargs(self, provider): + """Helper to get the kwargs from the aretain_batch call.""" + return provider._client.aretain_batch.call_args.kwargs + + def _get_retain_content(self, provider): + """Helper to get the raw content string from the first item.""" + kwargs = self._get_retain_kwargs(provider) + return kwargs["items"][0]["content"] + + def _get_retain_messages(self, provider): + """Helper to parse the first turn's messages from retained content. + + Content is a JSON array of turns: [[msgs...], [msgs...], ...] + For single-turn tests, returns the first turn's messages. + """ + content = self._get_retain_content(provider) + turns = json.loads(content) + return turns[0] if len(turns) == 1 else turns + + def test_sync_turn_retains(self, provider): + provider.sync_turn("hello", "hi there") + if provider._sync_thread: + provider._sync_thread.join(timeout=5.0) + provider._client.aretain_batch.assert_called_once() + messages = self._get_retain_messages(provider) + assert len(messages) == 2 + assert messages[0]["role"] == "user" + assert messages[0]["content"] == "hello" + assert "timestamp" in messages[0] + assert messages[1]["role"] == "assistant" + assert messages[1]["content"] == "hi there" + assert "timestamp" in messages[1] + + def test_sync_turn_skipped_when_auto_retain_off(self, provider_with_config): + p = provider_with_config(auto_retain=False) + p.sync_turn("hello", "hi") + assert p._sync_thread is None + p._client.aretain_batch.assert_not_called() + + def test_sync_turn_with_tags(self, provider_with_config): + p = provider_with_config(tags=["conv", "session1"]) + p.sync_turn("hello", "hi") + if p._sync_thread: + p._sync_thread.join(timeout=5.0) + item = p._client.aretain_batch.call_args.kwargs["items"][0] + assert item["tags"] == ["conv", "session1"] + + def test_sync_turn_uses_aretain_batch(self, provider): + """sync_turn should use aretain_batch with retain_async.""" + provider.sync_turn("hello", "hi") + if provider._sync_thread: + provider._sync_thread.join(timeout=5.0) + provider._client.aretain_batch.assert_called_once() + call_kwargs = provider._client.aretain_batch.call_args.kwargs + assert call_kwargs["document_id"] == "test-session" + assert call_kwargs["retain_async"] is True + assert len(call_kwargs["items"]) == 1 + assert call_kwargs["items"][0]["context"] == "conversation between Hermes Agent and the User" + + def test_sync_turn_custom_context(self, provider_with_config): + p = provider_with_config(retain_context="my-agent") + p.sync_turn("hello", "hi") + if p._sync_thread: + p._sync_thread.join(timeout=5.0) + item = p._client.aretain_batch.call_args.kwargs["items"][0] + assert item["context"] == "my-agent" + + def test_sync_turn_every_n_turns(self, provider_with_config): + """With retain_every_n_turns=3, only retains on every 3rd turn.""" + p = provider_with_config(retain_every_n_turns=3) + + p.sync_turn("turn1-user", "turn1-asst") + assert p._sync_thread is None # not retained yet + + p.sync_turn("turn2-user", "turn2-asst") + assert p._sync_thread is None # not retained yet + + p.sync_turn("turn3-user", "turn3-asst") + assert p._sync_thread is not None # retained! + p._sync_thread.join(timeout=5.0) + + p._client.aretain_batch.assert_called_once() + content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"] + # Should contain all 3 turns + assert "turn1-user" in content + assert "turn2-user" in content + assert "turn3-user" in content + + def test_sync_turn_accumulates_full_session(self, provider_with_config): + """Each retain sends the ENTIRE session, not just the latest batch.""" + p = provider_with_config(retain_every_n_turns=2) + + p.sync_turn("turn1-user", "turn1-asst") + p.sync_turn("turn2-user", "turn2-asst") + if p._sync_thread: + p._sync_thread.join(timeout=5.0) + + p._client.aretain_batch.reset_mock() + + p.sync_turn("turn3-user", "turn3-asst") + p.sync_turn("turn4-user", "turn4-asst") + if p._sync_thread: + p._sync_thread.join(timeout=5.0) + + content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"] + # Should contain ALL turns from the session + assert "turn1-user" in content + assert "turn2-user" in content + assert "turn3-user" in content + assert "turn4-user" in content + + def test_sync_turn_passes_document_id(self, provider): + """sync_turn should pass session_id as document_id for dedup.""" + provider.sync_turn("hello", "hi") + if provider._sync_thread: + provider._sync_thread.join(timeout=5.0) + call_kwargs = provider._client.aretain_batch.call_args.kwargs + assert call_kwargs["document_id"] == "test-session" + + def test_sync_turn_error_does_not_raise(self, provider): + """Errors in sync_turn should be swallowed (non-blocking).""" + provider._client.aretain_batch.side_effect = RuntimeError("network error") + provider.sync_turn("hello", "hi") + if provider._sync_thread: + provider._sync_thread.join(timeout=5.0) + # Should not raise + + +# --------------------------------------------------------------------------- +# System prompt tests +# --------------------------------------------------------------------------- + + +class TestSystemPrompt: + def test_hybrid_mode_prompt(self, provider): + block = provider.system_prompt_block() + assert "Hindsight Memory" in block + assert "hindsight_recall" in block + assert "automatically injected" in block + + def test_context_mode_prompt(self, provider_with_config): + p = provider_with_config(memory_mode="context") + block = p.system_prompt_block() + assert "context mode" in block + assert "hindsight_recall" not in block + + def test_tools_mode_prompt(self, provider_with_config): + p = provider_with_config(memory_mode="tools") + block = p.system_prompt_block() + assert "tools mode" in block + assert "hindsight_recall" in block + + +# --------------------------------------------------------------------------- +# Config schema tests +# --------------------------------------------------------------------------- + + +class TestConfigSchema: + def test_schema_has_all_new_fields(self, provider): + schema = provider.get_config_schema() + keys = {f["key"] for f in schema} + expected_keys = { + "mode", "api_url", "api_key", "llm_provider", "llm_api_key", + "llm_model", "bank_id", "bank_mission", "bank_retain_mission", + "recall_budget", "memory_mode", "recall_prefetch_method", + "tags", "recall_tags", "recall_tags_match", + "auto_recall", "auto_retain", + "retain_every_n_turns", "retain_async", + "retain_context", + "recall_max_tokens", "recall_max_input_chars", + "recall_prompt_preamble", + } + assert expected_keys.issubset(keys), f"Missing: {expected_keys - keys}" + + +# --------------------------------------------------------------------------- +# Availability tests +# --------------------------------------------------------------------------- + + +class TestAvailability: + def test_available_with_api_key(self, tmp_path, monkeypatch): + monkeypatch.setattr( + "plugins.memory.hindsight.get_hermes_home", + lambda: tmp_path / "nonexistent", + ) + monkeypatch.setenv("HINDSIGHT_API_KEY", "test-key") + p = HindsightMemoryProvider() + assert p.is_available() + + def test_not_available_without_config(self, tmp_path, monkeypatch): + monkeypatch.setattr( + "plugins.memory.hindsight.get_hermes_home", + lambda: tmp_path / "nonexistent", + ) + p = HindsightMemoryProvider() + assert not p.is_available() + + def test_available_in_local_mode(self, tmp_path, monkeypatch): + monkeypatch.setattr( + "plugins.memory.hindsight.get_hermes_home", + lambda: tmp_path / "nonexistent", + ) + monkeypatch.setenv("HINDSIGHT_MODE", "local") + p = HindsightMemoryProvider() + assert p.is_available() diff --git a/website/docs/user-guide/features/memory-providers.md b/website/docs/user-guide/features/memory-providers.md index ad0a17ae46f..e76a05414ff 100644 --- a/website/docs/user-guide/features/memory-providers.md +++ b/website/docs/user-guide/features/memory-providers.md @@ -263,12 +263,12 @@ echo "MEM0_API_KEY=your-key" >> ~/.hermes/.env ### Hindsight -Long-term memory with knowledge graph, entity resolution, and multi-strategy retrieval. The `hindsight_reflect` tool provides cross-memory synthesis that no other provider offers. +Long-term memory with knowledge graph, entity resolution, and multi-strategy retrieval. The `hindsight_reflect` tool provides cross-memory synthesis that no other provider offers. Automatically retains full conversation turns (including tool calls) with session-level document tracking. | | | |---|---| | **Best for** | Knowledge graph-based recall with entity relationships | -| **Requires** | Cloud: `pip install hindsight-client` + API key. Local: `pip install hindsight` + LLM key | +| **Requires** | Cloud: API key from [ui.hindsight.vectorize.io](https://ui.hindsight.vectorize.io). Local: LLM API key (OpenAI, Groq, OpenRouter, etc.) | | **Data storage** | Hindsight Cloud or local embedded PostgreSQL | | **Cost** | Hindsight pricing (cloud) or free (local) | @@ -282,13 +282,25 @@ hermes config set memory.provider hindsight echo "HINDSIGHT_API_KEY=your-key" >> ~/.hermes/.env ``` +The setup wizard installs dependencies automatically and only installs what's needed for the selected mode (`hindsight-client` for cloud, `hindsight-all` for local). Requires `hindsight-client >= 0.4.22` (auto-upgraded on session start if outdated). + +**Local mode UI:** `hindsight-embed -p hermes ui start` + **Config:** `$HERMES_HOME/hindsight/config.json` | Key | Default | Description | |-----|---------|-------------| | `mode` | `cloud` | `cloud` or `local` | | `bank_id` | `hermes` | Memory bank identifier | -| `budget` | `mid` | Recall thoroughness: `low` / `mid` / `high` | +| `recall_budget` | `mid` | Recall thoroughness: `low` / `mid` / `high` | +| `memory_mode` | `hybrid` | `hybrid` (context + tools), `context` (auto-inject only), `tools` (tools only) | +| `auto_retain` | `true` | Automatically retain conversation turns | +| `auto_recall` | `true` | Automatically recall memories before each turn | +| `retain_async` | `true` | Process retain asynchronously on the server | +| `tags` | — | Tags applied when storing memories | +| `recall_tags` | — | Tags to filter on recall | + +See [plugin README](https://github.com/NousResearch/hermes-agent/blob/main/plugins/memory/hindsight/README.md) for the full configuration reference. --- From d12f8db0b8c2c1df9b2239b00d5a37b026e85ec7 Mon Sep 17 00:00:00 2001 From: BongSuCHOI Date: Wed, 8 Apr 2026 17:36:32 +0000 Subject: [PATCH 23/49] fix(compaction): token-budget primary tail protection Tail protection was effectively message-count based despite having a token budget, because protect_last_n=20 acted as a hard floor. A single 50K-token tool output would cause all 20 recent messages to be preserved regardless of budget, leaving little room for summarization. Changes: - _find_tail_cut_by_tokens: min_tail reduced from protect_last_n (20) to 3; token budget is now the primary criterion - Soft ceiling at 1.5x budget to avoid cutting mid-oversized-message - _prune_old_tool_results: accepts optional protect_tail_tokens so pruning also respects the token budget instead of a fixed count - compress() minimum message check relaxed from protect_first_n + protect_last_n + 1 to protect_first_n + 3 + 1 - Tool group alignment (no splitting tool_call/result) preserved --- agent/context_compressor.py | 66 ++++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 16 deletions(-) diff --git a/agent/context_compressor.py b/agent/context_compressor.py index 0cd51b06eff..c61cf2c5a78 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -154,12 +154,15 @@ class ContextCompressor: def _prune_old_tool_results( self, messages: List[Dict[str, Any]], protect_tail_count: int, + protect_tail_tokens: int | None = None, ) -> tuple[List[Dict[str, Any]], int]: """Replace old tool result contents with a short placeholder. - Walks backward from the end, protecting the most recent - ``protect_tail_count`` messages. Older tool results get their - content replaced with a placeholder string. + Walks backward from the end, protecting the most recent messages that + fall within ``protect_tail_tokens`` (when provided) OR the last + ``protect_tail_count`` messages (backward-compatible default). + When both are given, the token budget takes priority and the message + count acts as a hard minimum floor. Returns (pruned_messages, pruned_count). """ @@ -168,7 +171,29 @@ class ContextCompressor: result = [m.copy() for m in messages] pruned = 0 - prune_boundary = len(result) - protect_tail_count + + # Determine the prune boundary + if protect_tail_tokens is not None and protect_tail_tokens > 0: + # Token-budget approach: walk backward accumulating tokens + accumulated = 0 + boundary = len(result) + min_protect = min(protect_tail_count, len(result) - 1) + for i in range(len(result) - 1, -1, -1): + msg = result[i] + content_len = len(msg.get("content") or "") + msg_tokens = content_len // _CHARS_PER_TOKEN + 10 + for tc in msg.get("tool_calls") or []: + if isinstance(tc, dict): + args = tc.get("function", {}).get("arguments", "") + msg_tokens += len(args) // _CHARS_PER_TOKEN + if accumulated + msg_tokens > protect_tail_tokens and (len(result) - i) >= min_protect: + boundary = i + break + accumulated += msg_tokens + boundary = i + prune_boundary = max(boundary, len(result) - min_protect) + else: + prune_boundary = len(result) - protect_tail_count for i in range(prune_boundary): msg = result[i] @@ -533,13 +558,20 @@ Write only the summary body. Do not include any preamble or prefix.""" derived from ``summary_target_ratio * context_length``, so it scales automatically with the model's context window. - Never cuts inside a tool_call/result group. Falls back to the old - ``protect_last_n`` if the budget would protect fewer messages. + Token budget is the primary criterion. A hard minimum of 3 messages + is always protected, but the budget is allowed to exceed by up to + 1.5x to avoid cutting inside an oversized message (tool output, file + read, etc.). If even the minimum 3 messages exceed 1.5x the budget + the cut is placed right after the head so compression still runs. + + Never cuts inside a tool_call/result group. """ if token_budget is None: token_budget = self.tail_token_budget n = len(messages) - min_tail = self.protect_last_n + # Hard minimum: always keep at least 3 messages in the tail + min_tail = min(3, n - head_end - 1) if n - head_end > 1 else 0 + soft_ceiling = int(token_budget * 1.5) accumulated = 0 cut_idx = n # start from beyond the end @@ -552,21 +584,21 @@ Write only the summary body. Do not include any preamble or prefix.""" if isinstance(tc, dict): args = tc.get("function", {}).get("arguments", "") msg_tokens += len(args) // _CHARS_PER_TOKEN - if accumulated + msg_tokens > token_budget and (n - i) >= min_tail: + # Stop once we exceed the soft ceiling (unless we haven't hit min_tail yet) + if accumulated + msg_tokens > soft_ceiling and (n - i) >= min_tail: break accumulated += msg_tokens cut_idx = i - # Ensure we protect at least protect_last_n messages + # Ensure we protect at least min_tail messages fallback_cut = n - min_tail if cut_idx > fallback_cut: cut_idx = fallback_cut # If the token budget would protect everything (small conversations), - # fall back to the fixed protect_last_n approach so compression can - # still remove middle turns. + # force a cut after the head so compression can still remove middle turns. if cut_idx <= head_end: - cut_idx = fallback_cut + cut_idx = max(fallback_cut, head_end + 1) # Align to avoid splitting tool groups cut_idx = self._align_boundary_backward(messages, cut_idx) @@ -591,12 +623,13 @@ Write only the summary body. Do not include any preamble or prefix.""" up so the API never receives mismatched IDs. """ n_messages = len(messages) - if n_messages <= self.protect_first_n + self.protect_last_n + 1: + # Only need head + 3 tail messages minimum (token budget decides the real tail size) + _min_for_compress = self.protect_first_n + 3 + 1 + if n_messages <= _min_for_compress: if not self.quiet_mode: logger.warning( "Cannot compress: only %d messages (need > %d)", - n_messages, - self.protect_first_n + self.protect_last_n + 1, + n_messages, _min_for_compress, ) return messages @@ -604,7 +637,8 @@ Write only the summary body. Do not include any preamble or prefix.""" # Phase 1: Prune old tool results (cheap, no LLM call) messages, pruned_count = self._prune_old_tool_results( - messages, protect_tail_count=self.protect_last_n * 3, + messages, protect_tail_count=self.protect_last_n, + protect_tail_tokens=self.tail_token_budget, ) if pruned_count and not self.quiet_mode: logger.info("Pre-compression: pruned %d old tool result(s)", pruned_count) From c506126123508bb097dd5bc3d35dbc335e729e3e Mon Sep 17 00:00:00 2001 From: BongSuCHOI Date: Wed, 8 Apr 2026 18:25:30 +0000 Subject: [PATCH 24/49] fix(tests): update context_compressor tests for min_tail=3 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #6240 changed tail protection from protect_last_n to min(3, ...) which increased the minimum compressible message count and shifted tail boundaries. Three tests broke: - test_summary_role_avoids_consecutive_user_messages: 6→8 msgs - test_double_collision_user_head_assistant_tail: 7→8 msgs - test_no_collision_scenarios_still_work: 6→8 msgs All tests now exceed the new min_for_compress threshold (6) and maintain proper role alternation in both head and tail sections. --- tests/agent/test_context_compressor.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/agent/test_context_compressor.py b/tests/agent/test_context_compressor.py index 257cf90395e..8a72d5fefc0 100644 --- a/tests/agent/test_context_compressor.py +++ b/tests/agent/test_context_compressor.py @@ -324,7 +324,10 @@ class TestCompressWithClient: with patch("agent.context_compressor.get_model_context_length", return_value=100000): c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2) - # Last head message (index 1) is "assistant" → summary should be "user" + # Last head message (index 1) is "assistant" → summary should be "user". + # With min_tail=3, tail = last 3 messages (indices 5-7). + # head_last=assistant, tail_first=assistant → summary_role="user", no collision. + # Need 8 messages: min_for_compress = 2+3+1 = 6, must have > 6. msgs = [ {"role": "user", "content": "msg 0"}, {"role": "assistant", "content": "msg 1"}, @@ -332,6 +335,8 @@ class TestCompressWithClient: {"role": "assistant", "content": "msg 3"}, {"role": "user", "content": "msg 4"}, {"role": "assistant", "content": "msg 5"}, + {"role": "user", "content": "msg 6"}, + {"role": "assistant", "content": "msg 7"}, ] with patch("agent.context_compressor.call_llm", return_value=mock_response): result = c.compress(msgs) @@ -460,8 +465,10 @@ class TestCompressWithClient: c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2) # Head: [system, user] → last head = user - # Tail: [assistant, user] → first tail = assistant + # Tail: [assistant, user, assistant] → first tail = assistant # summary_role="assistant" collides with tail, "user" collides with head → merge + # With min_tail=3, tail = last 3 messages (indices 5-7). + # Need 8 messages: min_for_compress = 2+3+1 = 6, must have > 6. msgs = [ {"role": "system", "content": "system prompt"}, {"role": "user", "content": "msg 1"}, @@ -470,6 +477,7 @@ class TestCompressWithClient: {"role": "assistant", "content": "msg 4"}, # compressed {"role": "assistant", "content": "msg 5"}, # tail start {"role": "user", "content": "msg 6"}, + {"role": "assistant", "content": "msg 7"}, ] with patch("agent.context_compressor.call_llm", return_value=mock_response): result = c.compress(msgs) @@ -481,7 +489,7 @@ class TestCompressWithClient: if r1 in ("user", "assistant") and r2 in ("user", "assistant"): assert r1 != r2, f"consecutive {r1} at indices {i-1},{i}" - # The summary should be merged into the first tail message (assistant) + # The summary should be merged into the first tail message (assistant at index 5) first_tail = [m for m in result if "msg 5" in (m.get("content") or "")] assert len(first_tail) == 1 assert "summary text" in first_tail[0]["content"] @@ -496,14 +504,18 @@ class TestCompressWithClient: with patch("agent.context_compressor.get_model_context_length", return_value=100000): c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2) - # Head=assistant, Tail=assistant → summary_role="user", no collision + # Head=assistant, Tail=assistant → summary_role="user", no collision. + # With min_tail=3, tail = last 3 messages (indices 5-7). + # Need 8 messages: min_for_compress = 2+3+1 = 6, must have > 6. msgs = [ {"role": "user", "content": "msg 0"}, {"role": "assistant", "content": "msg 1"}, {"role": "user", "content": "msg 2"}, {"role": "assistant", "content": "msg 3"}, - {"role": "assistant", "content": "msg 4"}, - {"role": "user", "content": "msg 5"}, + {"role": "user", "content": "msg 4"}, + {"role": "assistant", "content": "msg 5"}, + {"role": "user", "content": "msg 6"}, + {"role": "assistant", "content": "msg 7"}, ] with patch("agent.context_compressor.call_llm", return_value=mock_response): result = c.compress(msgs) From d40264d53b5fe88313367d8554a75efdc07a8d9f Mon Sep 17 00:00:00 2001 From: Teknium Date: Wed, 8 Apr 2026 23:35:46 -0700 Subject: [PATCH 25/49] test: add coverage for token-budget tail protection Tests for the new behavior paths: - Large tool outputs no longer block compaction (motivating scenario) - Hard minimum of 3 tail messages always protected - 1.5x soft ceiling for oversized messages - Small conversations still compress (min 8 messages) - Token-budget prune path in _prune_old_tool_results - Fallback to message-count when no token budget --- tests/agent/test_context_compressor.py | 155 +++++++++++++++++++++++++ 1 file changed, 155 insertions(+) diff --git a/tests/agent/test_context_compressor.py b/tests/agent/test_context_compressor.py index 8a72d5fefc0..42f6de0fd33 100644 --- a/tests/agent/test_context_compressor.py +++ b/tests/agent/test_context_compressor.py @@ -612,3 +612,158 @@ class TestSummaryTargetRatio: with patch("agent.context_compressor.get_model_context_length", return_value=100_000): c = ContextCompressor(model="test", quiet_mode=True) assert c.protect_last_n == 20 + + +class TestTokenBudgetTailProtection: + """Tests for token-budget-based tail protection (PR #6240). + + The core change: tail protection is now based on a token budget rather + than a fixed message count. This prevents large tool outputs from + blocking compaction. + """ + + @pytest.fixture() + def budget_compressor(self): + """Compressor with known token budget for tail protection tests.""" + with patch("agent.context_compressor.get_model_context_length", return_value=200_000): + c = ContextCompressor( + model="test/model", + threshold_percent=0.50, # 100K threshold + protect_first_n=2, + protect_last_n=20, + quiet_mode=True, + ) + return c + + def test_large_tool_outputs_no_longer_block_compaction(self, budget_compressor): + """The motivating scenario: 20 messages with large tool outputs should + NOT prevent compaction. With message-count tail protection they would + all be protected, leaving nothing to summarize.""" + c = budget_compressor + messages = [ + {"role": "user", "content": "Start task"}, + {"role": "assistant", "content": "On it"}, + ] + # Add 20 messages with large tool outputs (~5K chars each ≈ 1250 tokens) + for i in range(10): + messages.append({ + "role": "assistant", "content": None, + "tool_calls": [{"function": {"name": f"tool_{i}", "arguments": "{}"}}], + }) + messages.append({ + "role": "tool", "content": "x" * 5000, + "tool_call_id": f"call_{i}", + }) + # Add 3 recent small messages + messages.append({"role": "user", "content": "What's the status?"}) + messages.append({"role": "assistant", "content": "Here's what I found..."}) + messages.append({"role": "user", "content": "Continue"}) + + # The tail cut should NOT protect all 20 tool messages + head_end = c.protect_first_n + cut = c._find_tail_cut_by_tokens(messages, head_end) + tail_size = len(messages) - cut + # With token budget, the tail should be much smaller than 20+ + assert tail_size < 20, f"Tail {tail_size} messages — large tool outputs are blocking compaction" + # But at least 3 (hard minimum) + assert tail_size >= 3 + + def test_min_tail_always_3_messages(self, budget_compressor): + """Even with a tiny token budget, at least 3 messages are protected.""" + c = budget_compressor + # Override to a tiny budget + c.tail_token_budget = 10 + messages = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + {"role": "user", "content": "do something"}, + {"role": "assistant", "content": "working on it"}, + {"role": "user", "content": "more work"}, + {"role": "assistant", "content": "done"}, + {"role": "user", "content": "thanks"}, + ] + head_end = 2 + cut = c._find_tail_cut_by_tokens(messages, head_end) + tail_size = len(messages) - cut + assert tail_size >= 3, f"Tail is only {tail_size} messages, min should be 3" + + def test_soft_ceiling_allows_oversized_message(self, budget_compressor): + """The 1.5x soft ceiling allows an oversized message to be included + rather than splitting it.""" + c = budget_compressor + # Set a small budget — 500 tokens + c.tail_token_budget = 500 + messages = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + {"role": "user", "content": "read the file"}, + # This message is ~600 tokens (> budget of 500, but < 1.5x = 750) + {"role": "assistant", "content": "a" * 2400}, + {"role": "user", "content": "short"}, + {"role": "assistant", "content": "short reply"}, + {"role": "user", "content": "continue"}, + ] + head_end = 2 + cut = c._find_tail_cut_by_tokens(messages, head_end) + # The oversized message at index 3 should NOT be the cut point + # because 1.5x ceiling = 750 tokens and accumulated would be ~610 + # (short msgs + oversized msg) which is < 750 + tail_size = len(messages) - cut + assert tail_size >= 3 + + def test_small_conversation_still_compresses(self, budget_compressor): + """With the new min of 8 messages (head=2 + 3 + 1 guard + 2 middle), + a small but compressible conversation should still compress.""" + c = budget_compressor + # 9 messages: head(2) + 4 middle + 3 tail = compressible + messages = [] + for i in range(9): + role = "user" if i % 2 == 0 else "assistant" + messages.append({"role": role, "content": f"Message {i}"}) + + # Should not early-return (needs > protect_first_n + 3 + 1 = 6) + # Mock the summary generation to avoid real API call + with patch.object(c, "_generate_summary", return_value="Summary of conversation"): + result = c.compress(messages, current_tokens=90_000) + # Should have compressed (fewer messages than original) + assert len(result) < len(messages) + + def test_prune_with_token_budget(self, budget_compressor): + """_prune_old_tool_results with protect_tail_tokens respects the budget.""" + c = budget_compressor + messages = [ + {"role": "user", "content": "start"}, + {"role": "assistant", "content": None, + "tool_calls": [{"function": {"name": "read_file", "arguments": '{"path": "big.txt"}'}}]}, + {"role": "tool", "content": "x" * 10000, "tool_call_id": "c1"}, # ~2500 tokens + {"role": "assistant", "content": None, + "tool_calls": [{"function": {"name": "read_file", "arguments": '{"path": "small.txt"}'}}]}, + {"role": "tool", "content": "y" * 10000, "tool_call_id": "c2"}, # ~2500 tokens + {"role": "user", "content": "short recent message"}, + {"role": "assistant", "content": "short reply"}, + ] + # With a 1000-token budget, only the last couple messages should be protected + result, pruned = c._prune_old_tool_results( + messages, protect_tail_count=2, protect_tail_tokens=1000, + ) + # At least one old tool result should have been pruned + assert pruned >= 1 + + def test_prune_without_token_budget_uses_message_count(self, budget_compressor): + """Without protect_tail_tokens, falls back to message-count behavior.""" + c = budget_compressor + messages = [ + {"role": "user", "content": "start"}, + {"role": "assistant", "content": None, + "tool_calls": [{"function": {"name": "tool", "arguments": "{}"}}]}, + {"role": "tool", "content": "x" * 5000, "tool_call_id": "c1"}, + {"role": "user", "content": "recent"}, + {"role": "assistant", "content": "reply"}, + ] + # protect_tail_count=3 means last 3 messages protected + result, pruned = c._prune_old_tool_results( + messages, protect_tail_count=3, + ) + # Tool at index 2 is outside the protected tail (last 3 = indices 2,3,4) + # so it might or might not be pruned depending on boundary + assert isinstance(pruned, int) From 7120d6cdd6a6e3d0559185caabaef203dceca622 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:19:05 -0700 Subject: [PATCH 26/49] fix(bluebubbles): add missing integration points and documentation (#6460) - hermes_cli/skills_config.py: add platform label for per-platform skill config - gateway/session.py: add to PII-safe platforms (no mention system) - website/docs/user-guide/messaging/bluebubbles.md: full setup guide - website/sidebars.ts: sidebar navigation entry - 10 docs pages: add BlueBubbles to all platform enumerations (env vars, toolsets, cron delivery, gateway internals, etc.) --- gateway/session.py | 1 + hermes_cli/skills_config.py | 1 + website/docs/developer-guide/architecture.md | 4 +- .../docs/developer-guide/cron-internals.md | 1 + .../docs/developer-guide/gateway-internals.md | 1 + website/docs/index.md | 2 +- website/docs/integrations/index.md | 4 +- .../docs/reference/environment-variables.md | 7 + website/docs/reference/toolsets-reference.md | 1 + website/docs/user-guide/configuration.md | 2 +- website/docs/user-guide/features/cron.md | 1 + .../docs/user-guide/messaging/bluebubbles.md | 141 ++++++++++++++++++ website/docs/user-guide/messaging/index.md | 6 +- website/docs/user-guide/sessions.md | 1 + website/sidebars.ts | 1 + 15 files changed, 167 insertions(+), 7 deletions(-) create mode 100644 website/docs/user-guide/messaging/bluebubbles.md diff --git a/gateway/session.py b/gateway/session.py index 64f04ad9c95..72c3eb16188 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -193,6 +193,7 @@ _PII_SAFE_PLATFORMS = frozenset({ Platform.WHATSAPP, Platform.SIGNAL, Platform.TELEGRAM, + Platform.BLUEBUBBLES, }) """Platforms where user IDs can be safely redacted (no in-message mention system that requires raw IDs). Discord is excluded because mentions use ``<@user_id>`` diff --git a/hermes_cli/skills_config.py b/hermes_cli/skills_config.py index 7b44014ea59..d7e47ca5f28 100644 --- a/hermes_cli/skills_config.py +++ b/hermes_cli/skills_config.py @@ -23,6 +23,7 @@ PLATFORMS = { "slack": "💼 Slack", "whatsapp": "📱 WhatsApp", "signal": "📡 Signal", + "bluebubbles": "💬 BlueBubbles", "email": "📧 Email", "homeassistant": "🏠 Home Assistant", "mattermost": "💬 Mattermost", diff --git a/website/docs/developer-guide/architecture.md b/website/docs/developer-guide/architecture.md index c08161b32f1..38fbfb138ca 100644 --- a/website/docs/developer-guide/architecture.md +++ b/website/docs/developer-guide/architecture.md @@ -116,9 +116,9 @@ hermes-agent/ │ ├── mirror.py # Cross-session message mirroring │ ├── status.py # Token locks, profile-scoped process tracking │ ├── builtin_hooks/ # Always-registered hooks -│ └── platforms/ # 14 adapters: telegram, discord, slack, whatsapp, +│ └── platforms/ # 15 adapters: telegram, discord, slack, whatsapp, │ # signal, matrix, mattermost, email, sms, -│ # dingtalk, feishu, wecom, homeassistant, webhook +│ # dingtalk, feishu, wecom, bluebubbles, homeassistant, webhook │ ├── acp_adapter/ # ACP server (VS Code / Zed / JetBrains) ├── cron/ # Scheduler (jobs.py, scheduler.py) diff --git a/website/docs/developer-guide/cron-internals.md b/website/docs/developer-guide/cron-internals.md index cc8435dbee2..2f14d4e1a5c 100644 --- a/website/docs/developer-guide/cron-internals.md +++ b/website/docs/developer-guide/cron-internals.md @@ -153,6 +153,7 @@ Cron job results can be delivered to any supported platform: | DingTalk | `dingtalk` | Deliver to DingTalk | | Feishu | `feishu` | Deliver to Feishu | | WeCom | `wecom` | Deliver to WeCom | +| BlueBubbles | `bluebubbles` | Deliver to iMessage via BlueBubbles | For Telegram topics, use the format `telegram::` (e.g., `telegram:-1001234567890:17585`). diff --git a/website/docs/developer-guide/gateway-internals.md b/website/docs/developer-guide/gateway-internals.md index 1371bdd3409..cf25cecd9a8 100644 --- a/website/docs/developer-guide/gateway-internals.md +++ b/website/docs/developer-guide/gateway-internals.md @@ -160,6 +160,7 @@ gateway/platforms/ ├── dingtalk.py # DingTalk WebSocket ├── feishu.py # Feishu/Lark WebSocket or webhook ├── wecom.py # WeCom (WeChat Work) callback +├── bluebubbles.py # Apple iMessage via BlueBubbles macOS server ├── webhook.py # Inbound/outbound webhook adapter ├── api_server.py # REST API server adapter └── homeassistant.py # Home Assistant conversation integration diff --git a/website/docs/index.md b/website/docs/index.md index f4b5378f4cf..0f180673ac4 100644 --- a/website/docs/index.md +++ b/website/docs/index.md @@ -46,7 +46,7 @@ It's not a coding copilot tethered to an IDE or a chatbot wrapper around a singl - **A closed learning loop** — Agent-curated memory with periodic nudges, autonomous skill creation, skill self-improvement during use, FTS5 cross-session recall with LLM summarization, and [Honcho](https://github.com/plastic-labs/honcho) dialectic user modeling - **Runs anywhere, not just your laptop** — 6 terminal backends: local, Docker, SSH, Daytona, Singularity, Modal. Daytona and Modal offer serverless persistence — your environment hibernates when idle, costing nearly nothing -- **Lives where you do** — CLI, Telegram, Discord, Slack, WhatsApp, Signal, Matrix, Mattermost, Email, SMS, DingTalk, Feishu, WeCom, Home Assistant — 14+ platforms from one gateway +- **Lives where you do** — CLI, Telegram, Discord, Slack, WhatsApp, Signal, Matrix, Mattermost, Email, SMS, DingTalk, Feishu, WeCom, BlueBubbles, Home Assistant — 15+ platforms from one gateway - **Built by model trainers** — Created by [Nous Research](https://nousresearch.com), the lab behind Hermes, Nomos, and Psyche. Works with [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai), OpenAI, or any endpoint - **Scheduled automations** — Built-in cron with delivery to any platform - **Delegates & parallelizes** — Spawn isolated subagents for parallel workstreams. Programmatic Tool Calling via `execute_code` collapses multi-step pipelines into single inference calls diff --git a/website/docs/integrations/index.md b/website/docs/integrations/index.md index ce103f1cc80..e6fe54f7765 100644 --- a/website/docs/integrations/index.md +++ b/website/docs/integrations/index.md @@ -80,9 +80,9 @@ Speech-to-text supports three providers: local Whisper (free, runs on-device), G ## Messaging Platforms -Hermes runs as a gateway bot on 14+ messaging platforms, all configured through the same `gateway` subsystem: +Hermes runs as a gateway bot on 15+ messaging platforms, all configured through the same `gateway` subsystem: -- **[Telegram](/docs/user-guide/messaging/telegram)**, **[Discord](/docs/user-guide/messaging/discord)**, **[Slack](/docs/user-guide/messaging/slack)**, **[WhatsApp](/docs/user-guide/messaging/whatsapp)**, **[Signal](/docs/user-guide/messaging/signal)**, **[Matrix](/docs/user-guide/messaging/matrix)**, **[Mattermost](/docs/user-guide/messaging/mattermost)**, **[Email](/docs/user-guide/messaging/email)**, **[SMS](/docs/user-guide/messaging/sms)**, **[DingTalk](/docs/user-guide/messaging/dingtalk)**, **[Feishu/Lark](/docs/user-guide/messaging/feishu)**, **[WeCom](/docs/user-guide/messaging/wecom)**, **[Home Assistant](/docs/user-guide/messaging/homeassistant)**, **[Webhooks](/docs/user-guide/messaging/webhooks)** +- **[Telegram](/docs/user-guide/messaging/telegram)**, **[Discord](/docs/user-guide/messaging/discord)**, **[Slack](/docs/user-guide/messaging/slack)**, **[WhatsApp](/docs/user-guide/messaging/whatsapp)**, **[Signal](/docs/user-guide/messaging/signal)**, **[Matrix](/docs/user-guide/messaging/matrix)**, **[Mattermost](/docs/user-guide/messaging/mattermost)**, **[Email](/docs/user-guide/messaging/email)**, **[SMS](/docs/user-guide/messaging/sms)**, **[DingTalk](/docs/user-guide/messaging/dingtalk)**, **[Feishu/Lark](/docs/user-guide/messaging/feishu)**, **[WeCom](/docs/user-guide/messaging/wecom)**, **[BlueBubbles](/docs/user-guide/messaging/bluebubbles)**, **[Home Assistant](/docs/user-guide/messaging/homeassistant)**, **[Webhooks](/docs/user-guide/messaging/webhooks)** See the [Messaging Gateway overview](/docs/user-guide/messaging) for the platform comparison table and setup guide. diff --git a/website/docs/reference/environment-variables.md b/website/docs/reference/environment-variables.md index 00b428697fc..e8f2e8aee6b 100644 --- a/website/docs/reference/environment-variables.md +++ b/website/docs/reference/environment-variables.md @@ -228,6 +228,13 @@ For cloud sandbox backends, persistence is filesystem-oriented. `TERMINAL_LIFETI | `WECOM_WEBSOCKET_URL` | Custom WebSocket URL (default: `wss://openws.work.weixin.qq.com`) | | `WECOM_ALLOWED_USERS` | Comma-separated WeCom user IDs allowed to message the bot | | `WECOM_HOME_CHANNEL` | WeCom chat ID for cron delivery and notifications | +| `BLUEBUBBLES_SERVER_URL` | BlueBubbles server URL (e.g. `http://192.168.1.10:1234`) | +| `BLUEBUBBLES_PASSWORD` | BlueBubbles server password | +| `BLUEBUBBLES_WEBHOOK_HOST` | Webhook listener bind address (default: `127.0.0.1`) | +| `BLUEBUBBLES_WEBHOOK_PORT` | Webhook listener port (default: `8645`) | +| `BLUEBUBBLES_HOME_CHANNEL` | Phone/email for cron/notification delivery | +| `BLUEBUBBLES_ALLOWED_USERS` | Comma-separated authorized users | +| `BLUEBUBBLES_ALLOW_ALL_USERS` | Allow all users (`true`/`false`) | | `MATTERMOST_URL` | Mattermost server URL (e.g. `https://mm.example.com`) | | `MATTERMOST_TOKEN` | Bot token or personal access token for Mattermost | | `MATTERMOST_ALLOWED_USERS` | Comma-separated Mattermost user IDs allowed to message the bot | diff --git a/website/docs/reference/toolsets-reference.md b/website/docs/reference/toolsets-reference.md index 1c225b233dd..ba04d5c7777 100644 --- a/website/docs/reference/toolsets-reference.md +++ b/website/docs/reference/toolsets-reference.md @@ -103,6 +103,7 @@ Platform toolsets define the complete tool configuration for a deployment target | `hermes-dingtalk` | Same as `hermes-cli`. | | `hermes-feishu` | Same as `hermes-cli`. | | `hermes-wecom` | Same as `hermes-cli`. | +| `hermes-bluebubbles` | Same as `hermes-cli`. | | `hermes-homeassistant` | Same as `hermes-cli`. | | `hermes-webhook` | Same as `hermes-cli`. | | `hermes-gateway` | Union of all messaging platform toolsets. Used internally when the gateway needs the broadest possible tool set. | diff --git a/website/docs/user-guide/configuration.md b/website/docs/user-guide/configuration.md index 4431e068285..0ac24db1847 100644 --- a/website/docs/user-guide/configuration.md +++ b/website/docs/user-guide/configuration.md @@ -843,7 +843,7 @@ display: slack: 'off' # quiet in shared Slack workspace ``` -Platforms without an override fall back to the global `tool_progress` value. Valid platform keys: `telegram`, `discord`, `slack`, `signal`, `whatsapp`, `matrix`, `mattermost`, `email`, `sms`, `homeassistant`, `dingtalk`, `feishu`, `wecom`. +Platforms without an override fall back to the global `tool_progress` value. Valid platform keys: `telegram`, `discord`, `slack`, `signal`, `whatsapp`, `matrix`, `mattermost`, `email`, `sms`, `homeassistant`, `dingtalk`, `feishu`, `wecom`, `bluebubbles`. ## Privacy diff --git a/website/docs/user-guide/features/cron.md b/website/docs/user-guide/features/cron.md index ff63848d8a3..b463d5a7bed 100644 --- a/website/docs/user-guide/features/cron.md +++ b/website/docs/user-guide/features/cron.md @@ -202,6 +202,7 @@ When scheduling jobs, you specify where the output goes: | `"dingtalk"` | DingTalk | | | `"feishu"` | Feishu/Lark | | | `"wecom"` | WeCom | | +| `"bluebubbles"` | BlueBubbles (iMessage) | | The agent's final response is automatically delivered. You do not need to call `send_message` in the cron prompt. diff --git a/website/docs/user-guide/messaging/bluebubbles.md b/website/docs/user-guide/messaging/bluebubbles.md new file mode 100644 index 00000000000..3f023d31792 --- /dev/null +++ b/website/docs/user-guide/messaging/bluebubbles.md @@ -0,0 +1,141 @@ +# BlueBubbles (iMessage) + +Connect Hermes to Apple iMessage via [BlueBubbles](https://bluebubbles.app/) — a free, open-source macOS server that bridges iMessage to any device. + +## Prerequisites + +- A **Mac** (always on) running [BlueBubbles Server](https://bluebubbles.app/) +- Apple ID signed into Messages.app on that Mac +- BlueBubbles Server v1.0.0+ (webhooks require this version) +- Network connectivity between Hermes and the BlueBubbles server + +## Setup + +### 1. Install BlueBubbles Server + +Download and install from [bluebubbles.app](https://bluebubbles.app/). Complete the setup wizard — sign in with your Apple ID and configure a connection method (local network, Ngrok, Cloudflare, or Dynamic DNS). + +### 2. Get your Server URL and Password + +In BlueBubbles Server → **Settings → API**, note: +- **Server URL** (e.g., `http://192.168.1.10:1234`) +- **Server Password** + +### 3. Configure Hermes + +Run the setup wizard: + +```bash +hermes gateway setup +``` + +Select **BlueBubbles (iMessage)** and enter your server URL and password. + +Or set environment variables directly in `~/.hermes/.env`: + +```bash +BLUEBUBBLES_SERVER_URL=http://192.168.1.10:1234 +BLUEBUBBLES_PASSWORD=your-server-password +``` + +### 4. Authorize Users + +Choose one approach: + +**DM Pairing (recommended):** +```bash +hermes pairing generate bluebubbles +``` +Share the pairing code — the user sends it via iMessage to get approved. + +**Pre-authorize specific users:** +```bash +BLUEBUBBLES_ALLOWED_USERS=user@icloud.com,+15551234567 +``` + +**Open access:** +```bash +BLUEBUBBLES_ALLOW_ALL_USERS=true +``` + +### 5. Start the Gateway + +```bash +hermes gateway run +``` + +Hermes will connect to your BlueBubbles server, register a webhook, and start listening for iMessage messages. + +## How It Works + +``` +iMessage → Messages.app → BlueBubbles Server → Webhook → Hermes +Hermes → BlueBubbles REST API → Messages.app → iMessage +``` + +- **Inbound:** BlueBubbles sends webhook events to a local listener when new messages arrive. No polling — instant delivery. +- **Outbound:** Hermes sends messages via the BlueBubbles REST API. +- **Media:** Images, voice messages, videos, and documents are supported in both directions. Inbound attachments are downloaded and cached locally for the agent to process. + +## Environment Variables + +| Variable | Required | Default | Description | +|----------|----------|---------|-------------| +| `BLUEBUBBLES_SERVER_URL` | Yes | — | BlueBubbles server URL | +| `BLUEBUBBLES_PASSWORD` | Yes | — | Server password | +| `BLUEBUBBLES_WEBHOOK_HOST` | No | `127.0.0.1` | Webhook listener bind address | +| `BLUEBUBBLES_WEBHOOK_PORT` | No | `8645` | Webhook listener port | +| `BLUEBUBBLES_WEBHOOK_PATH` | No | `/bluebubbles-webhook` | Webhook URL path | +| `BLUEBUBBLES_HOME_CHANNEL` | No | — | Phone/email for cron delivery | +| `BLUEBUBBLES_ALLOWED_USERS` | No | — | Comma-separated authorized users | +| `BLUEBUBBLES_ALLOW_ALL_USERS` | No | `false` | Allow all users | +| `BLUEBUBBLES_SEND_READ_RECEIPTS` | No | `true` | Auto-mark messages as read | + +## Features + +### Text Messaging +Send and receive iMessages. Markdown is automatically stripped for clean plain-text delivery. + +### Rich Media +- **Images:** Photos appear natively in the iMessage conversation +- **Voice messages:** Audio files sent as iMessage voice messages +- **Videos:** Video attachments +- **Documents:** Files sent as iMessage attachments + +### Tapback Reactions +Love, like, dislike, laugh, emphasize, and question reactions. Requires the BlueBubbles [Private API helper](https://docs.bluebubbles.app/helper-bundle/installation). + +### Typing Indicators +Shows "typing..." in the iMessage conversation while the agent is processing. Requires Private API. + +### Read Receipts +Automatically marks messages as read after processing. Requires Private API. + +### Chat Addressing +You can address chats by email or phone number — Hermes resolves them to BlueBubbles chat GUIDs automatically. No need to use raw GUID format. + +## Private API + +Some features require the BlueBubbles [Private API helper](https://docs.bluebubbles.app/helper-bundle/installation): +- Tapback reactions +- Typing indicators +- Read receipts +- Creating new chats by address + +Without the Private API, basic text messaging and media still work. + +## Troubleshooting + +### "Cannot reach server" +- Verify the server URL is correct and the Mac is on +- Check that BlueBubbles Server is running +- Ensure network connectivity (firewall, port forwarding) + +### Messages not arriving +- Check that the webhook is registered in BlueBubbles Server → Settings → API → Webhooks +- Verify the webhook URL is reachable from the Mac +- Check `hermes gateway logs` for webhook errors + +### "Private API helper not connected" +- Install the Private API helper: [docs.bluebubbles.app](https://docs.bluebubbles.app/helper-bundle/installation) +- Basic messaging works without it — only reactions, typing, and read receipts require it diff --git a/website/docs/user-guide/messaging/index.md b/website/docs/user-guide/messaging/index.md index fa662305bef..4e7d3514f9e 100644 --- a/website/docs/user-guide/messaging/index.md +++ b/website/docs/user-guide/messaging/index.md @@ -6,7 +6,7 @@ description: "Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, # Messaging Gateway -Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, SMS, Email, Home Assistant, Mattermost, Matrix, DingTalk, Feishu/Lark, WeCom, or your browser. The gateway is a single background process that connects to all your configured platforms, handles sessions, runs cron jobs, and delivers voice messages. +Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, SMS, Email, Home Assistant, Mattermost, Matrix, DingTalk, Feishu/Lark, WeCom, BlueBubbles (iMessage), or your browser. The gateway is a single background process that connects to all your configured platforms, handles sessions, runs cron jobs, and delivers voice messages. For the full voice feature set — including CLI microphone mode, spoken replies in messaging, and Discord voice-channel conversations — see [Voice Mode](/docs/user-guide/features/voice-mode) and [Use Voice Mode with Hermes](/docs/guides/use-voice-mode-with-hermes). @@ -27,6 +27,7 @@ For the full voice feature set — including CLI microphone mode, spoken replies | DingTalk | — | — | — | — | — | ✅ | ✅ | | Feishu/Lark | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | WeCom | ✅ | ✅ | ✅ | — | — | ✅ | ✅ | +| BlueBubbles | — | ✅ | ✅ | — | ✅ | ✅ | — | **Voice** = TTS audio replies and/or voice message transcription. **Images** = send/receive images. **Files** = send/receive file attachments. **Threads** = threaded conversations. **Reactions** = emoji reactions on messages. **Typing** = typing indicator while processing. **Streaming** = progressive message updates via editing. @@ -49,6 +50,7 @@ flowchart TB dt[DingTalk] fs[Feishu/Lark] wc[WeCom] + bb[BlueBubbles] api["API Server
(OpenAI-compatible)"] wh[Webhooks] end @@ -352,6 +354,7 @@ Each platform has its own toolset: | DingTalk | `hermes-dingtalk` | Full tools including terminal | | Feishu/Lark | `hermes-feishu` | Full tools including terminal | | WeCom | `hermes-wecom` | Full tools including terminal | +| BlueBubbles | `hermes-bluebubbles` | Full tools including terminal | | API Server | `hermes` (default) | Full tools including terminal | | Webhooks | `hermes-webhook` | Full tools including terminal | @@ -370,5 +373,6 @@ Each platform has its own toolset: - [DingTalk Setup](dingtalk.md) - [Feishu/Lark Setup](feishu.md) - [WeCom Setup](wecom.md) +- [BlueBubbles Setup (iMessage)](bluebubbles.md) - [Open WebUI + API Server](open-webui.md) - [Webhooks](webhooks.md) diff --git a/website/docs/user-guide/sessions.md b/website/docs/user-guide/sessions.md index a84e1064db0..358574030a7 100644 --- a/website/docs/user-guide/sessions.md +++ b/website/docs/user-guide/sessions.md @@ -44,6 +44,7 @@ Each session is tagged with its source platform: | `dingtalk` | DingTalk messenger | | `feishu` | Feishu/Lark messenger | | `wecom` | WeCom (WeChat Work) | +| `bluebubbles` | Apple iMessage via BlueBubbles macOS server | | `homeassistant` | Home Assistant conversation | | `webhook` | Incoming webhooks | | `api-server` | API server requests | diff --git a/website/sidebars.ts b/website/sidebars.ts index 5e1ebf2d6bd..39b60d88e9e 100644 --- a/website/sidebars.ts +++ b/website/sidebars.ts @@ -107,6 +107,7 @@ const sidebars: SidebarsConfig = { 'user-guide/messaging/dingtalk', 'user-guide/messaging/feishu', 'user-guide/messaging/wecom', + 'user-guide/messaging/bluebubbles', 'user-guide/messaging/open-webui', 'user-guide/messaging/webhooks', ], From 18140199c3a1cbb658a2eeadf692ffb8b5d1626f Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:29:45 -0700 Subject: [PATCH 27/49] fix(ci): build and push multi-arch Docker image (amd64 + arm64) (#6124) Add QEMU cross-compilation and multi-arch manifest support so Apple Silicon (M1/M2/M3) and other ARM-based systems get native images. - Add docker/setup-qemu-action for arm64 emulation on amd64 runners - Smoke test stays amd64-only (load:true can't export multi-arch) - Both push steps (main + release) now build linux/amd64,linux/arm64 - Bump timeout 30->60min for QEMU cross-compilation overhead - Add permissions: contents: read (least-privilege hardening) Salvaged from PR #3998 by Mibayy. Also addresses #5005 and #3913. Co-authored-by: Mibayy --- .github/workflows/docker-publish.yml | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index 6c1bb6eaa56..eec35fd62f2 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -8,6 +8,9 @@ on: release: types: [published] +permissions: + contents: read + concurrency: group: docker-${{ github.ref }} cancel-in-progress: true @@ -17,22 +20,29 @@ jobs: # Only run on the upstream repository, not on forks if: github.repository == 'NousResearch/hermes-agent' runs-on: ubuntu-latest - timeout-minutes: 30 + timeout-minutes: 60 steps: - name: Checkout code uses: actions/checkout@v4 with: submodules: recursive + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - - name: Build image + # Build amd64 only so we can `load` the image for smoke testing. + # `load: true` cannot export a multi-arch manifest to the local daemon. + # The multi-arch build follows on push to main / release. + - name: Build image (amd64, smoke test) uses: docker/build-push-action@v6 with: context: . file: Dockerfile load: true + platforms: linux/amd64 tags: nousresearch/hermes-agent:test cache-from: type=gha cache-to: type=gha,mode=max @@ -51,26 +61,28 @@ jobs: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Push image (main branch) + - name: Push multi-arch image (main branch) if: github.event_name == 'push' && github.ref == 'refs/heads/main' uses: docker/build-push-action@v6 with: context: . file: Dockerfile push: true + platforms: linux/amd64,linux/arm64 tags: | nousresearch/hermes-agent:latest nousresearch/hermes-agent:${{ github.sha }} cache-from: type=gha cache-to: type=gha,mode=max - - name: Push image (release) + - name: Push multi-arch image (release) if: github.event_name == 'release' uses: docker/build-push-action@v6 with: context: . file: Dockerfile push: true + platforms: linux/amd64,linux/arm64 tags: | nousresearch/hermes-agent:latest nousresearch/hermes-agent:${{ github.event.release.tag_name }} From 894e8c8a8f505c863e4a1c2365feb6607e22072e Mon Sep 17 00:00:00 2001 From: Hunter B Date: Thu, 2 Apr 2026 19:59:19 -0500 Subject: [PATCH 28/49] fix: resolve opencode.ai context window to 1M and clean up display formatting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two issues resolved: 1. Add opencode.ai to _URL_TO_PROVIDER mapping so base_url routes through models.dev lookup (which has mimo-v2-pro at 1M context) instead of falling back to probing /models (404) and defaulting to 128K. 2. Fix _format_context_length to round cleanly: 1048576 → '1M' instead of '1.048576M'. Applies same rounding logic to K values. --- agent/model_metadata.py | 1 + hermes_cli/banner.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/agent/model_metadata.py b/agent/model_metadata.py index 5b1d3376afa..9282586fead 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -197,6 +197,7 @@ _URL_TO_PROVIDER: Dict[str, str] = { "api.githubcopilot.com": "copilot", "models.github.ai": "copilot", "api.fireworks.ai": "fireworks", + "opencode.ai": "opencode-go", } diff --git a/hermes_cli/banner.py b/hermes_cli/banner.py index 03712c272de..b29805872d2 100644 --- a/hermes_cli/banner.py +++ b/hermes_cli/banner.py @@ -295,10 +295,16 @@ def _format_context_length(tokens: int) -> str: """Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M').""" if tokens >= 1_000_000: val = tokens / 1_000_000 - return f"{val:g}M" + rounded = round(val) + if abs(val - rounded) < 0.05: + return f"{rounded}M" + return f"{val:.1f}M" elif tokens >= 1_000: val = tokens / 1_000 - return f"{val:g}K" + rounded = round(val) + if abs(val - rounded) < 0.05: + return f"{rounded}K" + return f"{val:.1f}K" return str(tokens) From 5cf4fac2aae0fb73ebe8760cd099924e8b4b996d Mon Sep 17 00:00:00 2001 From: Cherif Yaya Date: Thu, 9 Apr 2026 00:04:30 -0700 Subject: [PATCH 29/49] fix: restore codex fallback auth-store lookup --- agent/auxiliary_client.py | 22 ++++++++++++++----- tests/agent/test_auxiliary_client.py | 32 ++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index f743a64eeb6..27c67c10a36 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -629,11 +629,19 @@ def _nous_base_url() -> str: def _read_codex_access_token() -> Optional[str]: - """Read a valid, non-expired Codex OAuth access token from Hermes auth store.""" + """Read a valid, non-expired Codex OAuth access token from Hermes auth store. + + If a credential pool exists but currently has no selectable runtime entry + (for example all pool slots are marked exhausted), fall back to the + profile's auth.json token instead of hard-failing. This keeps explicit + fallback-to-Codex working when the pool state is stale but the stored OAuth + token is still valid. + """ pool_present, entry = _select_pool_entry("openai-codex") if pool_present: token = _pool_runtime_api_key(entry) - return token or None + if token: + return token try: from hermes_cli.auth import _read_codex_tokens @@ -894,9 +902,13 @@ def _try_codex() -> Tuple[Optional[Any], Optional[str]]: pool_present, entry = _select_pool_entry("openai-codex") if pool_present: codex_token = _pool_runtime_api_key(entry) - if not codex_token: - return None, None - base_url = _pool_runtime_base_url(entry, _CODEX_AUX_BASE_URL) or _CODEX_AUX_BASE_URL + if codex_token: + base_url = _pool_runtime_base_url(entry, _CODEX_AUX_BASE_URL) or _CODEX_AUX_BASE_URL + else: + codex_token = _read_codex_access_token() + if not codex_token: + return None, None + base_url = _CODEX_AUX_BASE_URL else: codex_token = _read_codex_access_token() if not codex_token: diff --git a/tests/agent/test_auxiliary_client.py b/tests/agent/test_auxiliary_client.py index dd02ad23abd..3723378998c 100644 --- a/tests/agent/test_auxiliary_client.py +++ b/tests/agent/test_auxiliary_client.py @@ -77,6 +77,20 @@ class TestReadCodexAccessToken: result = _read_codex_access_token() assert result == "tok-123" + def test_pool_without_selected_entry_falls_back_to_auth_store(self, tmp_path, monkeypatch): + hermes_home = tmp_path / "hermes" + hermes_home.mkdir(parents=True, exist_ok=True) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + + valid_jwt = "eyJhbGciOiJSUzI1NiJ9.eyJleHAiOjk5OTk5OTk5OTl9.sig" + with patch("agent.auxiliary_client._select_pool_entry", return_value=(True, None)), \ + patch("hermes_cli.auth._read_codex_tokens", return_value={ + "tokens": {"access_token": valid_jwt, "refresh_token": "refresh"} + }): + result = _read_codex_access_token() + + assert result == valid_jwt + def test_missing_returns_none(self, tmp_path, monkeypatch): hermes_home = tmp_path / "hermes" hermes_home.mkdir(parents=True, exist_ok=True) @@ -238,6 +252,24 @@ class TestAnthropicOAuthFlag: assert mock_build.call_args.args[0] == "sk-ant-oat01-pooled" +class TestTryCodex: + def test_pool_without_selected_entry_falls_back_to_auth_store(self): + with ( + patch("agent.auxiliary_client._select_pool_entry", return_value=(True, None)), + patch("agent.auxiliary_client._read_codex_access_token", return_value="codex-auth-token"), + patch("agent.auxiliary_client.OpenAI") as mock_openai, + ): + mock_openai.return_value = MagicMock() + from agent.auxiliary_client import _try_codex + + client, model = _try_codex() + + assert client is not None + assert model == "gpt-5.2-codex" + assert mock_openai.call_args.kwargs["api_key"] == "codex-auth-token" + assert mock_openai.call_args.kwargs["base_url"] == "https://chatgpt.com/backend-api/codex" + + class TestExpiredCodexFallback: """Test that expired Codex tokens don't block the auto chain.""" From b962801f6ae9451725f56d7bcfd53ae8b491313b Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Thu, 9 Apr 2026 02:05:41 -0700 Subject: [PATCH 30/49] fix(bluebubbles): add setup wizard integration and OPTIONAL_ENV_VARS (#6494) The BlueBubbles adapter was merged but missing setup wizard support: - Add _setup_bluebubbles() guided setup (server URL, password, allowlist, home channel, webhook port) - Add to _GATEWAY_PLATFORMS registry so it appears in 'hermes setup gateway' - Add to any_messaging check and home channel missing warning - Add to gateway status display in 'hermes setup' - Add BLUEBUBBLES_SERVER_URL, BLUEBUBBLES_PASSWORD, BLUEBUBBLES_ALLOWED_USERS to OPTIONAL_ENV_VARS with descriptions and categories Previously the only way to configure BlueBubbles was manually editing .env. --- hermes_cli/config.py | 21 +++++++++++++ hermes_cli/setup.py | 71 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 4357119a2ad..cf988967956 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -1125,6 +1125,27 @@ OPTIONAL_ENV_VARS = { "category": "messaging", "advanced": True, }, + "BLUEBUBBLES_SERVER_URL": { + "description": "BlueBubbles server URL for iMessage integration (e.g. http://192.168.1.10:1234)", + "prompt": "BlueBubbles server URL", + "url": "https://bluebubbles.app/", + "password": False, + "category": "messaging", + }, + "BLUEBUBBLES_PASSWORD": { + "description": "BlueBubbles server password (from BlueBubbles Server → Settings → API)", + "prompt": "BlueBubbles server password", + "url": None, + "password": True, + "category": "messaging", + }, + "BLUEBUBBLES_ALLOWED_USERS": { + "description": "Comma-separated iMessage addresses (email or phone) allowed to use the bot", + "prompt": "Allowed iMessage addresses (comma-separated)", + "url": None, + "password": False, + "category": "messaging", + }, "GATEWAY_ALLOW_ALL_USERS": { "description": "Allow all users to interact with messaging bots (true/false). Default: false.", "prompt": "Allow all users (true/false)", diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index 43c3b086d97..95c9fa6228e 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -2167,6 +2167,71 @@ def _setup_whatsapp(): print_info("or personal self-chat) and pair via QR code.") +def _setup_bluebubbles(): + """Configure BlueBubbles iMessage gateway.""" + print_header("BlueBubbles (iMessage)") + existing = get_env_value("BLUEBUBBLES_SERVER_URL") + if existing: + print_info("BlueBubbles: already configured") + if not prompt_yes_no("Reconfigure BlueBubbles?", False): + return + + print_info("Connects Hermes to iMessage via BlueBubbles — a free, open-source") + print_info("macOS server that bridges iMessage to any device.") + print_info(" Requires a Mac running BlueBubbles Server v1.0.0+") + print_info(" Download: https://bluebubbles.app/") + print() + print_info("In BlueBubbles Server → Settings → API, note your Server URL and Password.") + print() + + server_url = prompt("BlueBubbles server URL (e.g. http://192.168.1.10:1234)") + if not server_url: + print_warning("Server URL is required — skipping BlueBubbles setup") + return + save_env_value("BLUEBUBBLES_SERVER_URL", server_url.rstrip("/")) + + password = prompt("BlueBubbles server password", password=True) + if not password: + print_warning("Password is required — skipping BlueBubbles setup") + return + save_env_value("BLUEBUBBLES_PASSWORD", password) + print_success("BlueBubbles credentials saved") + + print() + print_info("🔒 Security: Restrict who can message your bot") + print_info(" Use iMessage addresses: email (user@icloud.com) or phone (+15551234567)") + print() + allowed_users = prompt("Allowed iMessage addresses (comma-separated, leave empty for open access)") + if allowed_users: + save_env_value("BLUEBUBBLES_ALLOWED_USERS", allowed_users.replace(" ", "")) + print_success("BlueBubbles allowlist configured") + else: + print_info("⚠️ No allowlist set — anyone who can iMessage you can use the bot!") + + print() + print_info("📬 Home Channel: phone or email for cron job delivery and notifications.") + print_info(" You can also set this later with /set-home in your iMessage chat.") + home_channel = prompt("Home channel address (leave empty to set later)") + if home_channel: + save_env_value("BLUEBUBBLES_HOME_CHANNEL", home_channel) + + print() + print_info("Advanced settings (defaults are fine for most setups):") + if prompt_yes_no("Configure webhook listener settings?", False): + webhook_port = prompt("Webhook listener port (default: 8645)") + if webhook_port: + try: + save_env_value("BLUEBUBBLES_WEBHOOK_PORT", str(int(webhook_port))) + print_success(f"Webhook port set to {webhook_port}") + except ValueError: + print_warning("Invalid port number, using default 8645") + + print() + print_info("Requires the BlueBubbles Private API helper for typing indicators,") + print_info("read receipts, and tapback reactions. Basic messaging works without it.") + print_info(" Install: https://docs.bluebubbles.app/helper-bundle/installation") + + def _setup_webhooks(): """Configure webhook integration.""" print_header("Webhooks") @@ -2221,6 +2286,7 @@ _GATEWAY_PLATFORMS = [ ("Matrix", "MATRIX_ACCESS_TOKEN", _setup_matrix), ("Mattermost", "MATTERMOST_TOKEN", _setup_mattermost), ("WhatsApp", "WHATSAPP_ENABLED", _setup_whatsapp), + ("BlueBubbles (iMessage)", "BLUEBUBBLES_SERVER_URL", _setup_bluebubbles), ("Webhooks (GitHub, GitLab, etc.)", "WEBHOOK_ENABLED", _setup_webhooks), ] @@ -2264,6 +2330,7 @@ def setup_gateway(config: dict): or get_env_value("MATRIX_ACCESS_TOKEN") or get_env_value("MATRIX_PASSWORD") or get_env_value("WHATSAPP_ENABLED") + or get_env_value("BLUEBUBBLES_SERVER_URL") or get_env_value("WEBHOOK_ENABLED") ) if any_messaging: @@ -2283,6 +2350,8 @@ def setup_gateway(config: dict): missing_home.append("Discord") if get_env_value("SLACK_BOT_TOKEN") and not get_env_value("SLACK_HOME_CHANNEL"): missing_home.append("Slack") + if get_env_value("BLUEBUBBLES_SERVER_URL") and not get_env_value("BLUEBUBBLES_HOME_CHANNEL"): + missing_home.append("BlueBubbles") if missing_home: print() @@ -2453,6 +2522,8 @@ def _get_section_config_summary(config: dict, section_key: str) -> Optional[str] platforms.append("WhatsApp") if get_env_value("SIGNAL_ACCOUNT"): platforms.append("Signal") + if get_env_value("BLUEBUBBLES_SERVER_URL"): + platforms.append("BlueBubbles") if platforms: return ", ".join(platforms) return None # No platforms configured — section must run From 1eabbe905e86bfadcdfbc417044decb9bc4f93c8 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Thu, 9 Apr 2026 02:06:12 -0700 Subject: [PATCH 31/49] fix: retry 3 times when model returns truly empty response (#6488) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a model returns no content, no structured reasoning, and no tool calls (common with open models), the agent now silently retries up to 3 times before falling through to (empty). Silent retry (no synthetic messages) keeps the conversation history clean, preserves prompt caching, and respects the no-synthetic-user- injection invariant. Most empty responses from open models are transient (provider hiccups, rate limits, sampling flukes) so a simple retry is sufficient. This fills the last gap in the empty-response recovery chain: 1. _last_content_with_tools fallback (prior tool turn had content) 2. Thinking-only prefill continuation (#5931 — structured reasoning) 3. Empty response silent retry (NEW — truly empty, no reasoning) 4. (empty) terminal (last resort after all retries exhausted) Inline blocks are excluded — the model chose to reason, it just produced no visible text. That differs from truly empty. Tests: - Updated test_truly_empty to expect 4 API calls (1 + 3 retries) - Added test_truly_empty_response_succeeds_on_nudge --- run_agent.py | 25 +++++++++++++++++++++--- tests/run_agent/test_run_agent.py | 32 +++++++++++++++++++++++++++---- 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/run_agent.py b/run_agent.py index 793ddd6750f..3c5661a112d 100644 --- a/run_agent.py +++ b/run_agent.py @@ -9109,8 +9109,27 @@ class AIAgent: self._save_session_log(messages) continue - # Exhausted prefill attempts or no structured - # reasoning — fall through to "(empty)" terminal. + # ── Empty response retry (no reasoning) ────── + # Model returned nothing — no content, no + # structured reasoning, no tool calls. Common + # with open models (transient provider issues, + # rate limits, sampling flukes). Silently retry + # up to 3 times before giving up. Skip when + # content has inline tags (model chose + # to reason, just no visible text). + _truly_empty = not final_response.strip() + if _truly_empty and not _has_structured and self._empty_content_retries < 3: + self._empty_content_retries += 1 + self._vprint( + f"{self.log_prefix}↻ Empty response (no content or reasoning) " + f"— retrying ({self._empty_content_retries}/3)", + force=True, + ) + continue + + # Exhausted prefill attempts, empty retries, or + # structured reasoning with no content — + # fall through to "(empty)" terminal. reasoning_text = self._extract_reasoning(assistant_message) assistant_msg = self._build_assistant_message(assistant_message, finish_reason) assistant_msg["content"] = "(empty)" @@ -9120,7 +9139,7 @@ class AIAgent: reasoning_preview = reasoning_text[:500] + "..." if len(reasoning_text) > 500 else reasoning_text self._vprint(f"{self.log_prefix}ℹ️ Reasoning-only response (no visible content). Reasoning: {reasoning_preview}") else: - self._vprint(f"{self.log_prefix}ℹ️ Empty response (no content or reasoning).") + self._vprint(f"{self.log_prefix}ℹ️ Empty response (no content or reasoning) after 3 retries.") final_response = "(empty)" break diff --git a/tests/run_agent/test_run_agent.py b/tests/run_agent/test_run_agent.py index 59f88601c5e..98d799ae433 100644 --- a/tests/run_agent/test_run_agent.py +++ b/tests/run_agent/test_run_agent.py @@ -1668,12 +1668,15 @@ class TestRunConversation: if roles[i] == "assistant" and roles[i + 1] == "assistant": raise AssertionError("Consecutive assistant messages found in history") - def test_truly_empty_response_accepted_without_retry(self, agent): - """Truly empty response (no content, no reasoning) should still complete with (empty).""" + def test_truly_empty_response_retries_3_times_then_empty(self, agent): + """Truly empty response (no content, no reasoning) retries 3 times then falls through to (empty).""" self._setup_agent(agent) agent.base_url = "http://127.0.0.1:1234/v1" empty_resp = _mock_response(content=None, finish_reason="stop") - agent.client.chat.completions.create.side_effect = [empty_resp] + # 4 responses: 1 original + 3 nudge retries, all empty + agent.client.chat.completions.create.side_effect = [ + empty_resp, empty_resp, empty_resp, empty_resp, + ] with ( patch.object(agent, "_persist_session"), patch.object(agent, "_save_trajectory"), @@ -1682,7 +1685,28 @@ class TestRunConversation: result = agent.run_conversation("answer me") assert result["completed"] is True assert result["final_response"] == "(empty)" - assert result["api_calls"] == 1 # no retries + assert result["api_calls"] == 4 # 1 original + 3 retries + + def test_truly_empty_response_succeeds_on_nudge(self, agent): + """Model produces content after being nudged for empty response.""" + self._setup_agent(agent) + agent.base_url = "http://127.0.0.1:1234/v1" + empty_resp = _mock_response(content=None, finish_reason="stop") + content_resp = _mock_response( + content="Here is the actual answer.", + finish_reason="stop", + ) + # 1 empty response, then model produces content on nudge + agent.client.chat.completions.create.side_effect = [empty_resp, content_resp] + with ( + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + result = agent.run_conversation("answer me") + assert result["completed"] is True + assert result["final_response"] == "Here is the actual answer." + assert result["api_calls"] == 2 # 1 original + 1 nudge retry def test_nous_401_refreshes_after_remint_and_retries(self, agent): self._setup_agent(agent) From e1b0b135cbb71142e68e9a8b3c27b2b5188634ec Mon Sep 17 00:00:00 2001 From: Kira Date: Thu, 9 Apr 2026 03:15:09 -0400 Subject: [PATCH 32/49] fix(discord): accept .log attachments and raise document size limit --- gateway/platforms/base.py | 1 + gateway/platforms/discord.py | 6 +-- .../gateway/test_discord_document_handling.py | 39 ++++++++++++++++++- 3 files changed, 41 insertions(+), 5 deletions(-) diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index c72fa513bb3..bd07459ac8d 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -298,6 +298,7 @@ SUPPORTED_DOCUMENT_TYPES = { ".pdf": "application/pdf", ".md": "text/markdown", ".txt": "text/plain", + ".log": "text/plain", ".zip": "application/zip", ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", ".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index 36984202e58..2ace06e7798 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -2382,7 +2382,7 @@ class DiscordAdapter(BasePlatformAdapter): ext or "unknown", content_type, ) else: - MAX_DOC_BYTES = 20 * 1024 * 1024 + MAX_DOC_BYTES = 32 * 1024 * 1024 if att.size and att.size > MAX_DOC_BYTES: logger.warning( "[Discord] Document too large (%s bytes), skipping: %s", @@ -2406,9 +2406,9 @@ class DiscordAdapter(BasePlatformAdapter): media_urls.append(cached_path) media_types.append(doc_mime) logger.info("[Discord] Cached user document: %s", cached_path) - # Inject text content for .txt/.md files (capped at 100 KB) + # Inject text content for plain-text documents (capped at 100 KB) MAX_TEXT_INJECT_BYTES = 100 * 1024 - if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES: + if ext in (".md", ".txt", ".log") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES: try: text_content = raw_bytes.decode("utf-8") display_name = att.filename or f"document{ext}" diff --git a/tests/gateway/test_discord_document_handling.py b/tests/gateway/test_discord_document_handling.py index 7f918d1c738..a22e0f0d669 100644 --- a/tests/gateway/test_discord_document_handling.py +++ b/tests/gateway/test_discord_document_handling.py @@ -209,14 +209,31 @@ class TestIncomingDocumentHandling: assert "[Content of readme.md]:" in event.text assert "# Title" in event.text + @pytest.mark.asyncio + async def test_log_content_injected(self, adapter): + """.log file under 100KB should be treated as text/plain and injected.""" + file_content = b"BLE trace line 1\nBLE trace line 2" + + with _mock_aiohttp_download(file_content): + msg = make_message( + attachments=[make_attachment(filename="btsnoop_hci.log", content_type="text/plain")], + content="please inspect this", + ) + await adapter._handle_message(msg) + + event = adapter.handle_message.call_args[0][0] + assert "[Content of btsnoop_hci.log]:" in event.text + assert "BLE trace line 1" in event.text + assert "please inspect this" in event.text + @pytest.mark.asyncio async def test_oversized_document_skipped(self, adapter): - """A document over 20MB should be skipped — media_urls stays empty.""" + """A document over 32MB should be skipped — media_urls stays empty.""" msg = make_message([ make_attachment( filename="huge.pdf", content_type="application/pdf", - size=25 * 1024 * 1024, + size=33 * 1024 * 1024, ) ]) await adapter._handle_message(msg) @@ -226,6 +243,24 @@ class TestIncomingDocumentHandling: # handler must still be called adapter.handle_message.assert_called_once() + @pytest.mark.asyncio + async def test_mid_sized_zip_under_32mb_is_cached(self, adapter): + """A 25MB .zip should be accepted now that Discord documents allow up to 32MB.""" + msg = make_message([ + make_attachment( + filename="bugreport.zip", + content_type="application/zip", + size=25 * 1024 * 1024, + ) + ]) + + with _mock_aiohttp_download(b"PK\x03\x04test"): + await adapter._handle_message(msg) + + event = adapter.handle_message.call_args[0][0] + assert len(event.media_urls) == 1 + assert event.media_types == ["application/zip"] + @pytest.mark.asyncio async def test_zip_document_cached(self, adapter): """A .zip file should be cached as a supported document.""" From b408379e9d44bae4fb366d19183840dd52c39a16 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Thu, 9 Apr 2026 02:37:23 -0700 Subject: [PATCH 33/49] fix: reduce credential exhaustion TTL from 24 hours to 1 hour (#6504) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 24-hour default cooldown for 402-exhausted credentials was far too aggressive — if a user tops up credits or the 402 was caused by an oversized max_tokens request rather than true billing exhaustion, they shouldn't have to wait a full day. Reduce to 1 hour (matching the existing 429 TTL). Inspired by PR #6493 (michalkomar). --- agent/credential_pool.py | 6 ++--- tests/agent/test_credential_pool.py | 36 +++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/agent/credential_pool.py b/agent/credential_pool.py index a47901c8470..dd2c9abc5ee 100644 --- a/agent/credential_pool.py +++ b/agent/credential_pool.py @@ -64,10 +64,10 @@ SUPPORTED_POOL_STRATEGIES = { } # Cooldown before retrying an exhausted credential. -# 429 (rate-limited) cools down faster since quotas reset frequently. -# 402 (billing/quota) and other codes use a longer default. +# 429 (rate-limited) and 402 (billing/quota) both cool down after 1 hour. +# Provider-supplied reset_at timestamps override these defaults. EXHAUSTED_TTL_429_SECONDS = 60 * 60 # 1 hour -EXHAUSTED_TTL_DEFAULT_SECONDS = 24 * 60 * 60 # 24 hours +EXHAUSTED_TTL_DEFAULT_SECONDS = 60 * 60 # 1 hour # Pool key prefix for custom OpenAI-compatible endpoints. # Custom endpoints all share provider='custom' but are keyed by their diff --git a/tests/agent/test_credential_pool.py b/tests/agent/test_credential_pool.py index 891ab68a825..c3bde951565 100644 --- a/tests/agent/test_credential_pool.py +++ b/tests/agent/test_credential_pool.py @@ -214,6 +214,42 @@ def test_exhausted_entry_resets_after_ttl(tmp_path, monkeypatch): assert entry.last_status == "ok" +def test_exhausted_402_entry_resets_after_one_hour(tmp_path, monkeypatch): + """402-exhausted credentials recover after 1 hour, not 24.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) + _write_auth_store( + tmp_path, + { + "version": 1, + "credential_pool": { + "openrouter": [ + { + "id": "cred-1", + "label": "primary", + "auth_type": "api_key", + "priority": 0, + "source": "manual", + "access_token": "***", + "base_url": "https://openrouter.ai/api/v1", + "last_status": "exhausted", + "last_status_at": time.time() - 3700, # ~1h2m ago + "last_error_code": 402, + } + ] + }, + }, + ) + + from agent.credential_pool import load_pool + + pool = load_pool("openrouter") + entry = pool.select() + + assert entry is not None + assert entry.id == "cred-1" + assert entry.last_status == "ok" + + def test_explicit_reset_timestamp_overrides_default_429_ttl(tmp_path, monkeypatch): monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) _write_auth_store( From 851857e413a6acf591cf4f00ad58abc48fe6316a Mon Sep 17 00:00:00 2001 From: cokemine Date: Thu, 9 Apr 2026 16:15:37 +0900 Subject: [PATCH 34/49] fix(models): correct probed_url selection logic Updated the logic for determining the probed_url in the probe_api_models function to use the first tried URL instead of the last. This change ensures that the most relevant URL is returned when probing for models. Additionally, improved the output message in the _model_flow_custom function to provide clearer guidance based on the suggested_base_url. --- hermes_cli/main.py | 6 +++++- hermes_cli/models.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 5b180fc29d1..96345c48504 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -1474,7 +1474,11 @@ def _model_flow_custom(config): f"Hermes will still save it." ) if probe.get("suggested_base_url"): - print(f" If this server expects /v1, try base URL: {probe['suggested_base_url']}") + suggested = probe["suggested_base_url"] + if suggested.endswith("/v1"): + print(f" If this server expects /v1 in the path, try base URL: {suggested}") + else: + print(f" If /v1 should not be in the base URL, try: {suggested}") # Select model — use probe results when available, fall back to manual input model_name = "" diff --git a/hermes_cli/models.py b/hermes_cli/models.py index ce89bdeac03..b55249a70cb 100644 --- a/hermes_cli/models.py +++ b/hermes_cli/models.py @@ -1532,7 +1532,7 @@ def probe_api_models( return { "models": None, - "probed_url": tried[-1] if tried else normalized.rstrip("/") + "/models", + "probed_url": tried[0] if tried else normalized.rstrip("/") + "/models", "resolved_base_url": normalized, "suggested_base_url": alternate_base if alternate_base != normalized else None, "used_fallback": False, From a94099908aeb0d9bb948d72f738bcd7c2b83d57b Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Thu, 9 Apr 2026 02:41:56 -0700 Subject: [PATCH 35/49] fix(state): orphan children instead of cascade-deleting in prune/delete (#6513) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit prune_sessions and delete_session only handled direct children when satisfying the parent_session_id FK constraint. Multi-level chains (A -> B -> C) caused IntegrityError because deleting B while C still referenced it was blocked by the FK. Fix: NULL out parent_session_id for any session whose parent is about to be deleted. This orphans children instead of cascade-deleting them, which also respects the prune retention window — newer child sessions are no longer deleted just because an ancestor is old. Reported by Aaryan2304 in PR #6463. --- hermes_state.py | 46 ++++++++++------------ tests/test_hermes_state.py | 78 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 25 deletions(-) diff --git a/hermes_state.py b/hermes_state.py index da632a9e118..a845dbb9f90 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -1235,10 +1235,10 @@ class SessionDB: self._execute_write(_do) def delete_session(self, session_id: str) -> bool: - """Delete a session, its child sessions, and all their messages. + """Delete a session and all its messages. - Child sessions (subagent runs, compression continuations) are deleted - first to satisfy the ``parent_session_id`` foreign key constraint. + Child sessions are orphaned (parent_session_id set to NULL) rather + than cascade-deleted, so they remain accessible independently. Returns True if the session was found and deleted. """ def _do(conn): @@ -1247,15 +1247,12 @@ class SessionDB: ) if cursor.fetchone()[0] == 0: return False - # Delete child sessions first (FK constraint) - child_ids = [r[0] for r in conn.execute( - "SELECT id FROM sessions WHERE parent_session_id = ?", + # Orphan child sessions so FK constraint is satisfied + conn.execute( + "UPDATE sessions SET parent_session_id = NULL " + "WHERE parent_session_id = ?", (session_id,), - ).fetchall()] - for cid in child_ids: - conn.execute("DELETE FROM messages WHERE session_id = ?", (cid,)) - conn.execute("DELETE FROM sessions WHERE id = ?", (cid,)) - # Delete the session itself + ) conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) return True @@ -1264,9 +1261,9 @@ class SessionDB: def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int: """Delete sessions older than N days. Returns count of deleted sessions. - Only prunes ended sessions (not active ones). Child sessions whose - parents are being pruned are deleted first to satisfy the - ``parent_session_id`` foreign key constraint. + Only prunes ended sessions (not active ones). Child sessions outside + the prune window are orphaned (parent_session_id set to NULL) rather + than cascade-deleted. """ cutoff = time.time() - (older_than_days * 86400) @@ -1284,17 +1281,16 @@ class SessionDB: ) session_ids = set(row["id"] for row in cursor.fetchall()) - # Delete children first whose parents are in the prune set - # (avoids FK constraint errors) - for sid in list(session_ids): - child_ids = [r[0] for r in conn.execute( - "SELECT id FROM sessions WHERE parent_session_id = ?", - (sid,), - ).fetchall()] - for cid in child_ids: - conn.execute("DELETE FROM messages WHERE session_id = ?", (cid,)) - conn.execute("DELETE FROM sessions WHERE id = ?", (cid,)) - session_ids.discard(cid) # don't double-delete + if not session_ids: + return 0 + + # Orphan any sessions whose parent is about to be deleted + placeholders = ",".join("?" * len(session_ids)) + conn.execute( + f"UPDATE sessions SET parent_session_id = NULL " + f"WHERE parent_session_id IN ({placeholders})", + list(session_ids), + ) for sid in session_ids: conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,)) diff --git a/tests/test_hermes_state.py b/tests/test_hermes_state.py index a0630858c86..5f9a16a529c 100644 --- a/tests/test_hermes_state.py +++ b/tests/test_hermes_state.py @@ -663,6 +663,84 @@ class TestPruneSessions: assert db.get_session("old_cli") is None assert db.get_session("old_tg") is not None + def test_prune_with_multilevel_chain(self, db): + """Pruning old sessions orphans newer children instead of crashing on FK.""" + old_ts = time.time() - 200 * 86400 + recent_ts = time.time() - 10 * 86400 + + # Chain: A (old) -> B (old) -> C (recent) -> D (recent) + db.create_session(session_id="A", source="cli") + db.end_session("A", end_reason="compressed") + db.create_session(session_id="B", source="cli", parent_session_id="A") + db.end_session("B", end_reason="compressed") + db.create_session(session_id="C", source="cli", parent_session_id="B") + db.end_session("C", end_reason="compressed") + db.create_session(session_id="D", source="cli", parent_session_id="C") + db.end_session("D", end_reason="done") + + # Backdate A and B to be old; C and D stay recent + for sid, ts in [("A", old_ts), ("B", old_ts), ("C", recent_ts), ("D", recent_ts)]: + db._conn.execute( + "UPDATE sessions SET started_at = ? WHERE id = ?", (ts, sid) + ) + db._conn.commit() + + # Should not raise IntegrityError + pruned = db.prune_sessions(older_than_days=90) + assert pruned == 2 # only A and B + assert db.get_session("A") is None + assert db.get_session("B") is None + # C and D survive, C is orphaned (parent_session_id NULL) + c = db.get_session("C") + assert c is not None + assert c["parent_session_id"] is None + d = db.get_session("D") + assert d is not None + assert d["parent_session_id"] == "C" + + def test_prune_entire_old_chain(self, db): + """All sessions in a chain are old — entire chain is pruned.""" + old_ts = time.time() - 200 * 86400 + + db.create_session(session_id="X", source="cli") + db.end_session("X", end_reason="compressed") + db.create_session(session_id="Y", source="cli", parent_session_id="X") + db.end_session("Y", end_reason="compressed") + db.create_session(session_id="Z", source="cli", parent_session_id="Y") + db.end_session("Z", end_reason="done") + + for sid in ("X", "Y", "Z"): + db._conn.execute( + "UPDATE sessions SET started_at = ? WHERE id = ?", (old_ts, sid) + ) + db._conn.commit() + + pruned = db.prune_sessions(older_than_days=90) + assert pruned == 3 + for sid in ("X", "Y", "Z"): + assert db.get_session(sid) is None + + +class TestDeleteSessionOrphansChildren: + def test_delete_orphans_children(self, db): + """Deleting a parent session orphans its children.""" + db.create_session(session_id="parent", source="cli") + db.create_session(session_id="child", source="cli", parent_session_id="parent") + db.create_session(session_id="grandchild", source="cli", parent_session_id="child") + + # Should not raise IntegrityError + result = db.delete_session("parent") + assert result is True + assert db.get_session("parent") is None + # Child is orphaned, not deleted + child = db.get_session("child") + assert child is not None + assert child["parent_session_id"] is None + # Grandchild is untouched + grandchild = db.get_session("grandchild") + assert grandchild is not None + assert grandchild["parent_session_id"] == "child" + # ========================================================================= # Schema and WAL mode From e22416dd9b47cf69cf339ec00a6a515b18d8ce5f Mon Sep 17 00:00:00 2001 From: Lumen Radley Date: Tue, 7 Apr 2026 23:44:12 +0200 Subject: [PATCH 36/49] fix: handle empty sudo password and false prompts --- cli-config.yaml.example | 10 ++- cli.py | 34 ++++++- hermes_cli/config.py | 2 +- tests/cli/test_cli_approval_ui.py | 47 +++++++++- tests/tools/test_terminal_tool.py | 90 +++++++++++++++++++ tools/terminal_tool.py | 145 ++++++++++++++++++++++++------ 6 files changed, 293 insertions(+), 35 deletions(-) create mode 100644 tests/tools/test_terminal_tool.py diff --git a/cli-config.yaml.example b/cli-config.yaml.example index af0917dedc6..d75284443f5 100644 --- a/cli-config.yaml.example +++ b/cli-config.yaml.example @@ -117,7 +117,8 @@ terminal: timeout: 180 docker_mount_cwd_to_workspace: false # SECURITY: off by default. Opt in to mount the launch cwd into Docker /workspace. lifetime_seconds: 300 - # sudo_password: "" # Enable sudo commands (pipes via sudo -S) - SECURITY WARNING: plaintext! + # sudo_password: "hunter2" # Optional: pipe a sudo password via sudo -S. SECURITY WARNING: plaintext. + # sudo_password: "" # Explicit empty password: try empty and never open the interactive sudo prompt. # ----------------------------------------------------------------------------- # OPTION 2: SSH remote execution @@ -208,13 +209,18 @@ terminal: # # SECURITY WARNING: Password stored in plaintext! # -# INTERACTIVE PROMPT: If no sudo_password is set and the CLI is running, +# INTERACTIVE PROMPT: If sudo_password is unset and the CLI is running, # you'll be prompted to enter your password when sudo is needed: # - 45-second timeout (auto-skips if no input) # - Press Enter to skip (command fails gracefully) # - Password is hidden while typing # - Password is cached for the session # +# EMPTY PASSWORDS: Setting sudo_password to an explicit empty string is different +# from leaving it unset. Hermes will try an empty password via `sudo -S` and +# will not open the interactive prompt. This is useful for passwordless sudo, +# Touch ID sudo setups, and environments where prompting is just noise. +# # ALTERNATIVES: # - SSH backend: Configure passwordless sudo on the remote server # - Containers: Run as root inside the container (no sudo needed) diff --git a/cli.py b/cli.py index f0edf67ee29..324bb056901 100644 --- a/cli.py +++ b/cli.py @@ -1546,6 +1546,7 @@ class HermesCLI: self._clarify_deadline = 0 self._sudo_state = None self._sudo_deadline = 0 + self._modal_input_snapshot = None self._approval_state = None self._approval_deadline = 0 self._approval_lock = threading.Lock() @@ -6205,6 +6206,7 @@ class HermesCLI: timeout = 45 response_queue = queue.Queue() + self._capture_modal_input_snapshot() self._sudo_state = { "response_queue": response_queue, } @@ -6217,6 +6219,7 @@ class HermesCLI: result = response_queue.get(timeout=1) self._sudo_state = None self._sudo_deadline = 0 + self._restore_modal_input_snapshot() self._invalidate() if result: _cprint(f"\n{_DIM} ✓ Password received (cached for session){_RST}") @@ -6231,6 +6234,7 @@ class HermesCLI: self._sudo_state = None self._sudo_deadline = 0 + self._restore_modal_input_snapshot() self._invalidate() _cprint(f"\n{_DIM} ⏱ Timeout — continuing without sudo{_RST}") return "" @@ -6403,6 +6407,33 @@ class HermesCLI: def _secret_capture_callback(self, var_name: str, prompt: str, metadata=None) -> dict: return prompt_for_secret(self, var_name, prompt, metadata) + def _capture_modal_input_snapshot(self) -> None: + """Temporarily clear the input buffer and save the user's in-progress draft.""" + if self._modal_input_snapshot is not None or not getattr(self, "_app", None): + return + try: + buf = self._app.current_buffer + self._modal_input_snapshot = { + "text": buf.text, + "cursor_position": buf.cursor_position, + } + buf.reset() + except Exception: + self._modal_input_snapshot = None + + def _restore_modal_input_snapshot(self) -> None: + """Restore any draft text that was present before a modal prompt opened.""" + snapshot = self._modal_input_snapshot + self._modal_input_snapshot = None + if not snapshot or not getattr(self, "_app", None): + return + try: + buf = self._app.current_buffer + buf.text = snapshot.get("text", "") + buf.cursor_position = min(snapshot.get("cursor_position", 0), len(buf.text)) + except Exception: + pass + def _submit_secret_response(self, value: str) -> None: if not self._secret_state: return @@ -7130,6 +7161,7 @@ class HermesCLI: # Sudo password prompt state (similar mechanism to clarify) self._sudo_state = None # dict with response_queue when active self._sudo_deadline = 0 + self._modal_input_snapshot = None # Dangerous command approval state (similar mechanism to clarify) self._approval_state = None # dict with command, description, choices, selected, response_queue @@ -7201,7 +7233,6 @@ class HermesCLI: text = event.app.current_buffer.text self._sudo_state["response_queue"].put(text) self._sudo_state = None - event.app.current_buffer.reset() event.app.invalidate() return @@ -7406,7 +7437,6 @@ class HermesCLI: if self._sudo_state: self._sudo_state["response_queue"].put("") self._sudo_state = None - event.app.current_buffer.reset() event.app.invalidate() return diff --git a/hermes_cli/config.py b/hermes_cli/config.py index cf988967956..387bef667d5 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -1217,7 +1217,7 @@ OPTIONAL_ENV_VARS = { "category": "setting", }, "SUDO_PASSWORD": { - "description": "Sudo password for terminal commands requiring root access", + "description": "Sudo password for terminal commands requiring root access; set to an explicit empty string to try empty without prompting", "prompt": "Sudo password", "url": None, "password": True, diff --git a/tests/cli/test_cli_approval_ui.py b/tests/cli/test_cli_approval_ui.py index 9b2e0bbb266..63e03b9ab90 100644 --- a/tests/cli/test_cli_approval_ui.py +++ b/tests/cli/test_cli_approval_ui.py @@ -2,22 +2,65 @@ import queue import threading import time from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch +import cli as cli_module from cli import HermesCLI +class _FakeBuffer: + def __init__(self, text="", cursor_position=None): + self.text = text + self.cursor_position = len(text) if cursor_position is None else cursor_position + + def reset(self, append_to_history=False): + self.text = "" + self.cursor_position = 0 + + def _make_cli_stub(): cli = HermesCLI.__new__(HermesCLI) cli._approval_state = None cli._approval_deadline = 0 cli._approval_lock = threading.Lock() + cli._sudo_state = None + cli._sudo_deadline = 0 + cli._modal_input_snapshot = None cli._invalidate = MagicMock() - cli._app = SimpleNamespace(invalidate=MagicMock()) + cli._app = SimpleNamespace(invalidate=MagicMock(), current_buffer=_FakeBuffer()) return cli class TestCliApprovalUi: + def test_sudo_prompt_restores_existing_draft_after_response(self): + cli = _make_cli_stub() + cli._app.current_buffer = _FakeBuffer("draft command", cursor_position=5) + result = {} + + def _run_callback(): + result["value"] = cli._sudo_password_callback() + + with patch.object(cli_module, "_cprint"): + thread = threading.Thread(target=_run_callback, daemon=True) + thread.start() + + deadline = time.time() + 2 + while cli._sudo_state is None and time.time() < deadline: + time.sleep(0.01) + + assert cli._sudo_state is not None + assert cli._app.current_buffer.text == "" + + cli._app.current_buffer.text = "secret" + cli._app.current_buffer.cursor_position = len("secret") + cli._sudo_state["response_queue"].put("secret") + + thread.join(timeout=2) + + assert result["value"] == "secret" + assert cli._app.current_buffer.text == "draft command" + assert cli._app.current_buffer.cursor_position == 5 + def test_approval_callback_includes_view_for_long_commands(self): cli = _make_cli_stub() command = "sudo dd if=/tmp/githubcli-keyring.gpg of=/usr/share/keyrings/githubcli-archive-keyring.gpg bs=4M status=progress" diff --git a/tests/tools/test_terminal_tool.py b/tests/tools/test_terminal_tool.py new file mode 100644 index 00000000000..42ed693a2ea --- /dev/null +++ b/tests/tools/test_terminal_tool.py @@ -0,0 +1,90 @@ +"""Regression tests for sudo detection and sudo password handling.""" + +import tools.terminal_tool as terminal_tool + + +def setup_function(): + terminal_tool._cached_sudo_password = "" + + +def teardown_function(): + terminal_tool._cached_sudo_password = "" + + +def test_searching_for_sudo_does_not_trigger_rewrite(monkeypatch): + monkeypatch.delenv("SUDO_PASSWORD", raising=False) + monkeypatch.delenv("HERMES_INTERACTIVE", raising=False) + + command = "rg --line-number --no-heading --with-filename 'sudo' . | head -n 20" + transformed, sudo_stdin = terminal_tool._transform_sudo_command(command) + + assert transformed == command + assert sudo_stdin is None + + +def test_printf_literal_sudo_does_not_trigger_rewrite(monkeypatch): + monkeypatch.delenv("SUDO_PASSWORD", raising=False) + monkeypatch.delenv("HERMES_INTERACTIVE", raising=False) + + command = "printf '%s\\n' sudo" + transformed, sudo_stdin = terminal_tool._transform_sudo_command(command) + + assert transformed == command + assert sudo_stdin is None + + +def test_non_command_argument_named_sudo_does_not_trigger_rewrite(monkeypatch): + monkeypatch.delenv("SUDO_PASSWORD", raising=False) + monkeypatch.delenv("HERMES_INTERACTIVE", raising=False) + + command = "grep -n sudo README.md" + transformed, sudo_stdin = terminal_tool._transform_sudo_command(command) + + assert transformed == command + assert sudo_stdin is None + + +def test_actual_sudo_command_uses_configured_password(monkeypatch): + monkeypatch.setenv("SUDO_PASSWORD", "testpass") + monkeypatch.delenv("HERMES_INTERACTIVE", raising=False) + + transformed, sudo_stdin = terminal_tool._transform_sudo_command("sudo apt install -y ripgrep") + + assert transformed == "sudo -S -p '' apt install -y ripgrep" + assert sudo_stdin == "testpass\n" + + +def test_actual_sudo_after_leading_env_assignment_is_rewritten(monkeypatch): + monkeypatch.setenv("SUDO_PASSWORD", "testpass") + monkeypatch.delenv("HERMES_INTERACTIVE", raising=False) + + transformed, sudo_stdin = terminal_tool._transform_sudo_command("DEBUG=1 sudo whoami") + + assert transformed == "DEBUG=1 sudo -S -p '' whoami" + assert sudo_stdin == "testpass\n" + + +def test_explicit_empty_sudo_password_tries_empty_without_prompt(monkeypatch): + monkeypatch.setenv("SUDO_PASSWORD", "") + monkeypatch.setenv("HERMES_INTERACTIVE", "1") + + def _fail_prompt(*_args, **_kwargs): + raise AssertionError("interactive sudo prompt should not run for explicit empty password") + + monkeypatch.setattr(terminal_tool, "_prompt_for_sudo_password", _fail_prompt) + + transformed, sudo_stdin = terminal_tool._transform_sudo_command("sudo true") + + assert transformed == "sudo -S -p '' true" + assert sudo_stdin == "\n" + + +def test_cached_sudo_password_is_used_when_env_is_unset(monkeypatch): + monkeypatch.delenv("SUDO_PASSWORD", raising=False) + monkeypatch.delenv("HERMES_INTERACTIVE", raising=False) + terminal_tool._cached_sudo_password = "cached-pass" + + transformed, sudo_stdin = terminal_tool._transform_sudo_command("echo ok && sudo whoami") + + assert transformed == "echo ok && sudo -S -p '' whoami" + assert sudo_stdin == "cached-pass\n" diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index 96a1147759c..0dc0fd58729 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -326,7 +326,6 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str: if "HERMES_SPINNER_PAUSE" in os.environ: del os.environ["HERMES_SPINNER_PAUSE"] - def _safe_command_preview(command: Any, limit: int = 200) -> str: """Return a log-safe preview for possibly-invalid command values.""" if command is None: @@ -338,6 +337,110 @@ def _safe_command_preview(command: Any, limit: int = 200) -> str: except Exception: return f"<{type(command).__name__}>" +def _looks_like_env_assignment(token: str) -> bool: + """Return True when *token* is a leading shell environment assignment.""" + if "=" not in token or token.startswith("="): + return False + name, _value = token.split("=", 1) + return bool(re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", name)) + + +def _read_shell_token(command: str, start: int) -> tuple[str, int]: + """Read one shell token, preserving quotes/escapes, starting at *start*.""" + i = start + n = len(command) + + while i < n: + ch = command[i] + if ch.isspace() or ch in ";|&()": + break + if ch == "'": + i += 1 + while i < n and command[i] != "'": + i += 1 + if i < n: + i += 1 + continue + if ch == '"': + i += 1 + while i < n: + inner = command[i] + if inner == "\\" and i + 1 < n: + i += 2 + continue + if inner == '"': + i += 1 + break + i += 1 + continue + if ch == "\\" and i + 1 < n: + i += 2 + continue + i += 1 + + return command[start:i], i + + +def _rewrite_real_sudo_invocations(command: str) -> tuple[str, bool]: + """Rewrite only real unquoted sudo command words, not plain text mentions.""" + out: list[str] = [] + i = 0 + n = len(command) + command_start = True + found = False + + while i < n: + ch = command[i] + + if ch.isspace(): + out.append(ch) + if ch == "\n": + command_start = True + i += 1 + continue + + if ch == "#" and command_start: + comment_end = command.find("\n", i) + if comment_end == -1: + out.append(command[i:]) + break + out.append(command[i:comment_end]) + i = comment_end + continue + + if command.startswith("&&", i) or command.startswith("||", i) or command.startswith(";;", i): + out.append(command[i:i + 2]) + i += 2 + command_start = True + continue + + if ch in ";|&(": + out.append(ch) + i += 1 + command_start = True + continue + + if ch == ")": + out.append(ch) + i += 1 + command_start = False + continue + + token, next_i = _read_shell_token(command, i) + if command_start and token == "sudo": + out.append("sudo -S -p ''") + found = True + else: + out.append(token) + + if command_start and _looks_like_env_assignment(token): + command_start = True + else: + command_start = False + i = next_i + + return "".join(out), found + def _transform_sudo_command(command: str | None) -> tuple[str | None, str | None]: """ @@ -374,40 +477,26 @@ def _transform_sudo_command(command: str | None) -> tuple[str | None, str | None Command runs as-is (fails gracefully with "sudo: a password is required"). """ global _cached_sudo_password - import re - # Check if command even contains sudo if command is None: return None, None + transformed, has_real_sudo = _rewrite_real_sudo_invocations(command) + if not has_real_sudo: + return command, None - if not re.search(r'\bsudo\b', command): - return command, None # No sudo in command, nothing to do + has_configured_password = "SUDO_PASSWORD" in os.environ + sudo_password = os.environ.get("SUDO_PASSWORD", "") if has_configured_password else _cached_sudo_password - # Try to get password from: env var -> session cache -> interactive prompt - sudo_password = os.getenv("SUDO_PASSWORD", "") or _cached_sudo_password + if not has_configured_password and not sudo_password and os.getenv("HERMES_INTERACTIVE"): + sudo_password = _prompt_for_sudo_password(timeout_seconds=45) + if sudo_password: + _cached_sudo_password = sudo_password - if not sudo_password: - # No password configured - check if we're in interactive mode - if os.getenv("HERMES_INTERACTIVE"): - # Prompt user for password - sudo_password = _prompt_for_sudo_password(timeout_seconds=45) - if sudo_password: - _cached_sudo_password = sudo_password # Cache for session + if has_configured_password or sudo_password: + # Trailing newline is required: sudo -S reads one line for the password. + return transformed, sudo_password + "\n" - if not sudo_password: - return command, None # No password, let it fail gracefully - - def replace_sudo(match): - # Replace bare 'sudo' with 'sudo -S -p ""'. - # The password is returned as sudo_stdin and must be written to the - # process's stdin pipe by the caller — it never appears in any - # command-line argument or shell string. - return "sudo -S -p ''" - - # Match 'sudo' at word boundaries (not 'visudo' or 'sudoers') - transformed = re.sub(r'\bsudo\b', replace_sudo, command) - # Trailing newline is required: sudo -S reads one line for the password. - return transformed, sudo_password + "\n" + return command, None # Environment classes now live in tools/environments/ From 161c2c4da4339d1cc6fda62e433dc7d39508ee78 Mon Sep 17 00:00:00 2001 From: Kira Date: Thu, 9 Apr 2026 01:04:06 -0400 Subject: [PATCH 37/49] fix(skills): archive OpenClaw cron store without config --- .../scripts/openclaw_to_hermes.py | 54 +++++++++++++------ tests/skills/test_openclaw_migration.py | 41 ++++++++++++++ 2 files changed, 78 insertions(+), 17 deletions(-) diff --git a/optional-skills/migration/openclaw-migration/scripts/openclaw_to_hermes.py b/optional-skills/migration/openclaw-migration/scripts/openclaw_to_hermes.py index 74e9d7dac35..5e0f76db284 100644 --- a/optional-skills/migration/openclaw-migration/scripts/openclaw_to_hermes.py +++ b/optional-skills/migration/openclaw-migration/scripts/openclaw_to_hermes.py @@ -1803,30 +1803,34 @@ class Migrator: def migrate_cron_jobs(self, config: Optional[Dict[str, Any]] = None) -> None: config = config or self.load_openclaw_config() cron = config.get("cron") or {} - if not cron: - self.record("cron-jobs", None, None, "skipped", "No cron configuration found") - return - - # Archive the full cron config - if self.archive_dir and self.execute: - self.archive_dir.mkdir(parents=True, exist_ok=True) - dest = self.archive_dir / "cron-config.json" - dest.write_text(json.dumps(cron, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") - self.record("cron-jobs", "openclaw.json cron.*", str(dest), "archived", - "Cron config archived. Use 'hermes cron' to recreate jobs manually.") - else: - self.record("cron-jobs", "openclaw.json cron.*", "archive/cron-config.json", - "archived", "Would archive cron config") - - # Also check for cron store files cron_store = self.source_root / "cron" + found_any = False + + # Archive the full cron config when present + if cron: + found_any = True + if self.archive_dir and self.execute: + self.archive_dir.mkdir(parents=True, exist_ok=True) + dest = self.archive_dir / "cron-config.json" + dest.write_text(json.dumps(cron, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") + self.record("cron-jobs", "openclaw.json cron.*", str(dest), "archived", + "Cron config archived. Use 'hermes cron' to recreate jobs manually.") + else: + self.record("cron-jobs", "openclaw.json cron.*", "archive/cron-config.json", + "archived", "Would archive cron config") + + # Also check for cron store files even when config.cron is missing if cron_store.is_dir() and self.archive_dir: + found_any = True dest_cron = self.archive_dir / "cron-store" if self.execute: shutil.copytree(cron_store, dest_cron, dirs_exist_ok=True) self.record("cron-jobs", str(cron_store), str(dest_cron), "archived", "Cron job store archived") + if not found_any: + self.record("cron-jobs", None, None, "skipped", "No cron configuration found") + # ── Hooks ───────────────────────────────────────────────── def migrate_hooks_config(self, config: Optional[Dict[str, Any]] = None) -> None: config = config or self.load_openclaw_config() @@ -2454,6 +2458,15 @@ class Migrator: notes.append(f"- **{item.kind}**: {item.reason}") notes.append("") + has_cron_config_archive = any( + i.kind == "cron-jobs" and i.status == "archived" and i.destination and i.destination.endswith("cron-config.json") + for i in self.items + ) + has_cron_store_archive = any( + i.kind == "cron-jobs" and i.status == "archived" and i.destination and i.destination.endswith("cron-store") + for i in self.items + ) + notes.extend([ "## IMPORTANT: Archive the OpenClaw Directory", "", @@ -2475,7 +2488,14 @@ class Migrator: "- Run `hermes claw cleanup` to archive the OpenClaw directory (prevents state confusion)", "- Run `hermes setup` to configure any remaining settings", "- Run `hermes mcp list` to verify MCP servers were imported correctly", - "- Run `hermes cron` to recreate scheduled tasks (see archive/cron-config.json)", + ]) + + if has_cron_config_archive: + notes.append("- Run `hermes cron` to recreate scheduled tasks (see archive/cron-config.json)") + elif has_cron_store_archive: + notes.append("- Run `hermes cron` to recreate scheduled tasks (see archived cron-store)") + + notes.extend([ "- Run `hermes gateway install` if you need the gateway service", "- Review `~/.hermes/config.yaml` for any adjustments", "", diff --git a/tests/skills/test_openclaw_migration.py b/tests/skills/test_openclaw_migration.py index d4aa8f710ef..99d126bed57 100644 --- a/tests/skills/test_openclaw_migration.py +++ b/tests/skills/test_openclaw_migration.py @@ -658,6 +658,47 @@ def test_workspace_agents_records_skip_when_missing(tmp_path: Path): assert wa_items[0]["status"] == "skipped" +def test_cron_store_is_archived_without_config_cron_section(tmp_path: Path): + """Bug fix: archive cron store even when openclaw.json has no top-level cron config.""" + mod = load_module() + source = tmp_path / ".openclaw" + target = tmp_path / ".hermes" + output_dir = target / "migration-report" + source.mkdir() + target.mkdir() + + (source / "openclaw.json").write_text(json.dumps({"channels": {}}), encoding="utf-8") + (source / "cron").mkdir(parents=True) + (source / "cron" / "jobs.json").write_text( + json.dumps({"version": 1, "jobs": [{"id": "job-1", "name": "demo"}]}), + encoding="utf-8", + ) + + migrator = mod.Migrator( + source_root=source, + target_root=target, + execute=True, + workspace_target=None, + overwrite=False, + migrate_secrets=False, + output_dir=output_dir, + selected_options={"cron-jobs"}, + ) + report = migrator.migrate() + + cron_items = [item for item in report["items"] if item["kind"] == "cron-jobs"] + archived_store = next( + (item for item in cron_items if item["destination"] and item["destination"].endswith("archive/cron-store")), + None, + ) + assert archived_store is not None + assert Path(archived_store["destination"]).joinpath("jobs.json").exists() + + notes_text = (output_dir / "MIGRATION_NOTES.md").read_text(encoding="utf-8") + assert "Run `hermes cron` to recreate scheduled tasks" in notes_text + assert "archive/cron-config.json" not in notes_text + + def test_skill_installs_cleanly_under_skills_guard(): skills_guard = load_skills_guard() result = skills_guard.scan_skill( From 3c8ec7037c6aa59d04faf9c2ef5a1fee02c6fb26 Mon Sep 17 00:00:00 2001 From: konsisumer Date: Thu, 9 Apr 2026 07:23:34 +0200 Subject: [PATCH 38/49] fix(agent): catch PermissionError in subdirectory hint discovery Wrap is_dir() in _is_valid_subdir() and is_file() in _load_hints_for_directory() with OSError handlers so that inaccessible directories (e.g. /root from a non-root Daytona host user) are silently skipped instead of crashing the agent. The existing PermissionError PRs for prompt_builder.py (#6247, #6321, #6355) do not cover subdirectory_hints.py, which was identified as a separate crash path in the #6214 comments. Ref: #6214 --- agent/subdirectory_hints.py | 10 ++++-- tests/agent/test_subdirectory_hints.py | 43 ++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/agent/subdirectory_hints.py b/agent/subdirectory_hints.py index 96903e2e281..dcc514b9014 100644 --- a/agent/subdirectory_hints.py +++ b/agent/subdirectory_hints.py @@ -159,7 +159,10 @@ class SubdirectoryHintTracker: def _is_valid_subdir(self, path: Path) -> bool: """Check if path is a valid directory to scan for hints.""" - if not path.is_dir(): + try: + if not path.is_dir(): + return False + except OSError: return False if path in self._loaded_dirs: return False @@ -172,7 +175,10 @@ class SubdirectoryHintTracker: found_hints = [] for filename in _HINT_FILENAMES: hint_path = directory / filename - if not hint_path.is_file(): + try: + if not hint_path.is_file(): + continue + except OSError: continue try: content = hint_path.read_text(encoding="utf-8").strip() diff --git a/tests/agent/test_subdirectory_hints.py b/tests/agent/test_subdirectory_hints.py index 7d2bc607c85..7c1a74e66cc 100644 --- a/tests/agent/test_subdirectory_hints.py +++ b/tests/agent/test_subdirectory_hints.py @@ -3,6 +3,7 @@ import os import pytest from pathlib import Path +from unittest.mock import patch from agent.subdirectory_hints import SubdirectoryHintTracker @@ -189,3 +190,45 @@ class TestSubdirectoryHintTracker: "terminal", {"command": "curl https://example.com/frontend/api"} ) assert result is None + + +class TestPermissionErrorHandling: + """Regression tests for PermissionError in filesystem checks (ref #6214).""" + + def test_is_valid_subdir_permission_error(self, tmp_path): + """_is_valid_subdir should return False when is_dir() raises PermissionError.""" + tracker = SubdirectoryHintTracker(working_dir=str(tmp_path)) + restricted = tmp_path / "restricted" + restricted.mkdir() + with patch.object(Path, "is_dir", side_effect=PermissionError("Permission denied")): + assert tracker._is_valid_subdir(restricted) is False + + def test_load_hints_permission_error_on_is_file(self, tmp_path): + """_load_hints_for_directory should skip files when is_file() raises PermissionError.""" + tracker = SubdirectoryHintTracker(working_dir=str(tmp_path)) + restricted = tmp_path / "restricted" + restricted.mkdir() + original_is_file = Path.is_file + def patched_is_file(self): + if "restricted" in str(self): + raise PermissionError("Permission denied") + return original_is_file(self) + with patch.object(Path, "is_file", patched_is_file): + result = tracker._load_hints_for_directory(restricted) + assert result is None + + def test_check_tool_call_survives_inaccessible_path(self, project): + """Full check_tool_call should not crash when a path is inaccessible.""" + tracker = SubdirectoryHintTracker(working_dir=str(project)) + original_is_dir = Path.is_dir + def patched_is_dir(self): + if "backend" in str(self) and "src" not in str(self): + raise PermissionError("Permission denied") + return original_is_dir(self) + with patch.object(Path, "is_dir", patched_is_dir): + # Should not raise — gracefully skip the inaccessible directory + result = tracker.check_tool_call( + "read_file", {"path": str(project / "backend" / "src" / "main.py")} + ) + # Result may be None (backend skipped) — the key point is no crash + assert result is None or isinstance(result, str) From 8dfc96dbbb26625badd5607ed880d37bbaf9c672 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Thu, 9 Apr 2026 03:43:14 -0700 Subject: [PATCH 39/49] feat: capture provider rate limit headers and show in /usage (#6541) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Parse x-ratelimit-* headers from inference API responses (Nous Portal, OpenRouter, OpenAI-compatible) and display them in the /usage command. - New agent/rate_limit_tracker.py: parse 12 rate limit headers (RPM/RPH/ TPM/TPH limits, remaining, reset timers), format as progress bars (CLI) or compact one-liner (gateway) - Hook into streaming path in run_agent.py: stream.response.headers is available on the OpenAI SDK Stream object before chunks are consumed - CLI /usage: appends rate limit section with progress bars + warnings when any bucket exceeds 80% - Gateway /usage: appends compact rate limit summary - 24 unit tests covering parsing, formatting, edge cases Headers captured per response: x-ratelimit-{limit,remaining,reset}-{requests,tokens}{,-1h} Example CLI display: Nous Rate Limits (captured just now): Requests/min [░░░░░░░░░░░░░░░░░░░░] 0.1% 1/800 used (799 left, resets in 59s) Tokens/hr [░░░░░░░░░░░░░░░░░░░░] 0.0% 49/336.0M (336.0M left, resets in 52m) --- agent/rate_limit_tracker.py | 242 +++++++++++++++++++++++++ cli.py | 23 ++- gateway/run.py | 23 ++- hermes_cli/commands.py | 2 +- run_agent.py | 32 ++++ tests/agent/test_rate_limit_tracker.py | 212 ++++++++++++++++++++++ 6 files changed, 519 insertions(+), 15 deletions(-) create mode 100644 agent/rate_limit_tracker.py create mode 100644 tests/agent/test_rate_limit_tracker.py diff --git a/agent/rate_limit_tracker.py b/agent/rate_limit_tracker.py new file mode 100644 index 00000000000..c87e096a1de --- /dev/null +++ b/agent/rate_limit_tracker.py @@ -0,0 +1,242 @@ +"""Rate limit tracking for inference API responses. + +Captures x-ratelimit-* headers from provider responses and provides +formatted display for the /usage slash command. Currently supports +the Nous Portal header format (also used by OpenRouter and OpenAI-compatible +APIs that follow the same convention). + +Header schema (12 headers total): + x-ratelimit-limit-requests RPM cap + x-ratelimit-limit-requests-1h RPH cap + x-ratelimit-limit-tokens TPM cap + x-ratelimit-limit-tokens-1h TPH cap + x-ratelimit-remaining-requests requests left in minute window + x-ratelimit-remaining-requests-1h requests left in hour window + x-ratelimit-remaining-tokens tokens left in minute window + x-ratelimit-remaining-tokens-1h tokens left in hour window + x-ratelimit-reset-requests seconds until minute request window resets + x-ratelimit-reset-requests-1h seconds until hour request window resets + x-ratelimit-reset-tokens seconds until minute token window resets + x-ratelimit-reset-tokens-1h seconds until hour token window resets +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import Any, Dict, Mapping, Optional + + +@dataclass +class RateLimitBucket: + """One rate-limit window (e.g. requests per minute).""" + + limit: int = 0 + remaining: int = 0 + reset_seconds: float = 0.0 + captured_at: float = 0.0 # time.time() when this was captured + + @property + def used(self) -> int: + return max(0, self.limit - self.remaining) + + @property + def usage_pct(self) -> float: + if self.limit <= 0: + return 0.0 + return (self.used / self.limit) * 100.0 + + @property + def remaining_seconds_now(self) -> float: + """Estimated seconds remaining until reset, adjusted for elapsed time.""" + elapsed = time.time() - self.captured_at + return max(0.0, self.reset_seconds - elapsed) + + +@dataclass +class RateLimitState: + """Full rate-limit state parsed from response headers.""" + + requests_min: RateLimitBucket = field(default_factory=RateLimitBucket) + requests_hour: RateLimitBucket = field(default_factory=RateLimitBucket) + tokens_min: RateLimitBucket = field(default_factory=RateLimitBucket) + tokens_hour: RateLimitBucket = field(default_factory=RateLimitBucket) + captured_at: float = 0.0 # when the headers were captured + provider: str = "" + + @property + def has_data(self) -> bool: + return self.captured_at > 0 + + @property + def age_seconds(self) -> float: + if not self.has_data: + return float("inf") + return time.time() - self.captured_at + + +def _safe_int(value: Any, default: int = 0) -> int: + try: + return int(float(value)) + except (TypeError, ValueError): + return default + + +def _safe_float(value: Any, default: float = 0.0) -> float: + try: + return float(value) + except (TypeError, ValueError): + return default + + +def parse_rate_limit_headers( + headers: Mapping[str, str], + provider: str = "", +) -> Optional[RateLimitState]: + """Parse x-ratelimit-* headers into a RateLimitState. + + Returns None if no rate limit headers are present. + """ + # Quick check: at least one rate limit header must exist + has_any = any(k.lower().startswith("x-ratelimit-") for k in headers) + if not has_any: + return None + + now = time.time() + + def _bucket(resource: str, suffix: str = "") -> RateLimitBucket: + # e.g. resource="requests", suffix="" -> per-minute + # resource="tokens", suffix="-1h" -> per-hour + tag = f"{resource}{suffix}" + return RateLimitBucket( + limit=_safe_int(headers.get(f"x-ratelimit-limit-{tag}")), + remaining=_safe_int(headers.get(f"x-ratelimit-remaining-{tag}")), + reset_seconds=_safe_float(headers.get(f"x-ratelimit-reset-{tag}")), + captured_at=now, + ) + + return RateLimitState( + requests_min=_bucket("requests"), + requests_hour=_bucket("requests", "-1h"), + tokens_min=_bucket("tokens"), + tokens_hour=_bucket("tokens", "-1h"), + captured_at=now, + provider=provider, + ) + + +# ── Formatting ────────────────────────────────────────────────────────── + + +def _fmt_count(n: int) -> str: + """Human-friendly number: 7999856 -> '8.0M', 33599 -> '33.6K', 799 -> '799'.""" + if n >= 1_000_000: + return f"{n / 1_000_000:.1f}M" + if n >= 10_000: + return f"{n / 1_000:.1f}K" + if n >= 1_000: + return f"{n / 1_000:.1f}K" + return str(n) + + +def _fmt_seconds(seconds: float) -> str: + """Seconds -> human-friendly duration: '58s', '2m 14s', '58m 57s', '1h 2m'.""" + s = max(0, int(seconds)) + if s < 60: + return f"{s}s" + if s < 3600: + m, sec = divmod(s, 60) + return f"{m}m {sec}s" if sec else f"{m}m" + h, remainder = divmod(s, 3600) + m = remainder // 60 + return f"{h}h {m}m" if m else f"{h}h" + + +def _bar(pct: float, width: int = 20) -> str: + """ASCII progress bar: [████████░░░░░░░░░░░░] 40%.""" + filled = int(pct / 100.0 * width) + filled = max(0, min(width, filled)) + empty = width - filled + return f"[{'█' * filled}{'░' * empty}]" + + +def _bucket_line(label: str, bucket: RateLimitBucket, label_width: int = 14) -> str: + """Format one bucket as a single line.""" + if bucket.limit <= 0: + return f" {label:<{label_width}} (no data)" + + pct = bucket.usage_pct + used = _fmt_count(bucket.used) + limit = _fmt_count(bucket.limit) + remaining = _fmt_count(bucket.remaining) + reset = _fmt_seconds(bucket.remaining_seconds_now) + + bar = _bar(pct) + return f" {label:<{label_width}} {bar} {pct:5.1f}% {used}/{limit} used ({remaining} left, resets in {reset})" + + +def format_rate_limit_display(state: RateLimitState) -> str: + """Format rate limit state for terminal/chat display.""" + if not state.has_data: + return "No rate limit data yet — make an API request first." + + age = state.age_seconds + if age < 5: + freshness = "just now" + elif age < 60: + freshness = f"{int(age)}s ago" + else: + freshness = f"{_fmt_seconds(age)} ago" + + provider_label = state.provider.title() if state.provider else "Provider" + + lines = [ + f"{provider_label} Rate Limits (captured {freshness}):", + "", + _bucket_line("Requests/min", state.requests_min), + _bucket_line("Requests/hr", state.requests_hour), + "", + _bucket_line("Tokens/min", state.tokens_min), + _bucket_line("Tokens/hr", state.tokens_hour), + ] + + # Add warnings if any bucket is getting hot + warnings = [] + for label, bucket in [ + ("requests/min", state.requests_min), + ("requests/hr", state.requests_hour), + ("tokens/min", state.tokens_min), + ("tokens/hr", state.tokens_hour), + ]: + if bucket.limit > 0 and bucket.usage_pct >= 80: + reset = _fmt_seconds(bucket.remaining_seconds_now) + warnings.append(f" ⚠ {label} at {bucket.usage_pct:.0f}% — resets in {reset}") + + if warnings: + lines.append("") + lines.extend(warnings) + + return "\n".join(lines) + + +def format_rate_limit_compact(state: RateLimitState) -> str: + """One-line compact summary for status bars / gateway messages.""" + if not state.has_data: + return "No rate limit data." + + rm = state.requests_min + tm = state.tokens_min + rh = state.requests_hour + th = state.tokens_hour + + parts = [] + if rm.limit > 0: + parts.append(f"RPM: {rm.remaining}/{rm.limit}") + if rh.limit > 0: + parts.append(f"RPH: {_fmt_count(rh.remaining)}/{_fmt_count(rh.limit)} (resets {_fmt_seconds(rh.remaining_seconds_now)})") + if tm.limit > 0: + parts.append(f"TPM: {_fmt_count(tm.remaining)}/{_fmt_count(tm.limit)}") + if th.limit > 0: + parts.append(f"TPH: {_fmt_count(th.remaining)}/{_fmt_count(th.limit)} (resets {_fmt_seconds(th.remaining_seconds_now)})") + + return " | ".join(parts) diff --git a/cli.py b/cli.py index 324bb056901..fa32ae9119d 100644 --- a/cli.py +++ b/cli.py @@ -5409,12 +5409,27 @@ class HermesCLI: print(f" ❌ Compression failed: {e}") def _show_usage(self): - """Show cumulative token usage for the current session.""" + """Show rate limits (if available) and session token usage.""" if not self.agent: print("(._.) No active agent -- send a message first.") return agent = self.agent + calls = agent.session_api_calls + + if calls == 0: + print("(._.) No API calls made yet in this session.") + return + + # ── Rate limits (shown first when available) ──────────────── + rl_state = agent.get_rate_limit_state() + if rl_state and rl_state.has_data: + from agent.rate_limit_tracker import format_rate_limit_display + print() + print(format_rate_limit_display(rl_state)) + print() + + # ── Session token usage ───────────────────────────────────── input_tokens = getattr(agent, "session_input_tokens", 0) or 0 output_tokens = getattr(agent, "session_output_tokens", 0) or 0 cache_read_tokens = getattr(agent, "session_cache_read_tokens", 0) or 0 @@ -5422,13 +5437,7 @@ class HermesCLI: prompt = agent.session_prompt_tokens completion = agent.session_completion_tokens total = agent.session_total_tokens - calls = agent.session_api_calls - if calls == 0: - print("(._.) No API calls made yet in this session.") - return - - # Current context window state compressor = agent.context_compressor last_prompt = compressor.last_prompt_tokens ctx_len = compressor.context_length diff --git a/gateway/run.py b/gateway/run.py index 27703a10248..339954f5bef 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -5280,19 +5280,28 @@ class GatewayRunner: agent = self._running_agents.get(session_key) if agent and hasattr(agent, "session_total_tokens") and agent.session_api_calls > 0: - lines = [ - "📊 **Session Token Usage**", - f"Prompt (input): {agent.session_prompt_tokens:,}", - f"Completion (output): {agent.session_completion_tokens:,}", - f"Total: {agent.session_total_tokens:,}", - f"API calls: {agent.session_api_calls}", - ] + lines = [] + + # Rate limits first (when available from provider headers) + rl_state = agent.get_rate_limit_state() + if rl_state and rl_state.has_data: + from agent.rate_limit_tracker import format_rate_limit_compact + lines.append(f"⏱️ **Rate Limits:** {format_rate_limit_compact(rl_state)}") + lines.append("") + + # Session token usage + lines.append("📊 **Session Token Usage**") + lines.append(f"Prompt (input): {agent.session_prompt_tokens:,}") + lines.append(f"Completion (output): {agent.session_completion_tokens:,}") + lines.append(f"Total: {agent.session_total_tokens:,}") + lines.append(f"API calls: {agent.session_api_calls}") ctx = agent.context_compressor if ctx.last_prompt_tokens: pct = min(100, ctx.last_prompt_tokens / ctx.context_length * 100) if ctx.context_length else 0 lines.append(f"Context: {ctx.last_prompt_tokens:,} / {ctx.context_length:,} ({pct:.0f}%)") if ctx.compression_count: lines.append(f"Compressions: {ctx.compression_count}") + return "\n".join(lines) # No running agent -- check session history for a rough count diff --git a/hermes_cli/commands.py b/hermes_cli/commands.py index 39dc4569cd8..70d9cb8aa30 100644 --- a/hermes_cli/commands.py +++ b/hermes_cli/commands.py @@ -129,7 +129,7 @@ COMMAND_REGISTRY: list[CommandDef] = [ CommandDef("commands", "Browse all commands and skills (paginated)", "Info", gateway_only=True, args_hint="[page]"), CommandDef("help", "Show available commands", "Info"), - CommandDef("usage", "Show token usage for the current session", "Info"), + CommandDef("usage", "Show token usage and rate limits for the current session", "Info"), CommandDef("insights", "Show usage insights and analytics", "Info", args_hint="[days]"), CommandDef("platforms", "Show gateway/messaging platform status", "Info", diff --git a/run_agent.py b/run_agent.py index 3c5661a112d..ecd0be656fc 100644 --- a/run_agent.py +++ b/run_agent.py @@ -692,6 +692,10 @@ class AIAgent: self._current_tool: str | None = None self._api_call_count: int = 0 + # Rate limit tracking — updated from x-ratelimit-* response headers + # after each API call. Accessed by /usage slash command. + self._rate_limit_state: Optional["RateLimitState"] = None + # Centralized logging — agent.log (INFO+) and errors.log (WARNING+) # both live under ~/.hermes/logs/. Idempotent, so gateway mode # (which creates a new AIAgent per message) won't duplicate handlers. @@ -2545,6 +2549,29 @@ class AIAgent: self._last_activity_ts = time.time() self._last_activity_desc = desc + def _capture_rate_limits(self, http_response: Any) -> None: + """Parse x-ratelimit-* headers from an HTTP response and cache the state. + + Called after each streaming API call. The httpx Response object is + available on the OpenAI SDK Stream via ``stream.response``. + """ + if http_response is None: + return + headers = getattr(http_response, "headers", None) + if not headers: + return + try: + from agent.rate_limit_tracker import parse_rate_limit_headers + state = parse_rate_limit_headers(headers, provider=self.provider) + if state is not None: + self._rate_limit_state = state + except Exception: + pass # Never let header parsing break the agent loop + + def get_rate_limit_state(self): + """Return the last captured RateLimitState, or None.""" + return self._rate_limit_state + def get_activity_summary(self) -> dict: """Return a snapshot of the agent's current activity for diagnostics. @@ -4399,6 +4426,11 @@ class AIAgent: self._touch_activity("waiting for provider response (streaming)") stream = request_client_holder["client"].chat.completions.create(**stream_kwargs) + # Capture rate limit headers from the initial HTTP response. + # The OpenAI SDK Stream object exposes the underlying httpx + # response via .response before any chunks are consumed. + self._capture_rate_limits(getattr(stream, "response", None)) + content_parts: list = [] tool_calls_acc: dict = {} tool_gen_notified: set = set() diff --git a/tests/agent/test_rate_limit_tracker.py b/tests/agent/test_rate_limit_tracker.py new file mode 100644 index 00000000000..caef785678b --- /dev/null +++ b/tests/agent/test_rate_limit_tracker.py @@ -0,0 +1,212 @@ +"""Tests for agent.rate_limit_tracker — header parsing and formatting.""" + +import time +import pytest +from agent.rate_limit_tracker import ( + RateLimitBucket, + RateLimitState, + parse_rate_limit_headers, + format_rate_limit_display, + format_rate_limit_compact, + _fmt_count, + _fmt_seconds, + _bar, +) + + +# ── Sample headers from Nous inference API ────────────────────────────── + +NOUS_HEADERS = { + "x-ratelimit-limit-requests": "800", + "x-ratelimit-limit-requests-1h": "33600", + "x-ratelimit-limit-tokens": "8000000", + "x-ratelimit-limit-tokens-1h": "336000000", + "x-ratelimit-remaining-requests": "795", + "x-ratelimit-remaining-requests-1h": "33590", + "x-ratelimit-remaining-tokens": "7999500", + "x-ratelimit-remaining-tokens-1h": "335999000", + "x-ratelimit-reset-requests": "45.5", + "x-ratelimit-reset-requests-1h": "3500.0", + "x-ratelimit-reset-tokens": "42.3", + "x-ratelimit-reset-tokens-1h": "3490.0", +} + + +class TestParseHeaders: + def test_basic_parsing(self): + state = parse_rate_limit_headers(NOUS_HEADERS, provider="nous") + assert state is not None + assert state.provider == "nous" + assert state.has_data + + assert state.requests_min.limit == 800 + assert state.requests_min.remaining == 795 + assert state.requests_min.reset_seconds == 45.5 + + assert state.requests_hour.limit == 33600 + assert state.requests_hour.remaining == 33590 + + assert state.tokens_min.limit == 8000000 + assert state.tokens_min.remaining == 7999500 + + assert state.tokens_hour.limit == 336000000 + assert state.tokens_hour.remaining == 335999000 + assert state.tokens_hour.reset_seconds == 3490.0 + + def test_no_headers(self): + state = parse_rate_limit_headers({}) + assert state is None + + def test_partial_headers(self): + headers = { + "x-ratelimit-limit-requests": "100", + "x-ratelimit-remaining-requests": "50", + } + state = parse_rate_limit_headers(headers) + assert state is not None + assert state.requests_min.limit == 100 + assert state.requests_min.remaining == 50 + # Missing fields default to 0 + assert state.tokens_min.limit == 0 + + def test_non_rate_limit_headers_ignored(self): + headers = { + "content-type": "application/json", + "server": "nginx", + } + state = parse_rate_limit_headers(headers) + assert state is None + + def test_malformed_values(self): + headers = { + "x-ratelimit-limit-requests": "not-a-number", + "x-ratelimit-remaining-requests": "", + "x-ratelimit-reset-requests": "abc", + } + state = parse_rate_limit_headers(headers) + assert state is not None + assert state.requests_min.limit == 0 + assert state.requests_min.remaining == 0 + assert state.requests_min.reset_seconds == 0.0 + + +class TestBucket: + def test_used(self): + b = RateLimitBucket(limit=800, remaining=795, reset_seconds=45.0, captured_at=time.time()) + assert b.used == 5 + + def test_usage_pct(self): + b = RateLimitBucket(limit=100, remaining=20, reset_seconds=30.0, captured_at=time.time()) + assert b.usage_pct == pytest.approx(80.0) + + def test_usage_pct_zero_limit(self): + b = RateLimitBucket(limit=0, remaining=0) + assert b.usage_pct == 0.0 + + def test_remaining_seconds_now(self): + now = time.time() + b = RateLimitBucket(limit=800, remaining=795, reset_seconds=60.0, captured_at=now - 10) + # ~50 seconds should remain + assert 49 <= b.remaining_seconds_now <= 51 + + def test_remaining_seconds_expired(self): + b = RateLimitBucket(limit=800, remaining=795, reset_seconds=30.0, captured_at=time.time() - 60) + assert b.remaining_seconds_now == 0.0 + + +class TestFormatting: + def test_fmt_count_millions(self): + assert _fmt_count(8000000) == "8.0M" + assert _fmt_count(336000000) == "336.0M" + + def test_fmt_count_thousands(self): + assert _fmt_count(33600) == "33.6K" + assert _fmt_count(1500) == "1.5K" + + def test_fmt_count_small(self): + assert _fmt_count(800) == "800" + assert _fmt_count(0) == "0" + + def test_fmt_seconds_short(self): + assert _fmt_seconds(45) == "45s" + assert _fmt_seconds(0) == "0s" + + def test_fmt_seconds_minutes(self): + assert _fmt_seconds(125) == "2m 5s" + assert _fmt_seconds(120) == "2m" + + def test_fmt_seconds_hours(self): + assert _fmt_seconds(3660) == "1h 1m" + assert _fmt_seconds(3600) == "1h" + + def test_bar(self): + bar = _bar(50.0, width=10) + assert bar == "[█████░░░░░]" + assert _bar(0.0, width=10) == "[░░░░░░░░░░]" + assert _bar(100.0, width=10) == "[██████████]" + + def test_format_display_no_data(self): + state = RateLimitState() + result = format_rate_limit_display(state) + assert "No rate limit data" in result + + def test_format_display_with_data(self): + state = parse_rate_limit_headers(NOUS_HEADERS, provider="nous") + result = format_rate_limit_display(state) + assert "Nous" in result + assert "Requests/min" in result + assert "Requests/hr" in result + assert "Tokens/min" in result + assert "Tokens/hr" in result + assert "resets in" in result + + def test_format_display_warning_on_high_usage(self): + headers = { + **NOUS_HEADERS, + "x-ratelimit-remaining-requests": "50", # 750/800 used = 93.75% + } + state = parse_rate_limit_headers(headers) + result = format_rate_limit_display(state) + assert "⚠" in result + + def test_format_compact(self): + state = parse_rate_limit_headers(NOUS_HEADERS, provider="nous") + result = format_rate_limit_compact(state) + assert "RPM:" in result + assert "RPH:" in result + assert "TPM:" in result + assert "TPH:" in result + assert "resets" in result + + def test_format_compact_no_data(self): + state = RateLimitState() + result = format_rate_limit_compact(state) + assert "No rate limit data" in result + + +class TestAgentIntegration: + """Test that AIAgent captures rate limit state correctly.""" + + def test_capture_rate_limits_from_headers(self): + """Simulate the header capture path without a real API call.""" + import sys + import os + # Use a mock httpx-like response + class MockResponse: + headers = NOUS_HEADERS + + # Import AIAgent minimally + from unittest.mock import MagicMock, patch + + # Test the parsing directly + state = parse_rate_limit_headers(MockResponse.headers, provider="nous") + assert state is not None + assert state.requests_min.limit == 800 + assert state.tokens_hour.limit == 336000000 + + def test_capture_rate_limits_none_response(self): + """_capture_rate_limits should handle None gracefully.""" + from agent.rate_limit_tracker import parse_rate_limit_headers + # None should not crash + result = parse_rate_limit_headers({}) + assert result is None From ad06bfccf0c22bfb2a114e40af7edc5b3a850e38 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Thu, 9 Apr 2026 03:56:40 -0700 Subject: [PATCH 40/49] =?UTF-8?q?fix:=20remove=20dead=20LLM=5FMODEL=20env?= =?UTF-8?q?=20var=20=E2=80=94=20add=20migration=20to=20clear=20stale=20.en?= =?UTF-8?q?v=20entries=20(#6543)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The old setup wizard (pre-March 2026) wrote LLM_MODEL to ~/.hermes/.env across 12 provider flows. Commit 9302690e removed the writes but never cleaned up existing .env files, leaving a dead variable that: - Nothing in the codebase reads (zero os.getenv calls) - The docs incorrectly claimed the gateway still used as fallback - Caused user confusion when debugging model resolution issues Changes: - config.py: Bump _config_version 12 → 13, add migration to clear LLM_MODEL and OPENAI_MODEL from .env (both dead since March 2026) - environment-variables.md: Remove LLM_MODEL row, fix HERMES_MODEL description to stop referencing it - providers.md: Update deprecation notice from 'deprecated' to 'removed' --- hermes_cli/config.py | 17 ++++++++++++++++- website/docs/integrations/providers.md | 2 +- website/docs/reference/environment-variables.md | 3 +-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 387bef667d5..a981b1bbbf1 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -569,7 +569,7 @@ DEFAULT_CONFIG = { }, # Config schema version - bump this when adding new required fields - "_config_version": 12, + "_config_version": 13, } # ============================================================================= @@ -1701,6 +1701,21 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A ep = providers_dict[key] print(f" → {key}: {ep.get('api', '')}") + # ── Version 12 → 13: clear dead LLM_MODEL / OPENAI_MODEL from .env ── + # These env vars were written by the old setup wizard but nothing reads + # them anymore (config.yaml is the sole source of truth since March 2026). + # Stale entries cause user confusion — see issue report. + if current_ver < 13: + for dead_var in ("LLM_MODEL", "OPENAI_MODEL"): + try: + old_val = get_env_value(dead_var) + if old_val: + save_env_value(dead_var, "") + if not quiet: + print(f" ✓ Cleared {dead_var} from .env (no longer used — config.yaml is source of truth)") + except Exception: + pass + if current_ver < latest_ver and not quiet: print(f"Config version: {current_ver} → {latest_ver}") diff --git a/website/docs/integrations/providers.md b/website/docs/integrations/providers.md index 74d4e631ae7..fbfa69ade6e 100644 --- a/website/docs/integrations/providers.md +++ b/website/docs/integrations/providers.md @@ -230,7 +230,7 @@ model: ``` :::warning Legacy env vars -`OPENAI_BASE_URL` and `LLM_MODEL` in `.env` are **deprecated**. `OPENAI_BASE_URL` is no longer consulted for endpoint resolution — `config.yaml` is the single source of truth. The CLI ignores `LLM_MODEL` entirely (only the gateway reads it as a fallback). Use `hermes model` or edit `config.yaml` directly — both persist correctly across restarts and Docker containers. +`OPENAI_BASE_URL` and `LLM_MODEL` in `.env` are **removed**. Neither is read by any part of Hermes — `config.yaml` is the single source of truth for model and endpoint configuration. If you have stale entries in your `.env`, they are automatically cleared on the next `hermes setup` or config migration. Use `hermes model` or edit `config.yaml` directly. ::: Both approaches persist to `config.yaml`, which is the source of truth for model, provider, and base URL. diff --git a/website/docs/reference/environment-variables.md b/website/docs/reference/environment-variables.md index e8f2e8aee6b..7c14d9f3da4 100644 --- a/website/docs/reference/environment-variables.md +++ b/website/docs/reference/environment-variables.md @@ -53,8 +53,7 @@ All variables go in `~/.hermes/.env`. You can also set them with `hermes config | `OPENCODE_GO_API_KEY` | OpenCode Go API key — $10/month subscription for open models ([opencode.ai](https://opencode.ai/auth)) | | `OPENCODE_GO_BASE_URL` | Override OpenCode Go base URL | | `CLAUDE_CODE_OAUTH_TOKEN` | Explicit Claude Code token override if you export one manually | -| `HERMES_MODEL` | Preferred model name (checked before `LLM_MODEL`, used by gateway) | -| `LLM_MODEL` | Default model name (fallback when not set in config.yaml) | +| `HERMES_MODEL` | Override model name at process level (used by cron scheduler; prefer `config.yaml` for normal use) | | `VOICE_TOOLS_OPENAI_KEY` | Preferred OpenAI key for OpenAI speech-to-text and text-to-speech providers | | `HERMES_LOCAL_STT_COMMAND` | Optional local speech-to-text command template. Supports `{input_path}`, `{output_dir}`, `{language}`, and `{model}` placeholders | | `HERMES_LOCAL_STT_LANGUAGE` | Default language passed to `HERMES_LOCAL_STT_COMMAND` or auto-detected local `whisper` CLI fallback (default: `en`) | From b650957b405b5160b7d4b55758d240e344203d3b Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Thu, 9 Apr 2026 03:57:11 -0700 Subject: [PATCH 41/49] docs(bluebubbles): fix pairing instructions to use existing approve flow (#6548) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The docs incorrectly referenced 'hermes pairing generate bluebubbles' which doesn't exist. The existing reactive pairing flow already handles this — when an unknown user messages the bot, it sends them a code automatically, and the owner approves with 'hermes pairing approve'. --- website/docs/user-guide/messaging/bluebubbles.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/website/docs/user-guide/messaging/bluebubbles.md b/website/docs/user-guide/messaging/bluebubbles.md index 3f023d31792..cde9690316a 100644 --- a/website/docs/user-guide/messaging/bluebubbles.md +++ b/website/docs/user-guide/messaging/bluebubbles.md @@ -43,17 +43,18 @@ BLUEBUBBLES_PASSWORD=your-server-password Choose one approach: **DM Pairing (recommended):** +When someone messages your iMessage, Hermes automatically sends them a pairing code. Approve it with: ```bash -hermes pairing generate bluebubbles +hermes pairing approve bluebubbles ``` -Share the pairing code — the user sends it via iMessage to get approved. +Use `hermes pairing list` to see pending codes and approved users. -**Pre-authorize specific users:** +**Pre-authorize specific users** (in `~/.hermes/.env`): ```bash BLUEBUBBLES_ALLOWED_USERS=user@icloud.com,+15551234567 ``` -**Open access:** +**Open access** (in `~/.hermes/.env`): ```bash BLUEBUBBLES_ALLOW_ALL_USERS=true ``` From 78e6b06518148309a598cdb3267f9036a7315b62 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Thu, 9 Apr 2026 04:00:41 -0700 Subject: [PATCH 42/49] feat: add 'hermes dump' command for copy-pasteable setup summary (#6550) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a new CLI command that outputs a compact, plain-text dump of the user's Hermes setup — version, OS, model/provider, API key presence, toolsets, gateway status, platforms, cron jobs, skills, and any non-default config overrides. Designed for support context: no ANSI colors, ready to paste into Discord/GitHub/Telegram. Secrets shown as 'set/not set' by default; --show-keys reveals redacted prefixes (first/last 4 chars). Files: - hermes_cli/dump.py (new) — run_dump() implementation - hermes_cli/main.py — parser + cmd_dump wiring - hermes_cli/profiles.py — shell completions + subcommand set --- hermes_cli/dump.py | 337 +++++++++++++++++++++++++++++++++++++++++ hermes_cli/main.py | 22 +++ hermes_cli/profiles.py | 6 +- 3 files changed, 362 insertions(+), 3 deletions(-) create mode 100644 hermes_cli/dump.py diff --git a/hermes_cli/dump.py b/hermes_cli/dump.py new file mode 100644 index 00000000000..4ad32ca2c17 --- /dev/null +++ b/hermes_cli/dump.py @@ -0,0 +1,337 @@ +""" +Dump command for hermes CLI. + +Outputs a compact, plain-text summary of the user's Hermes setup +that can be copy-pasted into Discord/GitHub/Telegram for support context. +No ANSI colors, no checkmarks — just data. +""" + +import json +import os +import platform +import subprocess +import sys +from pathlib import Path + +from hermes_cli.config import get_hermes_home, get_env_path, get_project_root, load_config +from hermes_constants import display_hermes_home + + +def _get_git_commit(project_root: Path) -> str: + """Return short git commit hash, or '(unknown)'.""" + try: + result = subprocess.run( + ["git", "rev-parse", "--short=8", "HEAD"], + capture_output=True, text=True, timeout=5, + cwd=str(project_root), + ) + if result.returncode == 0: + return result.stdout.strip() + except Exception: + pass + return "(unknown)" + + +def _key_present(name: str) -> str: + """Return 'set' or 'not set' for an env var.""" + return "set" if os.getenv(name) else "not set" + + +def _redact(value: str) -> str: + """Redact all but first 4 and last 4 chars.""" + if not value: + return "" + if len(value) < 12: + return "***" + return value[:4] + "..." + value[-4:] + + +def _gateway_status() -> str: + """Return a short gateway status string.""" + if sys.platform.startswith("linux"): + try: + from hermes_cli.gateway import get_service_name + svc = get_service_name() + except Exception: + svc = "hermes-gateway" + try: + r = subprocess.run( + ["systemctl", "--user", "is-active", svc], + capture_output=True, text=True, timeout=5, + ) + return "running (systemd)" if r.stdout.strip() == "active" else "stopped" + except Exception: + return "unknown" + elif sys.platform == "darwin": + try: + from hermes_cli.gateway import get_launchd_label + r = subprocess.run( + ["launchctl", "list", get_launchd_label()], + capture_output=True, text=True, timeout=5, + ) + return "loaded (launchd)" if r.returncode == 0 else "not loaded" + except Exception: + return "unknown" + return "N/A" + + +def _count_skills(hermes_home: Path) -> int: + """Count installed skills.""" + skills_dir = hermes_home / "skills" + if not skills_dir.is_dir(): + return 0 + count = 0 + for item in skills_dir.rglob("SKILL.md"): + count += 1 + return count + + +def _count_mcp_servers(config: dict) -> int: + """Count configured MCP servers.""" + mcp = config.get("mcp", {}) + servers = mcp.get("servers", {}) + return len(servers) + + +def _cron_summary(hermes_home: Path) -> str: + """Return cron jobs summary.""" + jobs_file = hermes_home / "cron" / "jobs.json" + if not jobs_file.exists(): + return "0" + try: + with open(jobs_file, encoding="utf-8") as f: + data = json.load(f) + jobs = data.get("jobs", []) + active = sum(1 for j in jobs if j.get("enabled", True)) + return f"{active} active / {len(jobs)} total" + except Exception: + return "(error reading)" + + +def _configured_platforms() -> list[str]: + """Return list of configured messaging platform names.""" + checks = { + "telegram": "TELEGRAM_BOT_TOKEN", + "discord": "DISCORD_BOT_TOKEN", + "slack": "SLACK_BOT_TOKEN", + "whatsapp": "WHATSAPP_ENABLED", + "signal": "SIGNAL_HTTP_URL", + "email": "EMAIL_ADDRESS", + "sms": "TWILIO_ACCOUNT_SID", + "matrix": "MATRIX_HOMESERVER_URL", + "mattermost": "MATTERMOST_URL", + "homeassistant": "HASS_TOKEN", + "dingtalk": "DINGTALK_CLIENT_ID", + "feishu": "FEISHU_APP_ID", + "wecom": "WECOM_BOT_ID", + } + return [name for name, env in checks.items() if os.getenv(env)] + + +def _memory_provider(config: dict) -> str: + """Return the active memory provider name.""" + mem = config.get("memory", {}) + provider = mem.get("provider", "") + return provider if provider else "built-in" + + +def _get_model_and_provider(config: dict) -> tuple[str, str]: + """Extract model and provider from config.""" + model_cfg = config.get("model", "") + if isinstance(model_cfg, dict): + model = model_cfg.get("default") or model_cfg.get("model") or model_cfg.get("name") or "(not set)" + provider = model_cfg.get("provider") or "(auto)" + elif isinstance(model_cfg, str): + model = model_cfg or "(not set)" + provider = "(auto)" + else: + model = "(not set)" + provider = "(auto)" + return model, provider + + +def _config_overrides(config: dict) -> dict[str, str]: + """Find non-default config values worth reporting. + + Returns a flat dict of dotpath -> value for interesting overrides. + """ + from hermes_cli.config import DEFAULT_CONFIG + + overrides = {} + + # Sections with interesting user-facing overrides + interesting_paths = [ + ("agent", "max_turns"), + ("agent", "gateway_timeout"), + ("agent", "tool_use_enforcement"), + ("terminal", "backend"), + ("terminal", "docker_image"), + ("terminal", "persistent_shell"), + ("browser", "allow_private_urls"), + ("compression", "enabled"), + ("compression", "threshold"), + ("display", "streaming"), + ("display", "skin"), + ("display", "show_reasoning"), + ("smart_model_routing", "enabled"), + ("privacy", "redact_pii"), + ("tts", "provider"), + ] + + for section, key in interesting_paths: + default_section = DEFAULT_CONFIG.get(section, {}) + user_section = config.get(section, {}) + if not isinstance(default_section, dict) or not isinstance(user_section, dict): + continue + default_val = default_section.get(key) + user_val = user_section.get(key) + if user_val is not None and user_val != default_val: + overrides[f"{section}.{key}"] = str(user_val) + + # Toolsets (if different from default) + default_toolsets = DEFAULT_CONFIG.get("toolsets", []) + user_toolsets = config.get("toolsets", []) + if user_toolsets != default_toolsets: + overrides["toolsets"] = str(user_toolsets) + + # Fallback providers + fallbacks = config.get("fallback_providers", []) + if fallbacks: + overrides["fallback_providers"] = str(fallbacks) + + return overrides + + +def run_dump(args): + """Output a compact, copy-pasteable setup summary.""" + show_keys = getattr(args, "show_keys", False) + + # Load env from .env file so key checks work + from dotenv import load_dotenv + env_path = get_env_path() + if env_path.exists(): + try: + load_dotenv(env_path, encoding="utf-8") + except UnicodeDecodeError: + load_dotenv(env_path, encoding="latin-1") + # Also try project .env as dev fallback + load_dotenv(get_project_root() / ".env", override=False, encoding="utf-8") + + project_root = get_project_root() + hermes_home = get_hermes_home() + + try: + from hermes_cli import __version__, __release_date__ + except ImportError: + __version__ = "(unknown)" + __release_date__ = "" + + commit = _get_git_commit(project_root) + + try: + config = load_config() + except Exception: + config = {} + + model, provider = _get_model_and_provider(config) + + # Profile + try: + from hermes_cli.profiles import get_active_profile_name + profile = get_active_profile_name() or "(default)" + except Exception: + profile = "(default)" + + # Terminal backend + terminal_cfg = config.get("terminal", {}) + backend = terminal_cfg.get("backend", "local") + + # OpenAI SDK version + try: + import openai + openai_ver = openai.__version__ + except ImportError: + openai_ver = "not installed" + + # OS info + os_info = f"{platform.system()} {platform.release()} {platform.machine()}" + + lines = [] + lines.append("--- hermes dump ---") + ver_str = f"{__version__}" + if __release_date__: + ver_str += f" ({__release_date__})" + ver_str += f" [{commit}]" + lines.append(f"version: {ver_str}") + lines.append(f"os: {os_info}") + lines.append(f"python: {sys.version.split()[0]}") + lines.append(f"openai_sdk: {openai_ver}") + lines.append(f"profile: {profile}") + lines.append(f"hermes_home: {display_hermes_home()}") + lines.append(f"model: {model}") + lines.append(f"provider: {provider}") + lines.append(f"terminal: {backend}") + + # API keys + lines.append("") + lines.append("api_keys:") + api_keys = [ + ("OPENROUTER_API_KEY", "openrouter"), + ("OPENAI_API_KEY", "openai"), + ("ANTHROPIC_API_KEY", "anthropic"), + ("ANTHROPIC_TOKEN", "anthropic_token"), + ("NOUS_API_KEY", "nous"), + ("GLM_API_KEY", "glm/zai"), + ("ZAI_API_KEY", "zai"), + ("KIMI_API_KEY", "kimi"), + ("MINIMAX_API_KEY", "minimax"), + ("DEEPSEEK_API_KEY", "deepseek"), + ("DASHSCOPE_API_KEY", "dashscope"), + ("HF_TOKEN", "huggingface"), + ("AI_GATEWAY_API_KEY", "ai_gateway"), + ("OPENCODE_ZEN_API_KEY", "opencode_zen"), + ("OPENCODE_GO_API_KEY", "opencode_go"), + ("KILOCODE_API_KEY", "kilocode"), + ("FIRECRAWL_API_KEY", "firecrawl"), + ("TAVILY_API_KEY", "tavily"), + ("BROWSERBASE_API_KEY", "browserbase"), + ("FAL_KEY", "fal"), + ("ELEVENLABS_API_KEY", "elevenlabs"), + ("GITHUB_TOKEN", "github"), + ] + + for env_var, label in api_keys: + val = os.getenv(env_var, "") + if show_keys and val: + display = _redact(val) + else: + display = "set" if val else "not set" + lines.append(f" {label:<20} {display}") + + # Features summary + lines.append("") + lines.append("features:") + + toolsets = config.get("toolsets", ["hermes-cli"]) + lines.append(f" toolsets: {', '.join(toolsets) if toolsets else '(default)'}") + lines.append(f" mcp_servers: {_count_mcp_servers(config)}") + lines.append(f" memory_provider: {_memory_provider(config)}") + lines.append(f" gateway: {_gateway_status()}") + + platforms = _configured_platforms() + lines.append(f" platforms: {', '.join(platforms) if platforms else 'none'}") + lines.append(f" cron_jobs: {_cron_summary(hermes_home)}") + lines.append(f" skills: {_count_skills(hermes_home)}") + + # Config overrides (non-default values) + overrides = _config_overrides(config) + if overrides: + lines.append("") + lines.append("config_overrides:") + for key, val in overrides.items(): + lines.append(f" {key}: {val}") + + lines.append("--- end dump ---") + + output = "\n".join(lines) + print(output) diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 96345c48504..a6d616e68bc 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -2643,6 +2643,12 @@ def cmd_doctor(args): run_doctor(args) +def cmd_dump(args): + """Dump setup summary for support/debugging.""" + from hermes_cli.dump import run_dump + run_dump(args) + + def cmd_config(args): """Configuration management.""" from hermes_cli.config import config_command @@ -4724,6 +4730,22 @@ For more help on a command: help="Attempt to fix issues automatically" ) doctor_parser.set_defaults(func=cmd_doctor) + + # ========================================================================= + # dump command + # ========================================================================= + dump_parser = subparsers.add_parser( + "dump", + help="Dump setup summary for support/debugging", + description="Output a compact, plain-text summary of your Hermes setup " + "that can be copy-pasted into Discord/GitHub for support context" + ) + dump_parser.add_argument( + "--show-keys", + action="store_true", + help="Show redacted API key prefixes (first/last 4 chars) instead of just set/not set" + ) + dump_parser.set_defaults(func=cmd_dump) # ========================================================================= # config command diff --git a/hermes_cli/profiles.py b/hermes_cli/profiles.py index 48ecbc4ca4d..9be25e10079 100644 --- a/hermes_cli/profiles.py +++ b/hermes_cli/profiles.py @@ -102,7 +102,7 @@ _RESERVED_NAMES = frozenset({ # Hermes subcommands that cannot be used as profile names/aliases _HERMES_SUBCOMMANDS = frozenset({ "chat", "model", "gateway", "setup", "whatsapp", "login", "logout", - "status", "cron", "doctor", "config", "pairing", "skills", "tools", + "status", "cron", "doctor", "dump", "config", "pairing", "skills", "tools", "mcp", "sessions", "insights", "version", "update", "uninstall", "profile", "plugins", "honcho", "acp", }) @@ -1007,7 +1007,7 @@ _hermes_completion() { # Top-level subcommands if [[ "$COMP_CWORD" == 1 ]]; then - local commands="chat model gateway setup status cron doctor config skills tools mcp sessions profile update version" + local commands="chat model gateway setup status cron doctor dump config skills tools mcp sessions profile update version" COMPREPLY=($(compgen -W "$commands" -- "$cur")) fi } @@ -1032,7 +1032,7 @@ _hermes() { _arguments \\ '-p[Profile name]:profile:($profiles)' \\ '--profile[Profile name]:profile:($profiles)' \\ - '1:command:(chat model gateway setup status cron doctor config skills tools mcp sessions profile update version)' \\ + '1:command:(chat model gateway setup status cron doctor dump config skills tools mcp sessions profile update version)' \\ '*::arg:->args' case $words[1] in From 1a3ae6ac6e2a0df226e472c87daf6c0cebe75beb Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Thu, 9 Apr 2026 04:10:11 -0700 Subject: [PATCH 43/49] feat: structured API error classification for smart failover (#6514) Add agent/error_classifier.py with a priority-ordered classification pipeline that replaces scattered inline string-matching in the retry loop with structured error taxonomy and recovery hints. FailoverReason enum (14 categories): auth, auth_permanent, billing, rate_limit, overloaded, server_error, timeout, context_overflow, payload_too_large, model_not_found, format_error, thinking_signature, long_context_tier, unknown. ClassifiedError dataclass carries reason + recovery action hints (retryable, should_compress, should_rotate_credential, should_fallback). Key improvements over inline matching: - 402 disambiguation: 'insufficient credits' = billing (immediate rotate), 'usage limit, try again' = rate_limit (backoff first) - OpenRouter 403 'key limit exceeded' correctly classified as billing - Error cause chain walking (walks __cause__/__context__ up to 5 levels) - Body message included in pattern matching (SDK str() misses it) - Server disconnect + large session check ordered before generic transport catch so RemoteProtocolError triggers compression when appropriate - Chinese error message support for context overflow run_agent.py: replaced 6 inline detection blocks with classifier calls, net -55 lines. All recovery actions (pool rotation, fallback activation, compression, transport recovery) unchanged. 65 new unit tests + 10 E2E tests + live tests with real SDK error objects. Inspired by OpenClaw's failover error classification system. --- agent/error_classifier.py | 789 +++++++++++++++++++++++++++ run_agent.py | 190 +++---- tests/agent/test_error_classifier.py | 750 +++++++++++++++++++++++++ 3 files changed, 1607 insertions(+), 122 deletions(-) create mode 100644 agent/error_classifier.py create mode 100644 tests/agent/test_error_classifier.py diff --git a/agent/error_classifier.py b/agent/error_classifier.py new file mode 100644 index 00000000000..b227932ad7c --- /dev/null +++ b/agent/error_classifier.py @@ -0,0 +1,789 @@ +"""API error classification for smart failover and recovery. + +Provides a structured taxonomy of API errors and a priority-ordered +classification pipeline that determines the correct recovery action +(retry, rotate credential, fallback to another provider, compress +context, or abort). + +Replaces scattered inline string-matching with a centralized classifier +that the main retry loop in run_agent.py consults for every API failure. +""" + +from __future__ import annotations + +import enum +import logging +import re +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + + +# ── Error taxonomy ────────────────────────────────────────────────────── + +class FailoverReason(enum.Enum): + """Why an API call failed — determines recovery strategy.""" + + # Authentication / authorization + auth = "auth" # Transient auth (401/403) — refresh/rotate + auth_permanent = "auth_permanent" # Auth failed after refresh — abort + + # Billing / quota + billing = "billing" # 402 or confirmed credit exhaustion — rotate immediately + rate_limit = "rate_limit" # 429 or quota-based throttling — backoff then rotate + + # Server-side + overloaded = "overloaded" # 503/529 — provider overloaded, backoff + server_error = "server_error" # 500/502 — internal server error, retry + + # Transport + timeout = "timeout" # Connection/read timeout — rebuild client + retry + + # Context / payload + context_overflow = "context_overflow" # Context too large — compress, not failover + payload_too_large = "payload_too_large" # 413 — compress payload + + # Model + model_not_found = "model_not_found" # 404 or invalid model — fallback to different model + + # Request format + format_error = "format_error" # 400 bad request — abort or strip + retry + + # Provider-specific + thinking_signature = "thinking_signature" # Anthropic thinking block sig invalid + long_context_tier = "long_context_tier" # Anthropic "extra usage" tier gate + + # Catch-all + unknown = "unknown" # Unclassifiable — retry with backoff + + +# ── Classification result ─────────────────────────────────────────────── + +@dataclass +class ClassifiedError: + """Structured classification of an API error with recovery hints.""" + + reason: FailoverReason + status_code: Optional[int] = None + provider: Optional[str] = None + model: Optional[str] = None + message: str = "" + error_context: Dict[str, Any] = field(default_factory=dict) + + # Recovery action hints — the retry loop checks these instead of + # re-classifying the error itself. + retryable: bool = True + should_compress: bool = False + should_rotate_credential: bool = False + should_fallback: bool = False + + @property + def is_auth(self) -> bool: + return self.reason in (FailoverReason.auth, FailoverReason.auth_permanent) + + @property + def is_transient(self) -> bool: + """Error is expected to resolve on retry (with or without backoff).""" + return self.reason in ( + FailoverReason.rate_limit, + FailoverReason.overloaded, + FailoverReason.server_error, + FailoverReason.timeout, + FailoverReason.unknown, + ) + + +# ── Provider-specific patterns ────────────────────────────────────────── + +# Patterns that indicate billing exhaustion (not transient rate limit) +_BILLING_PATTERNS = [ + "insufficient credits", + "insufficient_quota", + "credit balance", + "credits have been exhausted", + "top up your credits", + "payment required", + "billing hard limit", + "exceeded your current quota", + "account is deactivated", + "plan does not include", +] + +# Patterns that indicate rate limiting (transient, will resolve) +_RATE_LIMIT_PATTERNS = [ + "rate limit", + "rate_limit", + "too many requests", + "throttled", + "requests per minute", + "tokens per minute", + "requests per day", + "try again in", + "please retry after", + "resource_exhausted", +] + +# Usage-limit patterns that need disambiguation (could be billing OR rate_limit) +_USAGE_LIMIT_PATTERNS = [ + "usage limit", + "quota", + "limit exceeded", + "key limit exceeded", +] + +# Patterns confirming usage limit is transient (not billing) +_USAGE_LIMIT_TRANSIENT_SIGNALS = [ + "try again", + "retry", + "resets at", + "reset in", + "wait", + "requests remaining", + "periodic", + "window", +] + +# Payload-too-large patterns detected from message text (no status_code attr). +# Proxies and some backends embed the HTTP status in the error message. +_PAYLOAD_TOO_LARGE_PATTERNS = [ + "request entity too large", + "payload too large", + "error code: 413", +] + +# Context overflow patterns +_CONTEXT_OVERFLOW_PATTERNS = [ + "context length", + "context size", + "maximum context", + "token limit", + "too many tokens", + "reduce the length", + "exceeds the limit", + "context window", + "prompt is too long", + "prompt exceeds max length", + "max_tokens", + "maximum number of tokens", + # Chinese error messages (some providers return these) + "超过最大长度", + "上下文长度", +] + +# Model not found patterns +_MODEL_NOT_FOUND_PATTERNS = [ + "is not a valid model", + "invalid model", + "model not found", + "model_not_found", + "does not exist", + "no such model", + "unknown model", + "unsupported model", +] + +# Auth patterns (non-status-code signals) +_AUTH_PATTERNS = [ + "invalid api key", + "invalid_api_key", + "authentication", + "unauthorized", + "forbidden", + "invalid token", + "token expired", + "token revoked", + "access denied", +] + +# Anthropic thinking block signature patterns +_THINKING_SIG_PATTERNS = [ + "signature", # Combined with "thinking" check +] + +# Transport error type names +_TRANSPORT_ERROR_TYPES = frozenset({ + "ReadTimeout", "ConnectTimeout", "PoolTimeout", + "ConnectError", "RemoteProtocolError", + "ConnectionError", "ConnectionResetError", + "ConnectionAbortedError", "BrokenPipeError", + "TimeoutError", "ReadError", + "ServerDisconnectedError", + # OpenAI SDK errors (not subclasses of Python builtins) + "APIConnectionError", + "APITimeoutError", +}) + +# Server disconnect patterns (no status code, but transport-level) +_SERVER_DISCONNECT_PATTERNS = [ + "server disconnected", + "peer closed connection", + "connection reset by peer", + "connection was closed", + "network connection lost", + "unexpected eof", + "incomplete chunked read", +] + + +# ── Classification pipeline ───────────────────────────────────────────── + +def classify_api_error( + error: Exception, + *, + provider: str = "", + model: str = "", + approx_tokens: int = 0, + context_length: int = 200000, + num_messages: int = 0, +) -> ClassifiedError: + """Classify an API error into a structured recovery recommendation. + + Priority-ordered pipeline: + 1. Special-case provider-specific patterns (thinking sigs, tier gates) + 2. HTTP status code + message-aware refinement + 3. Error code classification (from body) + 4. Message pattern matching (billing vs rate_limit vs context vs auth) + 5. Transport error heuristics + 6. Server disconnect + large session → context overflow + 7. Fallback: unknown (retryable with backoff) + + Args: + error: The exception from the API call. + provider: Current provider name (e.g. "openrouter", "anthropic"). + model: Current model slug. + approx_tokens: Approximate token count of the current context. + context_length: Maximum context length for the current model. + + Returns: + ClassifiedError with reason and recovery action hints. + """ + status_code = _extract_status_code(error) + error_type = type(error).__name__ + body = _extract_error_body(error) + error_code = _extract_error_code(body) + + # Build a comprehensive error message string for pattern matching. + # str(error) alone may not include the body message (e.g. OpenAI SDK's + # APIStatusError.__str__ returns the first arg, not the body). Append + # the body message so patterns like "try again" in 402 disambiguation + # are detected even when only present in the structured body. + # + # Also extract metadata.raw — OpenRouter wraps upstream provider errors + # inside {"error": {"message": "Provider returned error", "metadata": + # {"raw": ""}}} and the real error message (e.g. + # "context length exceeded") is only in the inner JSON. + _raw_msg = str(error).lower() + _body_msg = "" + _metadata_msg = "" + if isinstance(body, dict): + _err_obj = body.get("error", {}) + if isinstance(_err_obj, dict): + _body_msg = (_err_obj.get("message") or "").lower() + # Parse metadata.raw for wrapped provider errors + _metadata = _err_obj.get("metadata", {}) + if isinstance(_metadata, dict): + _raw_json = _metadata.get("raw") or "" + if isinstance(_raw_json, str) and _raw_json.strip(): + try: + import json + _inner = json.loads(_raw_json) + if isinstance(_inner, dict): + _inner_err = _inner.get("error", {}) + if isinstance(_inner_err, dict): + _metadata_msg = (_inner_err.get("message") or "").lower() + except (json.JSONDecodeError, TypeError): + pass + if not _body_msg: + _body_msg = (body.get("message") or "").lower() + # Combine all message sources for pattern matching + parts = [_raw_msg] + if _body_msg and _body_msg not in _raw_msg: + parts.append(_body_msg) + if _metadata_msg and _metadata_msg not in _raw_msg and _metadata_msg not in _body_msg: + parts.append(_metadata_msg) + error_msg = " ".join(parts) + provider_lower = (provider or "").strip().lower() + model_lower = (model or "").strip().lower() + + def _result(reason: FailoverReason, **overrides) -> ClassifiedError: + defaults = { + "reason": reason, + "status_code": status_code, + "provider": provider, + "model": model, + "message": _extract_message(error, body), + } + defaults.update(overrides) + return ClassifiedError(**defaults) + + # ── 1. Provider-specific patterns (highest priority) ──────────── + + # Anthropic thinking block signature invalid (400). + # Don't gate on provider — OpenRouter proxies Anthropic errors, so the + # provider may be "openrouter" even though the error is Anthropic-specific. + # The message pattern ("signature" + "thinking") is unique enough. + if ( + status_code == 400 + and "signature" in error_msg + and "thinking" in error_msg + ): + return _result( + FailoverReason.thinking_signature, + retryable=True, + should_compress=False, + ) + + # Anthropic long-context tier gate (429 "extra usage" + "long context") + if ( + status_code == 429 + and "extra usage" in error_msg + and "long context" in error_msg + ): + return _result( + FailoverReason.long_context_tier, + retryable=True, + should_compress=True, + ) + + # ── 2. HTTP status code classification ────────────────────────── + + if status_code is not None: + classified = _classify_by_status( + status_code, error_msg, error_code, body, + provider=provider_lower, model=model_lower, + approx_tokens=approx_tokens, context_length=context_length, + num_messages=num_messages, + result_fn=_result, + ) + if classified is not None: + return classified + + # ── 3. Error code classification ──────────────────────────────── + + if error_code: + classified = _classify_by_error_code(error_code, error_msg, _result) + if classified is not None: + return classified + + # ── 4. Message pattern matching (no status code) ──────────────── + + classified = _classify_by_message( + error_msg, error_type, + approx_tokens=approx_tokens, + context_length=context_length, + result_fn=_result, + ) + if classified is not None: + return classified + + # ── 5. Server disconnect + large session → context overflow ───── + # Must come BEFORE generic transport error catch — a disconnect on + # a large session is more likely context overflow than a transient + # transport hiccup. Without this ordering, RemoteProtocolError + # always maps to timeout regardless of session size. + + is_disconnect = any(p in error_msg for p in _SERVER_DISCONNECT_PATTERNS) + if is_disconnect and not status_code: + is_large = approx_tokens > context_length * 0.6 or approx_tokens > 120000 or num_messages > 200 + if is_large: + return _result( + FailoverReason.context_overflow, + retryable=True, + should_compress=True, + ) + return _result(FailoverReason.timeout, retryable=True) + + # ── 6. Transport / timeout heuristics ─────────────────────────── + + if error_type in _TRANSPORT_ERROR_TYPES or isinstance(error, (TimeoutError, ConnectionError, OSError)): + return _result(FailoverReason.timeout, retryable=True) + + # ── 7. Fallback: unknown ──────────────────────────────────────── + + return _result(FailoverReason.unknown, retryable=True) + + +# ── Status code classification ────────────────────────────────────────── + +def _classify_by_status( + status_code: int, + error_msg: str, + error_code: str, + body: dict, + *, + provider: str, + model: str, + approx_tokens: int, + context_length: int, + num_messages: int = 0, + result_fn, +) -> Optional[ClassifiedError]: + """Classify based on HTTP status code with message-aware refinement.""" + + if status_code == 401: + # Not retryable on its own — credential pool rotation and + # provider-specific refresh (Codex, Anthropic, Nous) run before + # the retryability check in run_agent.py. If those succeed, the + # loop `continue`s. If they fail, retryable=False ensures we + # hit the client-error abort path (which tries fallback first). + return result_fn( + FailoverReason.auth, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + + if status_code == 403: + # OpenRouter 403 "key limit exceeded" is actually billing + if "key limit exceeded" in error_msg or "spending limit" in error_msg: + return result_fn( + FailoverReason.billing, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + return result_fn( + FailoverReason.auth, + retryable=False, + should_fallback=True, + ) + + if status_code == 402: + return _classify_402(error_msg, result_fn) + + if status_code == 404: + if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS): + return result_fn( + FailoverReason.model_not_found, + retryable=False, + should_fallback=True, + ) + # Generic 404 — could be model or endpoint + return result_fn( + FailoverReason.model_not_found, + retryable=False, + should_fallback=True, + ) + + if status_code == 413: + return result_fn( + FailoverReason.payload_too_large, + retryable=True, + should_compress=True, + ) + + if status_code == 429: + # Already checked long_context_tier above; this is a normal rate limit + return result_fn( + FailoverReason.rate_limit, + retryable=True, + should_rotate_credential=True, + should_fallback=True, + ) + + if status_code == 400: + return _classify_400( + error_msg, error_code, body, + provider=provider, model=model, + approx_tokens=approx_tokens, + context_length=context_length, + num_messages=num_messages, + result_fn=result_fn, + ) + + if status_code in (500, 502): + return result_fn(FailoverReason.server_error, retryable=True) + + if status_code in (503, 529): + return result_fn(FailoverReason.overloaded, retryable=True) + + # Other 4xx — non-retryable + if 400 <= status_code < 500: + return result_fn( + FailoverReason.format_error, + retryable=False, + should_fallback=True, + ) + + # Other 5xx — retryable + if 500 <= status_code < 600: + return result_fn(FailoverReason.server_error, retryable=True) + + return None + + +def _classify_402(error_msg: str, result_fn) -> ClassifiedError: + """Disambiguate 402: billing exhaustion vs transient usage limit. + + The key insight from OpenClaw: some 402s are transient rate limits + disguised as payment errors. "Usage limit, try again in 5 minutes" + is NOT a billing problem — it's a periodic quota that resets. + """ + # Check for transient usage-limit signals first + has_usage_limit = any(p in error_msg for p in _USAGE_LIMIT_PATTERNS) + has_transient_signal = any(p in error_msg for p in _USAGE_LIMIT_TRANSIENT_SIGNALS) + + if has_usage_limit and has_transient_signal: + # Transient quota — treat as rate limit, not billing + return result_fn( + FailoverReason.rate_limit, + retryable=True, + should_rotate_credential=True, + should_fallback=True, + ) + + # Confirmed billing exhaustion + return result_fn( + FailoverReason.billing, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + + +def _classify_400( + error_msg: str, + error_code: str, + body: dict, + *, + provider: str, + model: str, + approx_tokens: int, + context_length: int, + num_messages: int = 0, + result_fn, +) -> ClassifiedError: + """Classify 400 Bad Request — context overflow, format error, or generic.""" + + # Context overflow from 400 + if any(p in error_msg for p in _CONTEXT_OVERFLOW_PATTERNS): + return result_fn( + FailoverReason.context_overflow, + retryable=True, + should_compress=True, + ) + + # Some providers return model-not-found as 400 instead of 404 (e.g. OpenRouter). + if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS): + return result_fn( + FailoverReason.model_not_found, + retryable=False, + should_fallback=True, + ) + + # Some providers return rate limit / billing errors as 400 instead of 429/402. + # Check these patterns before falling through to format_error. + if any(p in error_msg for p in _RATE_LIMIT_PATTERNS): + return result_fn( + FailoverReason.rate_limit, + retryable=True, + should_rotate_credential=True, + should_fallback=True, + ) + if any(p in error_msg for p in _BILLING_PATTERNS): + return result_fn( + FailoverReason.billing, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + + # Generic 400 + large session → probable context overflow + # Anthropic sometimes returns a bare "Error" message when context is too large + err_body_msg = "" + if isinstance(body, dict): + err_obj = body.get("error", {}) + if isinstance(err_obj, dict): + err_body_msg = (err_obj.get("message") or "").strip().lower() + is_generic = len(err_body_msg) < 30 or err_body_msg in ("error", "") + is_large = approx_tokens > context_length * 0.4 or approx_tokens > 80000 or num_messages > 80 + + if is_generic and is_large: + return result_fn( + FailoverReason.context_overflow, + retryable=True, + should_compress=True, + ) + + # Non-retryable format error + return result_fn( + FailoverReason.format_error, + retryable=False, + should_fallback=True, + ) + + +# ── Error code classification ─────────────────────────────────────────── + +def _classify_by_error_code( + error_code: str, error_msg: str, result_fn, +) -> Optional[ClassifiedError]: + """Classify by structured error codes from the response body.""" + code_lower = error_code.lower() + + if code_lower in ("resource_exhausted", "throttled", "rate_limit_exceeded"): + return result_fn( + FailoverReason.rate_limit, + retryable=True, + should_rotate_credential=True, + ) + + if code_lower in ("insufficient_quota", "billing_not_active", "payment_required"): + return result_fn( + FailoverReason.billing, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + + if code_lower in ("model_not_found", "model_not_available", "invalid_model"): + return result_fn( + FailoverReason.model_not_found, + retryable=False, + should_fallback=True, + ) + + if code_lower in ("context_length_exceeded", "max_tokens_exceeded"): + return result_fn( + FailoverReason.context_overflow, + retryable=True, + should_compress=True, + ) + + return None + + +# ── Message pattern classification ────────────────────────────────────── + +def _classify_by_message( + error_msg: str, + error_type: str, + *, + approx_tokens: int, + context_length: int, + result_fn, +) -> Optional[ClassifiedError]: + """Classify based on error message patterns when no status code is available.""" + + # Payload-too-large patterns (from message text when no status_code) + if any(p in error_msg for p in _PAYLOAD_TOO_LARGE_PATTERNS): + return result_fn( + FailoverReason.payload_too_large, + retryable=True, + should_compress=True, + ) + + # Billing patterns + if any(p in error_msg for p in _BILLING_PATTERNS): + return result_fn( + FailoverReason.billing, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + + # Rate limit patterns + if any(p in error_msg for p in _RATE_LIMIT_PATTERNS): + return result_fn( + FailoverReason.rate_limit, + retryable=True, + should_rotate_credential=True, + should_fallback=True, + ) + + # Context overflow patterns + if any(p in error_msg for p in _CONTEXT_OVERFLOW_PATTERNS): + return result_fn( + FailoverReason.context_overflow, + retryable=True, + should_compress=True, + ) + + # Auth patterns + if any(p in error_msg for p in _AUTH_PATTERNS): + return result_fn( + FailoverReason.auth, + retryable=True, + should_rotate_credential=True, + ) + + # Model not found patterns + if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS): + return result_fn( + FailoverReason.model_not_found, + retryable=False, + should_fallback=True, + ) + + return None + + +# ── Helpers ───────────────────────────────────────────────────────────── + +def _extract_status_code(error: Exception) -> Optional[int]: + """Walk the error and its cause chain to find an HTTP status code.""" + current = error + for _ in range(5): # Max depth to prevent infinite loops + code = getattr(current, "status_code", None) + if isinstance(code, int): + return code + # Some SDKs use .status instead of .status_code + code = getattr(current, "status", None) + if isinstance(code, int) and 100 <= code < 600: + return code + # Walk cause chain + cause = getattr(current, "__cause__", None) or getattr(current, "__context__", None) + if cause is None or cause is current: + break + current = cause + return None + + +def _extract_error_body(error: Exception) -> dict: + """Extract the structured error body from an SDK exception.""" + body = getattr(error, "body", None) + if isinstance(body, dict): + return body + # Some errors have .response.json() + response = getattr(error, "response", None) + if response is not None: + try: + json_body = response.json() + if isinstance(json_body, dict): + return json_body + except Exception: + pass + return {} + + +def _extract_error_code(body: dict) -> str: + """Extract an error code string from the response body.""" + if not body: + return "" + error_obj = body.get("error", {}) + if isinstance(error_obj, dict): + code = error_obj.get("code") or error_obj.get("type") or "" + if isinstance(code, str) and code.strip(): + return code.strip() + # Top-level code + code = body.get("code") or body.get("error_code") or "" + if isinstance(code, (str, int)): + return str(code).strip() + return "" + + +def _extract_message(error: Exception, body: dict) -> str: + """Extract the most informative error message.""" + # Try structured body first + if body: + error_obj = body.get("error", {}) + if isinstance(error_obj, dict): + msg = error_obj.get("message", "") + if isinstance(msg, str) and msg.strip(): + return msg.strip()[:500] + msg = body.get("message", "") + if isinstance(msg, str) and msg.strip(): + return msg.strip()[:500] + # Fallback to str(error) + return str(error)[:500] diff --git a/run_agent.py b/run_agent.py index ecd0be656fc..8f60b8f0126 100644 --- a/run_agent.py +++ b/run_agent.py @@ -77,6 +77,7 @@ from hermes_constants import OPENROUTER_BASE_URL # Agent internals extracted to agent/ package for modularity from agent.memory_manager import build_memory_context_block from agent.retry_utils import jittered_backoff +from agent.error_classifier import classify_api_error, FailoverReason from agent.prompt_builder import ( DEFAULT_AGENT_IDENTITY, PLATFORM_HINTS, MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, SKILLS_GUIDANCE, @@ -8017,6 +8018,25 @@ class AIAgent: status_code = getattr(api_error, "status_code", None) error_context = self._extract_api_error_context(api_error) + + # ── Classify the error for structured recovery decisions ── + _compressor = getattr(self, "context_compressor", None) + _ctx_len = getattr(_compressor, "context_length", 200000) if _compressor else 200000 + classified = classify_api_error( + api_error, + provider=getattr(self, "provider", "") or "", + model=getattr(self, "model", "") or "", + approx_tokens=approx_tokens, + context_length=_ctx_len, + num_messages=len(api_messages) if api_messages else 0, + ) + logger.debug( + "Error classified: reason=%s status=%s retryable=%s compress=%s rotate=%s fallback=%s", + classified.reason.value, classified.status_code, + classified.retryable, classified.should_compress, + classified.should_rotate_credential, classified.should_fallback, + ) + recovered_with_pool, has_retried_429 = self._recover_with_credential_pool( status_code=status_code, has_retried_429=has_retried_429, @@ -8079,27 +8099,24 @@ class AIAgent: # from all messages so the next retry sends no thinking # blocks at all. One-shot — don't retry infinitely. if ( - self.api_mode == "anthropic_messages" - and status_code == 400 + classified.reason == FailoverReason.thinking_signature and not thinking_sig_retry_attempted ): - _err_msg_lower = str(api_error).lower() - if "signature" in _err_msg_lower and "thinking" in _err_msg_lower: - thinking_sig_retry_attempted = True - for _m in messages: - if isinstance(_m, dict): - _m.pop("reasoning_details", None) - self._vprint( - f"{self.log_prefix}⚠️ Thinking block signature invalid — " - f"stripped all thinking blocks, retrying...", - force=True, - ) - logging.warning( - "%sThinking block signature recovery: stripped " - "reasoning_details from %d messages", - self.log_prefix, len(messages), - ) - continue + thinking_sig_retry_attempted = True + for _m in messages: + if isinstance(_m, dict): + _m.pop("reasoning_details", None) + self._vprint( + f"{self.log_prefix}⚠️ Thinking block signature invalid — " + f"stripped all thinking blocks, retrying...", + force=True, + ) + logging.warning( + "%sThinking block signature recovery: stripped " + "reasoning_details from %d messages", + self.log_prefix, len(messages), + ) + continue retry_count += 1 elapsed_time = time.time() - api_start_time @@ -8156,14 +8173,7 @@ class AIAgent: # is NOT a transient rate limit — retrying or switching # credentials won't help. Reduce context to 200k (the # standard tier) and compress. - # Only applies to Sonnet — Opus 1M is general access. - _is_long_context_tier_error = ( - status_code == 429 - and "extra usage" in error_msg - and "long context" in error_msg - and "sonnet" in self.model.lower() - ) - if _is_long_context_tier_error: + if classified.reason == FailoverReason.long_context_tier: _reduced_ctx = 200000 compressor = self.context_compressor old_ctx = compressor.context_length @@ -8208,13 +8218,9 @@ class AIAgent: # When a fallback model is configured, switch immediately instead # of burning through retries with exponential backoff -- the # primary provider won't recover within the retry window. - is_rate_limited = ( - status_code == 429 - or "rate limit" in error_msg - or "too many requests" in error_msg - or "rate_limit" in error_msg - or "usage limit" in error_msg - or "quota" in error_msg + is_rate_limited = classified.reason in ( + FailoverReason.rate_limit, + FailoverReason.billing, ) if is_rate_limited and self._fallback_index < len(self._fallback_chain): # Don't eagerly fallback if credential pool rotation may @@ -8230,10 +8236,7 @@ class AIAgent: continue is_payload_too_large = ( - status_code == 413 - or 'request entity too large' in error_msg - or 'payload too large' in error_msg - or 'error code: 413' in error_msg + classified.reason == FailoverReason.payload_too_large ) if is_payload_too_large: @@ -8277,64 +8280,12 @@ class AIAgent: } # Check for context-length errors BEFORE generic 4xx handler. - # Local backends (LM Studio, Ollama, llama.cpp) often return - # HTTP 400 with messages like "Context size has been exceeded" - # which must trigger compression, not an immediate abort. - is_context_length_error = any(phrase in error_msg for phrase in [ - 'context length', 'context size', 'maximum context', - 'token limit', 'too many tokens', 'reduce the length', - 'exceeds the limit', 'context window', - 'request entity too large', # OpenRouter/Nous 413 safety net - 'prompt is too long', # Anthropic: "prompt is too long: N tokens > M maximum" - 'prompt exceeds max length', # Z.AI / GLM: generic 400 overflow wording - ]) - - # Fallback heuristic: Anthropic sometimes returns a generic - # 400 invalid_request_error with just "Error" as the message - # when the context is too large. If the error message is very - # short/generic AND the session is large, treat it as a - # probable context-length error and attempt compression rather - # than aborting. This prevents an infinite failure loop where - # each failed message gets persisted, making the session even - # larger. (#1630) - if not is_context_length_error and status_code == 400: - ctx_len = getattr(getattr(self, 'context_compressor', None), 'context_length', 200000) - is_large_session = approx_tokens > ctx_len * 0.4 or len(api_messages) > 80 - is_generic_error = len(error_msg.strip()) < 30 # e.g. just "error" - if is_large_session and is_generic_error: - is_context_length_error = True - self._vprint( - f"{self.log_prefix}⚠️ Generic 400 with large session " - f"(~{approx_tokens:,} tokens, {len(api_messages)} msgs) — " - f"treating as probable context overflow.", - force=True, - ) - - # Server disconnects on large sessions are often caused by - # the request exceeding the provider's context/payload limit - # without a proper HTTP error response. Treat these as - # context-length errors to trigger compression rather than - # burning through retries that will all fail the same way. - # This breaks the death spiral: disconnect → no token data - # → no compression → bigger session → more disconnects. - # (#2153) - if not is_context_length_error and not status_code: - _is_server_disconnect = ( - 'server disconnected' in error_msg - or 'peer closed connection' in error_msg - or error_type in ('ReadError', 'RemoteProtocolError', 'ServerDisconnectedError') - ) - if _is_server_disconnect: - ctx_len = getattr(getattr(self, 'context_compressor', None), 'context_length', 200000) - _is_large = approx_tokens > ctx_len * 0.6 or len(api_messages) > 200 - if _is_large: - is_context_length_error = True - self._vprint( - f"{self.log_prefix}⚠️ Server disconnected with large session " - f"(~{approx_tokens:,} tokens, {len(api_messages)} msgs) — " - f"treating as context-length error, attempting compression.", - force=True, - ) + # The classifier detects context overflow from: explicit error + # messages, generic 400 + large session heuristic (#1630), and + # server disconnect + large session pattern (#2153). + is_context_length_error = ( + classified.reason == FailoverReason.context_overflow + ) if is_context_length_error: compressor = self.context_compressor @@ -8406,35 +8357,30 @@ class AIAgent: "partial": True } - # Check for non-retryable client errors (4xx HTTP status codes). - # These indicate a problem with the request itself (bad model ID, - # invalid API key, forbidden, etc.) and will never succeed on retry. - # Note: 413 and context-length errors are excluded — handled above. - # 429 (rate limit) is transient and MUST be retried with backoff. - # 529 (Anthropic overloaded) is also transient. - # Also catch local validation errors (ValueError, TypeError) — these - # are programming bugs, not transient failures. - # Exclude UnicodeEncodeError — it's a ValueError subclass but is - # handled separately by the surrogate sanitization path above. - _RETRYABLE_STATUS_CODES = {413, 429, 529} + # Check for non-retryable client errors. The classifier + # already accounts for 413, 429, 529 (transient), context + # overflow, and generic-400 heuristics. Local validation + # errors (ValueError, TypeError) are programming bugs. is_local_validation_error = ( isinstance(api_error, (ValueError, TypeError)) and not isinstance(api_error, UnicodeEncodeError) ) - # Detect generic 400s from Anthropic OAuth (transient server-side failures). - # Real invalid_request_error responses include a descriptive message; - # transient ones contain only "Error" or are empty. (ref: issue #1608) - _err_body = getattr(api_error, "body", None) or {} - _err_message = (_err_body.get("error", {}).get("message", "") if isinstance(_err_body, dict) else "") - _is_generic_400 = (status_code == 400 and _err_message.strip().lower() in ("error", "")) - is_client_status_error = isinstance(status_code, int) and 400 <= status_code < 500 and status_code not in _RETRYABLE_STATUS_CODES and not _is_generic_400 - is_client_error = (is_local_validation_error or is_client_status_error or any(phrase in error_msg for phrase in [ - 'error code: 401', 'error code: 403', - 'error code: 404', 'error code: 422', - 'is not a valid model', 'invalid model', 'model not found', - 'invalid api key', 'invalid_api_key', 'authentication', - 'unauthorized', 'forbidden', 'not found', - ])) and not is_context_length_error + is_client_error = ( + is_local_validation_error + or ( + not classified.retryable + and not classified.should_compress + and classified.reason not in ( + FailoverReason.rate_limit, + FailoverReason.billing, + FailoverReason.overloaded, + FailoverReason.context_overflow, + FailoverReason.payload_too_large, + FailoverReason.long_context_tier, + FailoverReason.thinking_signature, + ) + ) + ) and not is_context_length_error if is_client_error: # Try fallback before aborting — a different provider @@ -8454,7 +8400,7 @@ class AIAgent: self._vprint(f"{self.log_prefix} 🔌 Provider: {_provider} Model: {_model}", force=True) self._vprint(f"{self.log_prefix} 🌐 Endpoint: {_base}", force=True) # Actionable guidance for common auth errors - if status_code in (401, 403) or "unauthorized" in error_msg or "forbidden" in error_msg or "permission" in error_msg: + if classified.is_auth or classified.reason == FailoverReason.billing: if _provider == "openai-codex" and status_code == 401: self._vprint(f"{self.log_prefix} 💡 Codex OAuth token was rejected (HTTP 401). Your token may have been", force=True) self._vprint(f"{self.log_prefix} refreshed by another client (Codex CLI, VS Code). To fix:", force=True) diff --git a/tests/agent/test_error_classifier.py b/tests/agent/test_error_classifier.py new file mode 100644 index 00000000000..da248f82184 --- /dev/null +++ b/tests/agent/test_error_classifier.py @@ -0,0 +1,750 @@ +"""Tests for agent.error_classifier — structured API error classification.""" + +import pytest +from agent.error_classifier import ( + ClassifiedError, + FailoverReason, + classify_api_error, + _extract_status_code, + _extract_error_body, + _extract_error_code, + _classify_402, +) + + +# ── Helper: mock API errors ──────────────────────────────────────────── + +class MockAPIError(Exception): + """Simulates an OpenAI SDK APIStatusError.""" + def __init__(self, message, status_code=None, body=None): + super().__init__(message) + self.status_code = status_code + self.body = body or {} + + +class MockTransportError(Exception): + """Simulates a transport-level error with a specific type name.""" + pass + + +class ReadTimeout(MockTransportError): + pass + + +class ConnectError(MockTransportError): + pass + + +class RemoteProtocolError(MockTransportError): + pass + + +class ServerDisconnectedError(MockTransportError): + pass + + +# ── Test: FailoverReason enum ────────────────────────────────────────── + +class TestFailoverReason: + def test_all_reasons_have_string_values(self): + for reason in FailoverReason: + assert isinstance(reason.value, str) + + def test_enum_members_exist(self): + expected = { + "auth", "auth_permanent", "billing", "rate_limit", + "overloaded", "server_error", "timeout", + "context_overflow", "payload_too_large", + "model_not_found", "format_error", + "thinking_signature", "long_context_tier", "unknown", + } + actual = {r.value for r in FailoverReason} + assert expected == actual + + +# ── Test: ClassifiedError ────────────────────────────────────────────── + +class TestClassifiedError: + def test_is_auth_property(self): + e1 = ClassifiedError(reason=FailoverReason.auth) + assert e1.is_auth is True + + e2 = ClassifiedError(reason=FailoverReason.auth_permanent) + assert e2.is_auth is True + + e3 = ClassifiedError(reason=FailoverReason.billing) + assert e3.is_auth is False + + def test_is_transient_property(self): + transient_reasons = [ + FailoverReason.rate_limit, + FailoverReason.overloaded, + FailoverReason.server_error, + FailoverReason.timeout, + FailoverReason.unknown, + ] + for reason in transient_reasons: + e = ClassifiedError(reason=reason) + assert e.is_transient is True, f"{reason} should be transient" + + non_transient = [ + FailoverReason.auth, + FailoverReason.billing, + FailoverReason.model_not_found, + FailoverReason.format_error, + ] + for reason in non_transient: + e = ClassifiedError(reason=reason) + assert e.is_transient is False, f"{reason} should NOT be transient" + + def test_defaults(self): + e = ClassifiedError(reason=FailoverReason.unknown) + assert e.retryable is True + assert e.should_compress is False + assert e.should_rotate_credential is False + assert e.should_fallback is False + assert e.status_code is None + assert e.message == "" + + +# ── Test: Status code extraction ─────────────────────────────────────── + +class TestExtractStatusCode: + def test_from_status_code_attr(self): + e = MockAPIError("fail", status_code=429) + assert _extract_status_code(e) == 429 + + def test_from_status_attr(self): + class ErrWithStatus(Exception): + status = 503 + assert _extract_status_code(ErrWithStatus()) == 503 + + def test_from_cause_chain(self): + inner = MockAPIError("inner", status_code=401) + outer = Exception("outer") + outer.__cause__ = inner + assert _extract_status_code(outer) == 401 + + def test_none_when_missing(self): + assert _extract_status_code(Exception("generic")) is None + + def test_rejects_non_http_status(self): + """Integers outside 100-599 on .status should be ignored.""" + class ErrWeirdStatus(Exception): + status = 42 + assert _extract_status_code(ErrWeirdStatus()) is None + + +# ── Test: Error body extraction ──────────────────────────────────────── + +class TestExtractErrorBody: + def test_from_body_attr(self): + e = MockAPIError("fail", body={"error": {"message": "bad"}}) + assert _extract_error_body(e) == {"error": {"message": "bad"}} + + def test_empty_when_no_body(self): + assert _extract_error_body(Exception("generic")) == {} + + +# ── Test: Error code extraction ──────────────────────────────────────── + +class TestExtractErrorCode: + def test_from_nested_error_code(self): + body = {"error": {"code": "rate_limit_exceeded"}} + assert _extract_error_code(body) == "rate_limit_exceeded" + + def test_from_nested_error_type(self): + body = {"error": {"type": "invalid_request_error"}} + assert _extract_error_code(body) == "invalid_request_error" + + def test_from_top_level_code(self): + body = {"code": "model_not_found"} + assert _extract_error_code(body) == "model_not_found" + + def test_empty_when_no_code(self): + assert _extract_error_code({}) == "" + assert _extract_error_code({"error": {"message": "oops"}}) == "" + + +# ── Test: 402 disambiguation ─────────────────────────────────────────── + +class TestClassify402: + """The critical 402 billing vs rate_limit disambiguation.""" + + def test_billing_exhaustion(self): + """Plain 402 = billing.""" + result = _classify_402( + "payment required", + lambda reason, **kw: ClassifiedError(reason=reason, **kw), + ) + assert result.reason == FailoverReason.billing + assert result.should_rotate_credential is True + + def test_transient_usage_limit(self): + """402 with 'usage limit' + 'try again' = rate limit, not billing.""" + result = _classify_402( + "usage limit exceeded. try again in 5 minutes", + lambda reason, **kw: ClassifiedError(reason=reason, **kw), + ) + assert result.reason == FailoverReason.rate_limit + assert result.should_rotate_credential is True + + def test_quota_with_retry(self): + """402 with 'quota' + 'retry' = rate limit.""" + result = _classify_402( + "quota exceeded, please retry after the window resets", + lambda reason, **kw: ClassifiedError(reason=reason, **kw), + ) + assert result.reason == FailoverReason.rate_limit + + def test_quota_without_retry(self): + """402 with just 'quota' but no transient signal = billing.""" + result = _classify_402( + "quota exceeded", + lambda reason, **kw: ClassifiedError(reason=reason, **kw), + ) + assert result.reason == FailoverReason.billing + + def test_insufficient_credits(self): + result = _classify_402( + "insufficient credits to complete request", + lambda reason, **kw: ClassifiedError(reason=reason, **kw), + ) + assert result.reason == FailoverReason.billing + + +# ── Test: Full classification pipeline ───────────────────────────────── + +class TestClassifyApiError: + """End-to-end classification tests.""" + + # ── Auth errors ── + + def test_401_classified_as_auth(self): + e = MockAPIError("Unauthorized", status_code=401) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.auth + assert result.should_rotate_credential is True + # 401 is non-retryable on its own — credential rotation runs + # before the retryability check in the agent loop. + assert result.retryable is False + assert result.should_fallback is True + + def test_403_classified_as_auth(self): + e = MockAPIError("Forbidden", status_code=403) + result = classify_api_error(e, provider="anthropic") + assert result.reason == FailoverReason.auth + assert result.should_fallback is True + + def test_403_key_limit_classified_as_billing(self): + """OpenRouter 403 'key limit exceeded' is billing, not auth.""" + e = MockAPIError("Key limit exceeded for this key", status_code=403) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.billing + assert result.should_rotate_credential is True + assert result.should_fallback is True + + def test_403_spending_limit_classified_as_billing(self): + e = MockAPIError("spending limit reached", status_code=403) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.billing + + # ── Billing ── + + def test_402_plain_billing(self): + e = MockAPIError("Payment Required", status_code=402) + result = classify_api_error(e) + assert result.reason == FailoverReason.billing + assert result.retryable is False + + def test_402_transient_usage_limit(self): + e = MockAPIError("usage limit exceeded, try again later", status_code=402) + result = classify_api_error(e) + assert result.reason == FailoverReason.rate_limit + assert result.retryable is True + + # ── Rate limit ── + + def test_429_rate_limit(self): + e = MockAPIError("Too Many Requests", status_code=429) + result = classify_api_error(e) + assert result.reason == FailoverReason.rate_limit + assert result.should_fallback is True + + # ── Server errors ── + + def test_500_server_error(self): + e = MockAPIError("Internal Server Error", status_code=500) + result = classify_api_error(e) + assert result.reason == FailoverReason.server_error + assert result.retryable is True + + def test_502_server_error(self): + e = MockAPIError("Bad Gateway", status_code=502) + result = classify_api_error(e) + assert result.reason == FailoverReason.server_error + + def test_503_overloaded(self): + e = MockAPIError("Service Unavailable", status_code=503) + result = classify_api_error(e) + assert result.reason == FailoverReason.overloaded + + def test_529_anthropic_overloaded(self): + e = MockAPIError("Overloaded", status_code=529) + result = classify_api_error(e) + assert result.reason == FailoverReason.overloaded + + # ── Model not found ── + + def test_404_model_not_found(self): + e = MockAPIError("model not found", status_code=404) + result = classify_api_error(e) + assert result.reason == FailoverReason.model_not_found + assert result.should_fallback is True + assert result.retryable is False + + def test_404_generic(self): + e = MockAPIError("Not Found", status_code=404) + result = classify_api_error(e) + assert result.reason == FailoverReason.model_not_found + + # ── Payload too large ── + + def test_413_payload_too_large(self): + e = MockAPIError("Request Entity Too Large", status_code=413) + result = classify_api_error(e) + assert result.reason == FailoverReason.payload_too_large + assert result.should_compress is True + + # ── Context overflow ── + + def test_400_context_length(self): + e = MockAPIError("context length exceeded: 250000 > 200000", status_code=400) + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + assert result.should_compress is True + + def test_400_too_many_tokens(self): + e = MockAPIError("This model's maximum context is 128000 tokens, too many tokens", status_code=400) + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + + def test_400_prompt_too_long(self): + e = MockAPIError("prompt is too long: 300000 tokens > 200000 maximum", status_code=400) + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + + def test_400_generic_large_session(self): + """Generic 400 with large session → context overflow heuristic.""" + e = MockAPIError( + "Error", + status_code=400, + body={"error": {"message": "Error"}}, + ) + result = classify_api_error(e, approx_tokens=100000, context_length=200000) + assert result.reason == FailoverReason.context_overflow + + def test_400_generic_small_session_is_format_error(self): + """Generic 400 with small session → format error, not context overflow.""" + e = MockAPIError( + "Error", + status_code=400, + body={"error": {"message": "Error"}}, + ) + result = classify_api_error(e, approx_tokens=1000, context_length=200000) + assert result.reason == FailoverReason.format_error + + # ── Server disconnect + large session ── + + def test_disconnect_large_session_context_overflow(self): + """Server disconnect with large session → context overflow.""" + e = Exception("server disconnected without sending complete message") + result = classify_api_error(e, approx_tokens=150000, context_length=200000) + assert result.reason == FailoverReason.context_overflow + assert result.should_compress is True + + def test_disconnect_small_session_timeout(self): + """Server disconnect with small session → timeout.""" + e = Exception("server disconnected without sending complete message") + result = classify_api_error(e, approx_tokens=5000, context_length=200000) + assert result.reason == FailoverReason.timeout + + # ── Provider-specific: Anthropic thinking signature ── + + def test_anthropic_thinking_signature(self): + e = MockAPIError( + "thinking block has invalid signature", + status_code=400, + ) + result = classify_api_error(e, provider="anthropic") + assert result.reason == FailoverReason.thinking_signature + assert result.retryable is True + + def test_non_anthropic_400_with_signature_not_classified_as_thinking(self): + """400 with 'signature' but from non-Anthropic → format error.""" + e = MockAPIError("invalid signature", status_code=400) + result = classify_api_error(e, provider="openrouter", approx_tokens=0) + # Without "thinking" in the message, it shouldn't be thinking_signature + assert result.reason != FailoverReason.thinking_signature + + # ── Provider-specific: Anthropic long-context tier ── + + def test_anthropic_long_context_tier(self): + e = MockAPIError( + "Extra usage is required for long context requests over 200k tokens", + status_code=429, + ) + result = classify_api_error(e, provider="anthropic", model="claude-sonnet-4") + assert result.reason == FailoverReason.long_context_tier + assert result.should_compress is True + + def test_normal_429_not_long_context(self): + """Normal 429 without 'extra usage' + 'long context' → rate_limit.""" + e = MockAPIError("Too Many Requests", status_code=429) + result = classify_api_error(e, provider="anthropic") + assert result.reason == FailoverReason.rate_limit + + # ── Transport errors ── + + def test_read_timeout(self): + e = ReadTimeout("Read timed out") + result = classify_api_error(e) + assert result.reason == FailoverReason.timeout + assert result.retryable is True + + def test_connect_error(self): + e = ConnectError("Connection refused") + result = classify_api_error(e) + assert result.reason == FailoverReason.timeout + + def test_connection_error_builtin(self): + e = ConnectionError("Connection reset by peer") + result = classify_api_error(e) + assert result.reason == FailoverReason.timeout + + def test_timeout_error_builtin(self): + e = TimeoutError("timed out") + result = classify_api_error(e) + assert result.reason == FailoverReason.timeout + + # ── Error code classification ── + + def test_error_code_resource_exhausted(self): + e = MockAPIError( + "Resource exhausted", + body={"error": {"code": "resource_exhausted", "message": "Too many requests"}}, + ) + result = classify_api_error(e) + assert result.reason == FailoverReason.rate_limit + + def test_error_code_model_not_found(self): + e = MockAPIError( + "Model not available", + body={"error": {"code": "model_not_found"}}, + ) + result = classify_api_error(e) + assert result.reason == FailoverReason.model_not_found + + def test_error_code_context_length_exceeded(self): + e = MockAPIError( + "Context too large", + body={"error": {"code": "context_length_exceeded"}}, + ) + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + + # ── Message-only patterns (no status code) ── + + def test_message_billing_pattern(self): + e = Exception("insufficient credits to complete this request") + result = classify_api_error(e) + assert result.reason == FailoverReason.billing + + def test_message_rate_limit_pattern(self): + e = Exception("rate limit reached for this model") + result = classify_api_error(e) + assert result.reason == FailoverReason.rate_limit + + def test_message_auth_pattern(self): + e = Exception("invalid api key provided") + result = classify_api_error(e) + assert result.reason == FailoverReason.auth + + def test_message_model_not_found_pattern(self): + e = Exception("gpt-99 is not a valid model") + result = classify_api_error(e) + assert result.reason == FailoverReason.model_not_found + + def test_message_context_overflow_pattern(self): + e = Exception("maximum context length exceeded") + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + + # ── Unknown / fallback ── + + def test_generic_exception_is_unknown(self): + e = Exception("something weird happened") + result = classify_api_error(e) + assert result.reason == FailoverReason.unknown + assert result.retryable is True + + # ── Format error ── + + def test_400_descriptive_format_error(self): + """400 with descriptive message (not context overflow) → format error.""" + e = MockAPIError( + "Invalid value for parameter 'temperature': must be between 0 and 2", + status_code=400, + body={"error": {"message": "Invalid value for parameter 'temperature': must be between 0 and 2"}}, + ) + result = classify_api_error(e, approx_tokens=1000) + assert result.reason == FailoverReason.format_error + assert result.retryable is False + + def test_422_format_error(self): + e = MockAPIError("Unprocessable Entity", status_code=422) + result = classify_api_error(e) + assert result.reason == FailoverReason.format_error + assert result.retryable is False + + # ── Peer closed + large session ── + + def test_peer_closed_large_session(self): + e = Exception("peer closed connection without sending complete message") + result = classify_api_error(e, approx_tokens=130000, context_length=200000) + assert result.reason == FailoverReason.context_overflow + + # ── Chinese error messages ── + + def test_chinese_context_overflow(self): + e = MockAPIError("超过最大长度限制", status_code=400) + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + + # ── Result metadata ── + + def test_provider_and_model_in_result(self): + e = MockAPIError("fail", status_code=500) + result = classify_api_error(e, provider="openrouter", model="gpt-5") + assert result.provider == "openrouter" + assert result.model == "gpt-5" + assert result.status_code == 500 + + def test_message_extracted(self): + e = MockAPIError( + "outer", + status_code=500, + body={"error": {"message": "Internal server error occurred"}}, + ) + result = classify_api_error(e) + assert result.message == "Internal server error occurred" + + +# ── Test: Adversarial / edge cases (from live testing) ───────────────── + +class TestAdversarialEdgeCases: + """Edge cases discovered during live testing with real SDK objects.""" + + def test_empty_exception_message(self): + result = classify_api_error(Exception("")) + assert result.reason == FailoverReason.unknown + assert result.retryable is True + + def test_500_with_none_body(self): + e = MockAPIError("fail", status_code=500, body=None) + result = classify_api_error(e) + assert result.reason == FailoverReason.server_error + + def test_non_dict_body(self): + """Some providers return strings instead of JSON.""" + class StringBodyError(Exception): + status_code = 400 + body = "just a string" + result = classify_api_error(StringBodyError("bad")) + assert result.reason == FailoverReason.format_error + + def test_list_body(self): + class ListBodyError(Exception): + status_code = 500 + body = [{"error": "something"}] + result = classify_api_error(ListBodyError("server error")) + assert result.reason == FailoverReason.server_error + + def test_circular_cause_chain(self): + """Must not infinite-loop on circular __cause__.""" + e = Exception("circular") + e.__cause__ = e + result = classify_api_error(e) + assert result.reason == FailoverReason.unknown + + def test_three_level_cause_chain(self): + inner = MockAPIError("inner", status_code=429) + middle = Exception("middle") + middle.__cause__ = inner + outer = RuntimeError("outer") + outer.__cause__ = middle + result = classify_api_error(outer) + assert result.status_code == 429 + assert result.reason == FailoverReason.rate_limit + + def test_400_with_rate_limit_text(self): + """Some providers send rate limits as 400 instead of 429.""" + e = MockAPIError( + "rate limit policy", + status_code=400, + body={"error": {"message": "rate limit exceeded on this model"}}, + ) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.rate_limit + + def test_400_with_billing_text(self): + """Some providers send billing errors as 400.""" + e = MockAPIError( + "billing", + status_code=400, + body={"error": {"message": "insufficient credits for this request"}}, + ) + result = classify_api_error(e) + assert result.reason == FailoverReason.billing + + def test_200_with_error_body(self): + """200 status with error in body — should be unknown, not crash.""" + class WeirdSuccess(Exception): + status_code = 200 + body = {"error": {"message": "loading"}} + result = classify_api_error(WeirdSuccess("model loading")) + assert result.reason == FailoverReason.unknown + + def test_ollama_context_size_exceeded(self): + e = MockAPIError( + "Error", + status_code=400, + body={"error": {"message": "context size has been exceeded"}}, + ) + result = classify_api_error(e, provider="ollama") + assert result.reason == FailoverReason.context_overflow + + def test_connection_refused_error(self): + e = ConnectionRefusedError("Connection refused: localhost:11434") + result = classify_api_error(e, provider="ollama") + assert result.reason == FailoverReason.timeout + + def test_body_message_enrichment(self): + """Body message must be included in pattern matching even when + str(error) doesn't contain it (OpenAI SDK APIStatusError).""" + e = MockAPIError( + "Usage limit", # str(e) = "usage limit" + status_code=402, + body={"error": {"message": "Usage limit reached, try again in 5 minutes"}}, + ) + result = classify_api_error(e) + # "try again" is only in body, not in str(e) + assert result.reason == FailoverReason.rate_limit + + def test_disconnect_pattern_ordering(self): + """Disconnect + large session must beat generic transport catch.""" + class FakeRemoteProtocol(Exception): + pass + # Type name isn't in _TRANSPORT_ERROR_TYPES but message has disconnect pattern + e = Exception("peer closed connection without sending complete message") + result = classify_api_error(e, approx_tokens=150000, context_length=200000) + assert result.reason == FailoverReason.context_overflow + assert result.should_compress is True + + def test_credit_balance_too_low(self): + e = MockAPIError( + "Credits low", + status_code=402, + body={"error": {"message": "Your credit balance is too low"}}, + ) + result = classify_api_error(e, provider="anthropic") + assert result.reason == FailoverReason.billing + + def test_deepseek_402_chinese(self): + """Chinese billing message should still match billing patterns.""" + # "余额不足" doesn't match English billing patterns, but 402 defaults to billing + e = MockAPIError("余额不足", status_code=402) + result = classify_api_error(e, provider="deepseek") + assert result.reason == FailoverReason.billing + + def test_openrouter_wrapped_context_overflow_in_metadata_raw(self): + """OpenRouter wraps provider errors in metadata.raw JSON string.""" + e = MockAPIError( + "Provider returned error", + status_code=400, + body={ + "error": { + "message": "Provider returned error", + "code": 400, + "metadata": { + "raw": '{"error":{"message":"context length exceeded: 50000 > 32768"}}' + } + } + }, + ) + result = classify_api_error(e, provider="openrouter", approx_tokens=10000) + assert result.reason == FailoverReason.context_overflow + assert result.should_compress is True + + def test_openrouter_wrapped_rate_limit_in_metadata_raw(self): + e = MockAPIError( + "Provider returned error", + status_code=400, + body={ + "error": { + "message": "Provider returned error", + "metadata": { + "raw": '{"error":{"message":"Rate limit exceeded. Please retry after 30s."}}' + } + } + }, + ) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.rate_limit + + def test_thinking_signature_via_openrouter(self): + """Thinking signature errors proxied through OpenRouter must be caught.""" + e = MockAPIError( + "thinking block has invalid signature", + status_code=400, + ) + # provider is openrouter, not anthropic — old code missed this + result = classify_api_error(e, provider="openrouter", model="anthropic/claude-sonnet-4") + assert result.reason == FailoverReason.thinking_signature + + def test_generic_400_large_by_message_count(self): + """Many small messages (>80) should trigger context overflow heuristic.""" + e = MockAPIError( + "Error", + status_code=400, + body={"error": {"message": "Error"}}, + ) + # Low token count but high message count + result = classify_api_error( + e, approx_tokens=5000, context_length=200000, num_messages=100, + ) + assert result.reason == FailoverReason.context_overflow + + def test_disconnect_large_by_message_count(self): + """Server disconnect with 200+ messages should trigger context overflow.""" + e = Exception("server disconnected without sending complete message") + result = classify_api_error( + e, approx_tokens=5000, context_length=200000, num_messages=250, + ) + assert result.reason == FailoverReason.context_overflow + + def test_openrouter_wrapped_model_not_found_in_metadata_raw(self): + e = MockAPIError( + "Provider returned error", + status_code=400, + body={ + "error": { + "message": "Provider returned error", + "metadata": { + "raw": '{"error":{"message":"The model gpt-99 does not exist"}}' + } + } + }, + ) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.model_not_found From 173289b64fbe16ab322d0f0331cb0a3ad27cd5f3 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Thu, 9 Apr 2026 04:11:03 -0700 Subject: [PATCH 44/49] docs: add hermes dump and hermes logs to CLI commands reference (#6552) Documents both debugging commands with full option tables, examples, and usage guidance. Adds both to the top-level commands table and as detailed sections with subsections for log files, filtering behavior, and log rotation. --- website/docs/reference/cli-commands.md | 145 +++++++++++++++++++++++++ 1 file changed, 145 insertions(+) diff --git a/website/docs/reference/cli-commands.md b/website/docs/reference/cli-commands.md index 55983b1c699..a7362b06ff7 100644 --- a/website/docs/reference/cli-commands.md +++ b/website/docs/reference/cli-commands.md @@ -43,6 +43,8 @@ hermes [global-options] [subcommand/options] | `hermes cron` | Inspect and tick the cron scheduler. | | `hermes webhook` | Manage dynamic webhook subscriptions for event-driven activation. | | `hermes doctor` | Diagnose config and dependency issues. | +| `hermes dump` | Copy-pasteable setup summary for support/debugging. | +| `hermes logs` | View, tail, and filter agent/gateway/error log files. | | `hermes config` | Show, edit, migrate, and query configuration files. | | `hermes pairing` | Approve or revoke messaging pairing codes. | | `hermes skills` | Browse, install, publish, audit, and configure skills. | @@ -272,6 +274,149 @@ hermes doctor [--fix] |--------|-------------| | `--fix` | Attempt automatic repairs where possible. | +## `hermes dump` + +```bash +hermes dump [--show-keys] +``` + +Outputs a compact, plain-text summary of your entire Hermes setup. Designed to be copy-pasted into Discord, GitHub issues, or Telegram when asking for support — no ANSI colors, no special formatting, just data. + +| Option | Description | +|--------|-------------| +| `--show-keys` | Show redacted API key prefixes (first and last 4 characters) instead of just `set`/`not set`. | + +### What it includes + +| Section | Details | +|---------|---------| +| **Header** | Hermes version, release date, git commit hash | +| **Environment** | OS, Python version, OpenAI SDK version | +| **Identity** | Active profile name, HERMES_HOME path | +| **Model** | Configured default model and provider | +| **Terminal** | Backend type (local, docker, ssh, etc.) | +| **API keys** | Presence check for all 22 provider/tool API keys | +| **Features** | Enabled toolsets, MCP server count, memory provider | +| **Services** | Gateway status, configured messaging platforms | +| **Workload** | Cron job counts, installed skill count | +| **Config overrides** | Any config values that differ from defaults | + +### Example output + +``` +--- hermes dump --- +version: 0.8.0 (2026.4.8) [af4abd2f] +os: Linux 6.14.0-37-generic x86_64 +python: 3.11.14 +openai_sdk: 2.24.0 +profile: default +hermes_home: ~/.hermes +model: anthropic/claude-opus-4.6 +provider: openrouter +terminal: local + +api_keys: + openrouter set + openai not set + anthropic set + nous not set + firecrawl set + ... + +features: + toolsets: all + mcp_servers: 0 + memory_provider: built-in + gateway: running (systemd) + platforms: telegram, discord + cron_jobs: 3 active / 5 total + skills: 42 + +config_overrides: + agent.max_turns: 250 + compression.threshold: 0.85 + display.streaming: True +--- end dump --- +``` + +### When to use + +- Reporting a bug on GitHub — paste the dump into your issue +- Asking for help in Discord — share it in a code block +- Comparing your setup to someone else's +- Quick sanity check when something isn't working + +:::tip +`hermes dump` is specifically designed for sharing. For interactive diagnostics, use `hermes doctor`. For a visual overview, use `hermes status`. +::: + +## `hermes logs` + +```bash +hermes logs [log_name] [options] +``` + +View, tail, and filter Hermes log files. All logs are stored in `~/.hermes/logs/` (or `/logs/` for non-default profiles). + +### Log files + +| Name | File | What it captures | +|------|------|-----------------| +| `agent` (default) | `agent.log` | All agent activity — API calls, tool dispatch, session lifecycle (INFO and above) | +| `errors` | `errors.log` | Warnings and errors only — a filtered subset of agent.log | +| `gateway` | `gateway.log` | Messaging gateway activity — platform connections, message dispatch, webhook events | + +### Options + +| Option | Description | +|--------|-------------| +| `log_name` | Which log to view: `agent` (default), `errors`, `gateway`, or `list` to show available files with sizes. | +| `-n`, `--lines ` | Number of lines to show (default: 50). | +| `-f`, `--follow` | Follow the log in real time, like `tail -f`. Press Ctrl+C to stop. | +| `--level ` | Minimum log level to show: `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`. | +| `--session ` | Filter lines containing a session ID substring. | +| `--since