Merge branch 'main' of github.com:NousResearch/hermes-agent into fix/show-reasoning-per-platform

This commit is contained in:
Leon 2026-04-20 00:58:25 +08:00
commit 26bd52c1ba
291 changed files with 23713 additions and 2985 deletions

View file

@ -697,7 +697,12 @@ class TestIsConnectionError:
class TestKimiForCodingTemperature:
"""kimi-for-coding now requires temperature=0.6 exactly."""
"""Moonshot kimi-for-coding models require fixed temperatures.
k2.5 / k2-turbo-preview / k2-0905-preview 0.6 (non-thinking lock).
k2-thinking / k2-thinking-turbo 1.0 (thinking lock).
kimi-k2-instruct* and every other model preserve the caller's temperature.
"""
def test_build_call_kwargs_forces_fixed_temperature(self):
from agent.auxiliary_client import _build_call_kwargs
@ -772,12 +777,55 @@ class TestKimiForCodingTemperature:
assert kwargs["model"] == "kimi-for-coding"
assert kwargs["temperature"] == 0.6
def test_non_kimi_model_still_preserves_temperature(self):
@pytest.mark.parametrize(
"model,expected",
[
("kimi-k2.5", 0.6),
("kimi-k2-turbo-preview", 0.6),
("kimi-k2-0905-preview", 0.6),
("kimi-k2-thinking", 1.0),
("kimi-k2-thinking-turbo", 1.0),
("moonshotai/kimi-k2.5", 0.6),
("moonshotai/Kimi-K2-Thinking", 1.0),
],
)
def test_kimi_k2_family_temperature_override(self, model, expected):
"""Moonshot kimi-k2.* models only accept fixed temperatures.
Non-thinking models 0.6, thinking-mode models 1.0.
"""
from agent.auxiliary_client import _build_call_kwargs
kwargs = _build_call_kwargs(
provider="kimi-coding",
model="kimi-k2.5",
model=model,
messages=[{"role": "user", "content": "hello"}],
temperature=0.3,
)
assert kwargs["temperature"] == expected
@pytest.mark.parametrize(
"model",
[
"anthropic/claude-sonnet-4-6",
"gpt-5.4",
# kimi-k2-instruct is the non-coding K2 family — temperature is
# variable (recommended 0.6 but not enforced). Must not clamp.
"kimi-k2-instruct",
"moonshotai/Kimi-K2-Instruct",
"moonshotai/Kimi-K2-Instruct-0905",
"kimi-k2-instruct-0905",
# Hypothetical future kimi name not in the whitelist.
"kimi-k2-experimental",
],
)
def test_non_restricted_model_preserves_temperature(self, model):
from agent.auxiliary_client import _build_call_kwargs
kwargs = _build_call_kwargs(
provider="openrouter",
model=model,
messages=[{"role": "user", "content": "hello"}],
temperature=0.3,
)

View file

@ -781,3 +781,127 @@ class TestTokenBudgetTailProtection:
# Tool at index 2 is outside the protected tail (last 3 = indices 2,3,4)
# so it might or might not be pruned depending on boundary
assert isinstance(pruned, int)
class TestTruncateToolCallArgsJson:
"""Regression tests for #11762.
The previous implementation produced invalid JSON by slicing
``function.arguments`` mid-string, which caused non-retryable 400s from
strict providers (observed on MiniMax) and stuck long sessions in a
re-send loop. The helper here must always emit parseable JSON whose
shape matches the original shrunken, not corrupted.
"""
def _helper(self):
from agent.context_compressor import _truncate_tool_call_args_json
return _truncate_tool_call_args_json
def test_shrunken_args_remain_valid_json(self):
import json as _json
shrink = self._helper()
original = _json.dumps({
"path": "~/.hermes/skills/shopping/browser-setup-notes.md",
"content": "# Shopping Browser Setup Notes\n\n" + "abc " * 400,
})
assert len(original) > 500
shrunk = shrink(original)
parsed = _json.loads(shrunk) # must not raise
assert parsed["path"] == "~/.hermes/skills/shopping/browser-setup-notes.md"
assert parsed["content"].endswith("...[truncated]")
assert len(shrunk) < len(original)
def test_non_json_arguments_pass_through(self):
shrink = self._helper()
not_json = "this is not json at all, " * 50
assert shrink(not_json) == not_json
def test_short_string_leaves_unchanged(self):
import json as _json
shrink = self._helper()
payload = _json.dumps({"command": "ls -la", "cwd": "/tmp"})
assert _json.loads(shrink(payload)) == {"command": "ls -la", "cwd": "/tmp"}
def test_nested_structures_are_walked(self):
import json as _json
shrink = self._helper()
payload = _json.dumps({
"messages": [
{"role": "user", "content": "x" * 500},
{"role": "assistant", "content": "ok"},
],
"meta": {"note": "y" * 500},
})
parsed = _json.loads(shrink(payload))
assert parsed["messages"][0]["content"].endswith("...[truncated]")
assert parsed["messages"][1]["content"] == "ok"
assert parsed["meta"]["note"].endswith("...[truncated]")
def test_non_string_leaves_preserved(self):
import json as _json
shrink = self._helper()
payload = _json.dumps({
"retries": 3,
"enabled": True,
"timeout": None,
"items": [1, 2, 3],
"note": "z" * 500,
})
parsed = _json.loads(shrink(payload))
assert parsed["retries"] == 3
assert parsed["enabled"] is True
assert parsed["timeout"] is None
assert parsed["items"] == [1, 2, 3]
assert parsed["note"].endswith("...[truncated]")
def test_scalar_json_string_gets_shrunk(self):
import json as _json
shrink = self._helper()
payload = _json.dumps("q" * 500)
parsed = _json.loads(shrink(payload))
assert isinstance(parsed, str)
assert parsed.endswith("...[truncated]")
def test_unicode_preserved(self):
import json as _json
shrink = self._helper()
payload = _json.dumps({"content": "非德满" + ("a" * 500)})
out = shrink(payload)
# ensure_ascii=False keeps CJK intact rather than emitting \uXXXX
assert "非德满" in out
def test_pass3_emits_valid_json_for_downstream_provider(self):
"""End-to-end: Pass 3 must never produce the exact failure payload
that caused the 400 loop (unterminated string, missing brace)."""
import json as _json
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
c = ContextCompressor(
model="test/model",
threshold_percent=0.85,
protect_first_n=1,
protect_last_n=1,
quiet_mode=True,
)
huge_content = "# Shopping Browser Setup Notes\n\n## Overview\n" + "x " * 400
args_payload = _json.dumps({
"path": "~/.hermes/skills/shopping/browser-setup-notes.md",
"content": huge_content,
})
assert len(args_payload) > 500 # triggers the Pass-3 shrink
messages = [
{"role": "user", "content": "please write two files"},
{"role": "assistant", "content": None, "tool_calls": [
{"id": "call_1", "type": "function",
"function": {"name": "write_file", "arguments": args_payload}},
]},
{"role": "tool", "tool_call_id": "call_1",
"content": '{"bytes_written": 727}'},
{"role": "user", "content": "ok"},
{"role": "assistant", "content": "done"},
]
result, _ = c._prune_old_tool_results(messages, protect_tail_count=2)
shrunk = result[1]["tool_calls"][0]["function"]["arguments"]
# Must parse — otherwise downstream provider returns 400
parsed = _json.loads(shrunk)
assert parsed["path"] == "~/.hermes/skills/shopping/browser-setup-notes.md"
assert parsed["content"].endswith("...[truncated]")

View file

@ -971,8 +971,6 @@ class TestHonchoCadenceTracking:
class FakeManager:
def prefetch_context(self, key, query=None):
pass
def prefetch_dialectic(self, key, query):
pass
p._manager = FakeManager()

View file

@ -208,34 +208,81 @@ class TestMem0UserIdScoping:
class TestHonchoUserIdScoping:
"""Verify Honcho plugin uses gateway user_id for peer_name when provided."""
"""Verify Honcho plugin keeps runtime user scoping separate from config peer_name."""
def test_gateway_user_id_overrides_peer_name(self):
"""When user_id is in kwargs and no explicit peer_name, user_id should be used."""
def test_gateway_user_id_is_passed_as_runtime_peer(self):
"""Gateway user_id should scope Honcho sessions without mutating config peer_name."""
from plugins.memory.honcho import HonchoMemoryProvider
provider = HonchoMemoryProvider()
# Create a mock config with NO explicit peer_name
mock_cfg = MagicMock()
mock_cfg.enabled = True
mock_cfg.api_key = "test-key"
mock_cfg.base_url = None
mock_cfg.peer_name = "" # No explicit peer_name — user_id should fill it
mock_cfg.recall_mode = "tools" # Use tools mode to defer session init
mock_cfg.peer_name = "static-user"
mock_cfg.recall_mode = "context"
mock_cfg.context_tokens = None
mock_cfg.raw = {}
mock_cfg.dialectic_depth = 1
mock_cfg.dialectic_depth_levels = None
mock_cfg.init_on_session_start = False
mock_cfg.ai_peer = "hermes"
mock_cfg.resolve_session_name.return_value = "test-sess"
mock_cfg.session_strategy = "shared"
with patch(
"plugins.memory.honcho.client.HonchoClientConfig.from_global_config",
return_value=mock_cfg,
):
), patch(
"plugins.memory.honcho.client.get_honcho_client",
return_value=MagicMock(),
), patch(
"plugins.memory.honcho.session.HonchoSessionManager",
) as mock_manager_cls:
mock_manager = MagicMock()
mock_manager.get_or_create.return_value = MagicMock(messages=[])
mock_manager_cls.return_value = mock_manager
provider.initialize(
session_id="test-sess",
user_id="discord_user_789",
platform="discord",
)
# The config's peer_name should have been overridden with the user_id
assert mock_cfg.peer_name == "discord_user_789"
assert mock_cfg.peer_name == "static-user"
assert mock_manager_cls.call_args.kwargs["runtime_user_peer_name"] == "discord_user_789"
def test_session_manager_prefers_runtime_user_id_over_config_peer_name(self):
"""Session manager should isolate gateway users even when config peer_name is static."""
from plugins.memory.honcho.session import HonchoSessionManager
mock_cfg = MagicMock()
mock_cfg.peer_name = "static-user"
mock_cfg.ai_peer = "hermes"
mock_cfg.write_frequency = "sync"
mock_cfg.dialectic_reasoning_level = "low"
mock_cfg.dialectic_dynamic = True
mock_cfg.dialectic_max_chars = 600
mock_cfg.observation_mode = "directional"
mock_cfg.user_observe_me = True
mock_cfg.user_observe_others = True
mock_cfg.ai_observe_me = True
mock_cfg.ai_observe_others = True
manager = HonchoSessionManager(
honcho=MagicMock(),
config=mock_cfg,
runtime_user_peer_name="discord_user_789",
)
with patch.object(manager, "_get_or_create_peer", return_value=MagicMock()), patch.object(
manager,
"_get_or_create_honcho_session",
return_value=(MagicMock(), []),
):
session = manager.get_or_create("discord:channel-1")
assert session.user_peer_id == "discord_user_789"
def test_no_user_id_preserves_config_peer_name(self):
"""Without user_id, the config peer_name should be preserved."""

View file

@ -79,7 +79,7 @@ class TestBuildChildProgressCallback:
parent._delegate_spinner = None
parent.tool_progress_callback = None
cb = _build_child_progress_callback(0, parent)
cb = _build_child_progress_callback(0, "test goal", parent)
assert cb is None
def test_cli_spinner_tool_event(self):
@ -93,7 +93,7 @@ class TestBuildChildProgressCallback:
parent._delegate_spinner = spinner
parent.tool_progress_callback = None
cb = _build_child_progress_callback(0, parent)
cb = _build_child_progress_callback(0, "test goal", parent)
assert cb is not None
cb("tool.started", "web_search", "quantum computing", {})
@ -113,7 +113,7 @@ class TestBuildChildProgressCallback:
parent._delegate_spinner = spinner
parent.tool_progress_callback = None
cb = _build_child_progress_callback(0, parent)
cb = _build_child_progress_callback(0, "test goal", parent)
cb("_thinking", "I'll search for papers first")
output = buf.getvalue()
@ -121,54 +121,64 @@ class TestBuildChildProgressCallback:
assert "search for papers" in output
def test_gateway_batched_progress(self):
"""Gateway path should batch tool calls and flush at BATCH_SIZE."""
"""Gateway path: each tool.started relays a subagent.tool event, and a
subagent.progress summary fires once BATCH_SIZE tools accumulate."""
parent = MagicMock()
parent._delegate_spinner = None
parent_cb = MagicMock()
parent.tool_progress_callback = parent_cb
cb = _build_child_progress_callback(0, parent)
# Send 4 tool calls — shouldn't flush yet (BATCH_SIZE = 5)
cb = _build_child_progress_callback(0, "test goal", parent)
# Each tool.started relays a subagent.tool event immediately (per-tool relay).
for i in range(4):
cb("tool.started", f"tool_{i}", f"arg_{i}", {})
parent_cb.assert_not_called()
# 5th call should trigger flush
cb("tool.started", "tool_4", "arg_4", {})
parent_cb.assert_called_once()
call_args = parent_cb.call_args
assert "tool_0" in call_args[0][1]
assert "tool_4" in call_args[0][1]
# 4 per-tool relays so far, no batch summary yet (BATCH_SIZE=5)
events = [c.args[0] for c in parent_cb.call_args_list]
assert events == ["subagent.tool"] * 4
def test_thinking_not_relayed_to_gateway(self):
"""Thinking events should NOT be sent to gateway (too noisy)."""
# 5th call triggers another per-tool relay PLUS the batch-size summary
cb("tool.started", "tool_4", "arg_4", {})
events = [c.args[0] for c in parent_cb.call_args_list]
assert events == ["subagent.tool"] * 5 + ["subagent.progress"]
summary_call = parent_cb.call_args_list[-1]
summary_text = summary_call.kwargs.get("preview") or summary_call.args[2]
assert "tool_0" in summary_text
assert "tool_4" in summary_text
def test_thinking_relayed_to_gateway(self):
"""Thinking events are relayed as subagent.thinking events."""
parent = MagicMock()
parent._delegate_spinner = None
parent_cb = MagicMock()
parent.tool_progress_callback = parent_cb
cb = _build_child_progress_callback(0, parent)
cb = _build_child_progress_callback(0, "test goal", parent)
cb("_thinking", "some reasoning text")
parent_cb.assert_not_called()
parent_cb.assert_called_once()
assert parent_cb.call_args.args[0] == "subagent.thinking"
assert parent_cb.call_args.args[2] == "some reasoning text"
def test_parallel_callbacks_independent(self):
"""Each child's callback should have independent batch state."""
"""Each child's callback batches tool names independently."""
parent = MagicMock()
parent._delegate_spinner = None
parent_cb = MagicMock()
parent.tool_progress_callback = parent_cb
cb0 = _build_child_progress_callback(0, parent)
cb1 = _build_child_progress_callback(1, parent)
# Send 3 calls to each — neither should flush (batch size = 5)
cb0 = _build_child_progress_callback(0, "goal a", parent)
cb1 = _build_child_progress_callback(1, "goal b", parent)
# 3 tool.started per child = 6 per-tool relays; neither should hit
# the batch-size summary (batch size = 5, counted per-child).
for i in range(3):
cb0(f"tool_{i}")
cb1(f"other_{i}")
parent_cb.assert_not_called()
cb0("tool.started", f"tool_{i}", f"a_{i}", {})
cb1("tool.started", f"other_{i}", f"b_{i}", {})
events = [c.args[0] for c in parent_cb.call_args_list]
assert events.count("subagent.tool") == 6
assert "subagent.progress" not in events
def test_task_index_prefix_in_batch_mode(self):
"""Batch mode (task_count > 1) should show 1-indexed prefix for all tasks."""
@ -182,7 +192,7 @@ class TestBuildChildProgressCallback:
parent.tool_progress_callback = None
# task_index=0 in a batch of 3 → prefix "[1]"
cb0 = _build_child_progress_callback(0, parent, task_count=3)
cb0 = _build_child_progress_callback(0, "test goal", parent, task_count=3)
cb0("web_search", "test")
output = buf.getvalue()
assert "[1]" in output
@ -190,7 +200,7 @@ class TestBuildChildProgressCallback:
# task_index=2 in a batch of 3 → prefix "[3]"
buf.truncate(0)
buf.seek(0)
cb2 = _build_child_progress_callback(2, parent, task_count=3)
cb2 = _build_child_progress_callback(2, "test goal", parent, task_count=3)
cb2("web_search", "test")
output = buf.getvalue()
assert "[3]" in output
@ -206,7 +216,7 @@ class TestBuildChildProgressCallback:
parent._delegate_spinner = spinner
parent.tool_progress_callback = None
cb = _build_child_progress_callback(0, parent, task_count=1)
cb = _build_child_progress_callback(0, "test goal", parent, task_count=1)
cb("tool.started", "web_search", "test", {})
output = buf.getvalue()
@ -321,26 +331,31 @@ class TestBatchFlush:
"""Tests for gateway batch flush on subagent completion."""
def test_flush_sends_remaining_batch(self):
"""_flush should send remaining tool names to gateway."""
"""_flush should send a final subagent.progress summary of any unsent
tool names in the batch (less than BATCH_SIZE)."""
parent = MagicMock()
parent._delegate_spinner = None
parent_cb = MagicMock()
parent.tool_progress_callback = parent_cb
cb = _build_child_progress_callback(0, parent)
cb = _build_child_progress_callback(0, "test goal", parent)
# Send 3 tools (below batch size of 5)
# Send 3 tools (below batch size of 5) — each relays subagent.tool
cb("tool.started", "web_search", "query1", {})
cb("tool.started", "read_file", "file.txt", {})
cb("tool.started", "write_file", "out.txt", {})
parent_cb.assert_not_called()
events = [c.args[0] for c in parent_cb.call_args_list]
assert events == ["subagent.tool"] * 3 # per-tool relays so far
assert "subagent.progress" not in events # no batch-size summary yet
# Flush should send the remaining 3
# Flush should send the remaining 3 as a summary
cb._flush()
parent_cb.assert_called_once()
summary = parent_cb.call_args[0][1]
assert "web_search" in summary
assert "write_file" in summary
events = [c.args[0] for c in parent_cb.call_args_list]
assert events[-1] == "subagent.progress"
summary_call = parent_cb.call_args_list[-1]
summary_text = summary_call.kwargs.get("preview") or summary_call.args[2]
assert "web_search" in summary_text
assert "write_file" in summary_text
def test_flush_noop_when_batch_empty(self):
"""_flush should not send anything when batch is empty."""
@ -349,7 +364,7 @@ class TestBatchFlush:
parent_cb = MagicMock()
parent.tool_progress_callback = parent_cb
cb = _build_child_progress_callback(0, parent)
cb = _build_child_progress_callback(0, "test goal", parent)
cb._flush()
parent_cb.assert_not_called()
@ -364,7 +379,7 @@ class TestBatchFlush:
parent._delegate_spinner = spinner
parent.tool_progress_callback = None
cb = _build_child_progress_callback(0, parent)
cb = _build_child_progress_callback(0, "test goal", parent)
cb("tool.started", "web_search", "test", {})
cb._flush() # Should not crash

View file

@ -237,6 +237,13 @@ class TestCLIStatusBar:
cli_obj._spinner_text = ""
assert cli_obj._spinner_widget_height(width=90) == 0
def test_spinner_height_uses_display_width_for_wide_characters(self):
cli_obj = _make_cli()
cli_obj._spinner_text = "" * 40
cli_obj._tool_start_time = 0
assert cli_obj._spinner_widget_height(width=64) == 2
def test_voice_status_bar_compacts_on_narrow_terminals(self):
cli_obj = _make_cli()
cli_obj._voice_mode = True

View file

@ -0,0 +1,21 @@
from unittest.mock import MagicMock, patch
def test_gquota_uses_chat_console_when_tui_is_live():
from agent.google_oauth import GoogleOAuthError
from cli import HermesCLI
cli = HermesCLI.__new__(HermesCLI)
cli.console = MagicMock()
cli._app = object()
live_console = MagicMock()
with patch("cli.ChatConsole", return_value=live_console), \
patch("agent.google_oauth.get_valid_access_token", side_effect=GoogleOAuthError("No Google OAuth credentials found")), \
patch("agent.google_oauth.load_credentials", return_value=None), \
patch("agent.google_code_assist.retrieve_user_quota"):
cli._handle_gquota_command("/gquota")
assert live_console.print.call_count == 2
cli.console.print.assert_not_called()

View file

@ -33,6 +33,20 @@ class TestCLIQuickCommands:
printed = self._printed_plain(cli.console.print.call_args[0][0])
assert printed == "daily-note"
def test_exec_command_uses_chat_console_when_tui_is_live(self):
cli = self._make_cli({"dn": {"type": "exec", "command": "echo daily-note"}})
cli._app = object()
live_console = MagicMock()
with patch("cli.ChatConsole", return_value=live_console):
result = cli.process_command("/dn")
assert result is True
live_console.print.assert_called_once()
printed = self._printed_plain(live_console.print.call_args[0][0])
assert printed == "daily-note"
cli.console.print.assert_not_called()
def test_exec_command_stderr_shown_on_no_stdout(self):
cli = self._make_cli({"err": {"type": "exec", "command": "echo error >&2"}})
result = cli.process_command("/err")

View file

@ -473,6 +473,7 @@ class TestInlineThinkBlockExtraction(unittest.TestCase):
agent.verbose_logging = False
agent.reasoning_callback = None
agent.stream_delta_callback = None # non-streaming by default
agent._stream_callback = None # non-streaming by default
return agent
def test_single_think_block_extracted(self):
@ -619,6 +620,7 @@ class TestReasoningDeltasFiredFlag(unittest.TestCase):
agent = AIAgent.__new__(AIAgent)
agent.reasoning_callback = None
agent.stream_delta_callback = None
agent._stream_callback = None
agent.verbose_logging = False
return agent

View file

@ -344,6 +344,127 @@ class TestDisplayResumedHistory:
assert "Just thinking" not in output
assert "Hi there!" in output
def test_think_tags_stripped(self):
"""<think>...</think> blocks should be stripped from display (#11316)."""
cli = _make_cli()
cli.conversation_history = [
{"role": "user", "content": "Solve this"},
{
"role": "assistant",
"content": "<think>\nI need to reason carefully here.\n</think>\n\nThe answer is 7.",
},
]
output = self._capture_display(cli)
assert "<think>" not in output
assert "</think>" not in output
assert "I need to reason carefully here" not in output
assert "The answer is 7" in output
def test_thinking_tags_stripped(self):
"""<thinking>...</thinking> blocks should be stripped from display."""
cli = _make_cli()
cli.conversation_history = [
{"role": "user", "content": "What is 2+2?"},
{
"role": "assistant",
"content": "<thinking>\nLet me compute: 2 + 2 = 4\n</thinking>\n\nThe answer is 4.",
},
]
output = self._capture_display(cli)
assert "<thinking>" not in output
assert "Let me compute" not in output
assert "The answer is 4" in output
def test_reasoning_tags_stripped(self):
"""<reasoning>...</reasoning> blocks should be stripped from display."""
cli = _make_cli()
cli.conversation_history = [
{"role": "user", "content": "Explain gravity"},
{
"role": "assistant",
"content": (
"<reasoning>\nGravity is a fundamental force...\n</reasoning>\n\n"
"Gravity pulls objects together."
),
},
]
output = self._capture_display(cli)
assert "<reasoning>" not in output
assert "fundamental force" not in output
assert "Gravity pulls objects together" in output
def test_thought_tags_stripped(self):
"""<thought>...</thought> blocks (Gemma 4) should be stripped."""
cli = _make_cli()
cli.conversation_history = [
{"role": "user", "content": "Say hello"},
{
"role": "assistant",
"content": "<thought>\nInternal thought here.\n</thought>\n\nHello!",
},
]
output = self._capture_display(cli)
assert "<thought>" not in output
assert "Internal thought here" not in output
assert "Hello!" in output
def test_unclosed_think_tag_stripped(self):
"""Unclosed <think> (truncated generation) should not leak reasoning."""
cli = _make_cli()
cli.conversation_history = [
{"role": "user", "content": "Truncated response"},
{
"role": "assistant",
"content": "Some text before.\n<think>\nUnfinished reasoning...",
},
]
output = self._capture_display(cli)
assert "<think>" not in output
assert "Unfinished reasoning" not in output
assert "Some text before" in output
def test_multiple_reasoning_blocks_all_stripped(self):
"""Multiple interleaved reasoning blocks are all stripped."""
cli = _make_cli()
cli.conversation_history = [
{"role": "user", "content": "Complex question"},
{
"role": "assistant",
"content": (
"<think>\nFirst thought.\n</think>\n"
"Partial text.\n"
"<reasoning>\nSecond thought.\n</reasoning>\n"
"Final answer."
),
},
]
output = self._capture_display(cli)
assert "First thought" not in output
assert "Second thought" not in output
assert "Partial text" in output
assert "Final answer" in output
def test_orphan_closing_think_tag_stripped(self):
"""A stray </think> with no matching open should not render to user."""
cli = _make_cli()
cli.conversation_history = [
{"role": "user", "content": "Broken output"},
{
"role": "assistant",
"content": "some leftover reasoning</think>Visible answer.",
},
]
output = self._capture_display(cli)
assert "</think>" not in output
assert "Visible answer" in output
def test_assistant_with_text_and_tool_calls(self):
"""When an assistant message has both text content AND tool_calls."""
cli = _make_cli()

View file

@ -1024,7 +1024,7 @@ class TestRunJobSkillBacked:
"id": "multi-skill-job",
"name": "multi skill test",
"prompt": "Combine the results.",
"skills": ["blogwatcher", "find-nearby"],
"skills": ["blogwatcher", "maps"],
}
fake_db = MagicMock()
@ -1057,12 +1057,12 @@ class TestRunJobSkillBacked:
assert error is None
assert final_response == "ok"
assert skill_view_mock.call_count == 2
assert [call.args[0] for call in skill_view_mock.call_args_list] == ["blogwatcher", "find-nearby"]
assert [call.args[0] for call in skill_view_mock.call_args_list] == ["blogwatcher", "maps"]
prompt_arg = mock_agent.run_conversation.call_args.args[0]
assert prompt_arg.index("blogwatcher") < prompt_arg.index("find-nearby")
assert prompt_arg.index("blogwatcher") < prompt_arg.index("maps")
assert "Instructions for blogwatcher." in prompt_arg
assert "Instructions for find-nearby." in prompt_arg
assert "Instructions for maps." in prompt_arg
assert "Combine the results." in prompt_arg
@ -1175,6 +1175,180 @@ class TestBuildJobPromptSilentHint:
assert system_pos < prompt_pos
class TestParseWakeGate:
"""Unit tests for _parse_wake_gate — pure function, no side effects."""
def test_empty_output_wakes(self):
from cron.scheduler import _parse_wake_gate
assert _parse_wake_gate("") is True
assert _parse_wake_gate(None) is True
def test_whitespace_only_wakes(self):
from cron.scheduler import _parse_wake_gate
assert _parse_wake_gate(" \n\n \t\n") is True
def test_non_json_last_line_wakes(self):
from cron.scheduler import _parse_wake_gate
assert _parse_wake_gate("hello world") is True
assert _parse_wake_gate("line 1\nline 2\nplain text") is True
def test_json_non_dict_wakes(self):
"""Bare arrays, numbers, strings must not be interpreted as a gate."""
from cron.scheduler import _parse_wake_gate
assert _parse_wake_gate("[1, 2, 3]") is True
assert _parse_wake_gate("42") is True
assert _parse_wake_gate('"wakeAgent"') is True
def test_wake_gate_false_skips(self):
from cron.scheduler import _parse_wake_gate
assert _parse_wake_gate('{"wakeAgent": false}') is False
def test_wake_gate_true_wakes(self):
from cron.scheduler import _parse_wake_gate
assert _parse_wake_gate('{"wakeAgent": true}') is True
def test_wake_gate_missing_wakes(self):
"""A JSON dict without a wakeAgent key defaults to waking."""
from cron.scheduler import _parse_wake_gate
assert _parse_wake_gate('{"data": {"foo": "bar"}}') is True
def test_non_boolean_false_still_wakes(self):
"""Only strict ``False`` skips — truthy/falsy shortcuts are too risky."""
from cron.scheduler import _parse_wake_gate
assert _parse_wake_gate('{"wakeAgent": 0}') is True
assert _parse_wake_gate('{"wakeAgent": null}') is True
assert _parse_wake_gate('{"wakeAgent": ""}') is True
def test_only_last_non_empty_line_parsed(self):
from cron.scheduler import _parse_wake_gate
multi = 'some log output\nmore output\n{"wakeAgent": false}'
assert _parse_wake_gate(multi) is False
def test_trailing_blank_lines_ignored(self):
from cron.scheduler import _parse_wake_gate
multi = '{"wakeAgent": false}\n\n\n'
assert _parse_wake_gate(multi) is False
def test_non_last_json_line_does_not_gate(self):
"""A JSON gate on an earlier line with plain text after it does NOT trigger."""
from cron.scheduler import _parse_wake_gate
multi = '{"wakeAgent": false}\nactually this is the real output'
assert _parse_wake_gate(multi) is True
class TestRunJobWakeGate:
"""Integration tests for run_job wake-gate short-circuit."""
def _make_job(self, name="wake-gate-test", script="check.py"):
"""Minimal valid cron job dict for run_job."""
return {
"id": f"job_{name}",
"name": name,
"prompt": "Do a thing",
"schedule": "*/5 * * * *",
"script": script,
}
def test_wake_false_skips_agent_and_returns_silent(self, caplog):
"""When _run_job_script output ends with {wakeAgent: false}, the agent
is not invoked and run_job returns the SILENT marker so delivery is
suppressed."""
from cron.scheduler import SILENT_MARKER
import cron.scheduler as scheduler
with patch.object(scheduler, "_run_job_script",
return_value=(True, '{"wakeAgent": false}')), \
patch("run_agent.AIAgent") as agent_cls:
success, doc, final, err = scheduler.run_job(self._make_job())
assert success is True
assert err is None
assert final == SILENT_MARKER
assert "Script gate returned `wakeAgent=false`" in doc
agent_cls.assert_not_called()
def test_wake_true_runs_agent_with_injected_output(self):
"""When the script returns {wakeAgent: true, data: ...}, the agent is
invoked and the data line still shows up in the prompt."""
import cron.scheduler as scheduler
script_output = '{"wakeAgent": true, "data": {"new": 3}}'
agent = MagicMock()
agent.run_conversation = MagicMock(return_value={
"final_response": "ok", "messages": []
})
with patch.object(scheduler, "_run_job_script",
return_value=(True, script_output)), \
patch("run_agent.AIAgent", return_value=agent) as agent_cls:
success, doc, final, err = scheduler.run_job(self._make_job())
agent_cls.assert_called_once()
# The script output should be visible in the prompt passed to
# run_conversation.
call_kwargs = agent.run_conversation.call_args
prompt_arg = call_kwargs.args[0] if call_kwargs.args else call_kwargs.kwargs.get("user_message", "")
assert script_output in prompt_arg
assert success is True
assert err is None
def test_script_runs_only_once_on_wake(self):
"""Wake-true path must not re-run the script inside _build_job_prompt
(script would execute twice otherwise, wasting work and risking
double-side-effects)."""
import cron.scheduler as scheduler
call_count = 0
def _script_stub(path):
nonlocal call_count
call_count += 1
return (True, "regular output")
agent = MagicMock()
agent.run_conversation = MagicMock(return_value={
"final_response": "ok", "messages": []
})
with patch.object(scheduler, "_run_job_script", side_effect=_script_stub), \
patch("run_agent.AIAgent", return_value=agent):
scheduler.run_job(self._make_job())
assert call_count == 1, f"script ran {call_count}x, expected exactly 1"
def test_script_failure_does_not_trigger_gate(self):
"""If _run_job_script returns success=False, the gate is NOT evaluated
and the agent still runs (the failure is reported as context)."""
import cron.scheduler as scheduler
# Malicious or broken script whose stderr happens to contain the
# gate JSON — we must NOT honor it because ran_ok is False.
agent = MagicMock()
agent.run_conversation = MagicMock(return_value={
"final_response": "ok", "messages": []
})
with patch.object(scheduler, "_run_job_script",
return_value=(False, '{"wakeAgent": false}')), \
patch("run_agent.AIAgent", return_value=agent) as agent_cls:
success, doc, final, err = scheduler.run_job(self._make_job())
agent_cls.assert_called_once() # Agent DID wake despite the gate-like text
def test_no_script_path_runs_agent_normally(self):
"""Regression: jobs without a script still work."""
import cron.scheduler as scheduler
agent = MagicMock()
agent.run_conversation = MagicMock(return_value={
"final_response": "ok", "messages": []
})
job = self._make_job(script=None)
job.pop("script", None)
with patch.object(scheduler, "_run_job_script") as script_fn, \
patch("run_agent.AIAgent", return_value=agent) as agent_cls:
scheduler.run_job(job)
script_fn.assert_not_called()
agent_cls.assert_called_once()
class TestBuildJobPromptMissingSkill:
"""Verify that a missing skill logs a warning and does not crash the job."""

View file

@ -0,0 +1,148 @@
"""Regression test: cancel_background_tasks must drain late-arrival tasks.
During gateway shutdown, a message arriving while
cancel_background_tasks is mid-await can spawn a fresh
_process_message_background task via handle_message, which is added
to self._background_tasks. Without the re-drain loop, the subsequent
_background_tasks.clear() drops the reference; the task runs
untracked against a disconnecting adapter.
"""
import asyncio
from unittest.mock import AsyncMock
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType
from gateway.session import SessionSource, build_session_key
class _StubAdapter(BasePlatformAdapter):
async def connect(self):
pass
async def disconnect(self):
pass
async def send(self, chat_id, text, **kwargs):
return None
async def get_chat_info(self, chat_id):
return {}
def _make_adapter():
adapter = _StubAdapter(PlatformConfig(enabled=True, token="t"), Platform.TELEGRAM)
adapter._send_with_retry = AsyncMock(return_value=None)
return adapter
def _event(text, cid="42"):
return MessageEvent(
text=text,
message_type=MessageType.TEXT,
source=SessionSource(platform=Platform.TELEGRAM, chat_id=cid, chat_type="dm"),
)
@pytest.mark.asyncio
async def test_cancel_background_tasks_drains_late_arrivals():
"""A message that arrives during the gather window must be picked
up by the re-drain loop, not leaked as an untracked task."""
adapter = _make_adapter()
sk = build_session_key(
SessionSource(platform=Platform.TELEGRAM, chat_id="42", chat_type="dm")
)
m1_started = asyncio.Event()
m1_cleanup_running = asyncio.Event()
m2_started = asyncio.Event()
m2_cancelled = asyncio.Event()
async def handler(event):
if event.text == "M1":
m1_started.set()
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
m1_cleanup_running.set()
# Widen the gather window with a shielded cleanup
# delay so M2 can get injected during it.
await asyncio.shield(asyncio.sleep(0.2))
raise
else: # M2 — the late arrival
m2_started.set()
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
m2_cancelled.set()
raise
adapter._message_handler = handler
# Spawn M1.
await adapter.handle_message(_event("M1"))
await asyncio.wait_for(m1_started.wait(), timeout=1.0)
# Kick off shutdown. This will cancel M1 and await its cleanup.
cancel_task = asyncio.create_task(adapter.cancel_background_tasks())
# Wait until M1's cleanup is running (inside the shielded sleep).
# This is the race window: cancel_task is awaiting gather, M1 is
# shielded in cleanup, the _active_sessions entry has been cleared
# by M1's own finally.
await asyncio.wait_for(m1_cleanup_running.wait(), timeout=1.0)
# Clear the active-session entry (M1's finally hasn't fully run yet,
# but in production the platform dispatcher would deliver a new
# message that takes the no-active-session spawn path). For this
# repro, make it deterministic.
adapter._active_sessions.pop(sk, None)
# Inject late arrival — spawns a fresh _process_message_background
# task and adds it to _background_tasks while cancel_task is still
# in gather.
await adapter.handle_message(_event("M2"))
await asyncio.wait_for(m2_started.wait(), timeout=1.0)
# Let cancel_task finish. Round 1's gather completes when M1's
# shielded cleanup finishes. Round 2 should pick up M2.
await asyncio.wait_for(cancel_task, timeout=5.0)
# Assert M2 was drained, not leaked.
assert m2_cancelled.is_set(), (
"Late-arrival M2 was NOT cancelled by cancel_background_tasks — "
"the re-drain loop is missing and the task leaked"
)
assert adapter._background_tasks == set()
@pytest.mark.asyncio
async def test_cancel_background_tasks_handles_no_tasks():
"""Regression guard: no tasks, no hang, no error."""
adapter = _make_adapter()
await adapter.cancel_background_tasks()
assert adapter._background_tasks == set()
@pytest.mark.asyncio
async def test_cancel_background_tasks_bounded_rounds():
"""Regression guard: the drain loop is bounded — it does not spin
forever even if late-arrival tasks keep getting spawned."""
adapter = _make_adapter()
# Single well-behaved task that cancels cleanly — baseline check
# that the loop terminates in one round.
async def quick():
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
raise
task = asyncio.create_task(quick())
adapter._background_tasks.add(task)
await adapter.cancel_background_tasks()
assert task.done()
assert adapter._background_tasks == set()

View file

@ -200,6 +200,25 @@ class TestCommandBypassActiveSession:
"/background response was not sent back to the user"
)
@pytest.mark.asyncio
async def test_steer_bypasses_guard(self):
"""/steer must bypass the Level-1 active-session guard so it reaches
the gateway runner's /steer handler and injects into the running
agent instead of being queued as user text for the next turn.
"""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event("/steer also check auth.log"))
assert sk not in adapter._pending_messages, (
"/steer was queued as a pending message instead of being dispatched"
)
assert any("handled:steer" in r for r in adapter.sent_responses), (
"/steer response was not sent back to the user"
)
@pytest.mark.asyncio
async def test_help_bypasses_guard(self):
"""/help must bypass so it is not silently dropped as pending slash text."""
@ -249,6 +268,82 @@ class TestCommandBypassActiveSession:
)
# ---------------------------------------------------------------------------
# Tests: non-bypass-set commands (no dedicated Level-2 handler) also bypass
# instead of interrupting + being discarded. Regression for the Discord
# ghost-slash-command bug where /model, /reasoning, /voice, /insights, /title,
# /resume, /retry, /undo, /compress, /usage, /provider, /reload-mcp,
# /sethome, /reset silently interrupted the running agent.
# ---------------------------------------------------------------------------
class TestAllResolvableCommandsBypassGuard:
"""Every recognized slash command must bypass the Level-1 active-session
guard. Without this, commands the user fires mid-run interrupt the agent
AND get silently discarded by the slash-command safety net (zero-char
response)."""
@pytest.mark.parametrize(
"command_text,canonical",
[
("/model claude-sonnet-4", "model"),
("/model", "model"),
("/reasoning high", "reasoning"),
("/personality default", "personality"),
("/voice on", "voice"),
("/insights 7", "insights"),
("/title my session", "title"),
("/resume yesterday", "resume"),
("/retry", "retry"),
("/undo", "undo"),
("/compress", "compress"),
("/usage", "usage"),
("/provider", "provider"),
("/reload-mcp", "reload-mcp"),
("/sethome", "sethome"),
],
)
@pytest.mark.asyncio
async def test_command_bypasses_guard(self, command_text, canonical):
"""Any resolvable slash command bypasses instead of being queued."""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event(command_text))
assert sk not in adapter._pending_messages, (
f"{command_text} was queued as pending — it should bypass the guard"
)
assert len(adapter.sent_responses) > 0, (
f"{command_text} produced no response — it should be dispatched, "
"not silently discarded"
)
def test_should_bypass_returns_true_for_every_registered_command(self):
"""Spot-check: the commands previously-broken on Discord all bypass."""
from hermes_cli.commands import should_bypass_active_session
for cmd in (
"model", "reasoning", "personality", "voice", "insights", "title",
"resume", "retry", "undo", "compress", "usage", "provider",
"reload-mcp", "sethome", "reset",
):
assert should_bypass_active_session(cmd) is True, (
f"/{cmd} must bypass the active-session guard"
)
def test_should_bypass_returns_false_for_unknown(self):
"""Unknown words don't bypass — they get queued as user text."""
from hermes_cli.commands import should_bypass_active_session
assert should_bypass_active_session("foobar") is False
assert should_bypass_active_session(None) is False
assert should_bypass_active_session("") is False
# A file path split on whitespace: '/path/to/file.py' -> 'path/to/file.py'
assert should_bypass_active_session("path/to/file.py") is False
# ---------------------------------------------------------------------------
# Tests: non-bypass messages still get queued
# ---------------------------------------------------------------------------

View file

@ -0,0 +1,122 @@
"""Regression tests for the Discord adapter race-polish fix.
Two races are addressed:
1. on_message allowlist check racing on_ready's _resolve_allowed_usernames
resolution window. Username-based entries in DISCORD_ALLOWED_USERS
appear in the set as raw strings for several seconds after
connect/reconnect; author.id is always numeric, so legitimate users
are silently rejected until resolution finishes.
2. join_voice_channel check-and-connect: concurrent /voice channel
invocations both see _voice_clients.get(guild_id) is None, both call
channel.connect(), second raises ClientException ('Already connected').
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import Platform, PlatformConfig
def _make_adapter():
"""Bare DiscordAdapter for testing — object.__new__ pattern per AGENTS.md."""
from gateway.platforms.discord import DiscordAdapter
adapter = object.__new__(DiscordAdapter)
adapter._platform = Platform.DISCORD
adapter.config = PlatformConfig(enabled=True, token="t")
adapter._ready_event = asyncio.Event()
adapter._allowed_user_ids = set()
adapter._allowed_role_ids = set()
adapter._voice_clients = {}
adapter._voice_locks = {}
adapter._voice_receivers = {}
adapter._voice_listen_tasks = {}
adapter._voice_timeout_tasks = {}
adapter._voice_text_channels = {}
adapter._voice_sources = {}
adapter._client = MagicMock()
return adapter
class TestJoinVoiceSerialization:
@pytest.mark.asyncio
async def test_concurrent_joins_do_not_double_connect(self):
"""Two concurrent join_voice_channel calls on the same guild
must serialize through the per-guild lock only ONE
channel.connect() actually fires; the second sees the
_voice_clients entry the first just installed."""
adapter = _make_adapter()
connect_count = [0]
connect_event = asyncio.Event()
class FakeVC:
def __init__(self, channel):
self.channel = channel
def is_connected(self):
return True
async def move_to(self, _channel):
return None
async def disconnect(self):
return None
async def slow_connect(self):
connect_count[0] += 1
# Widen the race window
await connect_event.wait()
return FakeVC(self)
channel = MagicMock()
channel.id = 111
channel.guild.id = 42
channel.connect = lambda: slow_connect(channel)
# Swap out VoiceReceiver so it doesn't try to set up real audio
from gateway.platforms import discord as discord_mod
with patch.object(discord_mod, "VoiceReceiver", MagicMock(return_value=MagicMock(start=lambda: None))):
with patch.object(discord_mod.asyncio, "ensure_future", lambda _c: asyncio.create_task(asyncio.sleep(0))):
# Fire two joins concurrently
t1 = asyncio.create_task(adapter.join_voice_channel(channel))
t2 = asyncio.create_task(adapter.join_voice_channel(channel))
# Let them run until they're blocked on our event
await asyncio.sleep(0.05)
# Release connect so both can finish
connect_event.set()
r1, r2 = await asyncio.gather(t1, t2)
assert connect_count[0] == 1, (
f"Expected exactly 1 channel.connect() call, got {connect_count[0]}"
"per-guild voice lock is not serializing join_voice_channel"
)
assert r1 is True and r2 is True
assert 42 in adapter._voice_clients
class TestOnMessageWaitsForReadyEvent:
@pytest.mark.asyncio
async def test_on_message_blocks_until_ready_event_set(self):
"""A message arriving before on_ready finishes
_resolve_allowed_usernames must wait, not proceed with a
half-resolved allowlist."""
# This is an integration-style check — we pull out the
# on_message handler by asserting the source contains the
# expected wait pattern. A full end-to-end test would require
# setting up the discord.py client machinery, which is not
# practical here.
import inspect
from gateway.platforms import discord as discord_mod
src = inspect.getsource(discord_mod.DiscordAdapter.connect)
assert "_ready_event.is_set()" in src, (
"on_message must gate on _ready_event so username-based "
"allowlist entries are resolved before the allowlist check"
)
assert "await asyncio.wait_for(" in src and "_ready_event.wait()" in src, (
"Expected asyncio.wait_for(_ready_event.wait(), timeout=...) "
"pattern in on_message"
)

View file

@ -645,3 +645,54 @@ def test_group_topic_chat_id_int_string_coercion():
assert event.auto_skill == "hermes-agent-dev"
assert event.source.chat_topic == "Dev"
# ── _build_message_event: from_user=None fallback in DMs ──
def test_build_message_event_dm_from_user_none_falls_back_to_chat_id():
"""When from_user is None in a DM, user_id should fall back to chat.id."""
from gateway.platforms.base import MessageType
adapter = _make_adapter()
msg = _make_mock_message(chat_id=12345, user_id=42, user_name="Alice")
# Simulate from_user being None (edge case on fresh restart / forwarded msg)
msg.from_user = None
event = adapter._build_message_event(msg, MessageType.TEXT)
# Should fall back to chat.id since chat_type is "dm"
assert event.source.user_id == "12345"
assert event.source.user_name == "Alice" # falls back to chat.full_name
def test_build_message_event_group_from_user_none_stays_none():
"""When from_user is None in a group, user_id should remain None."""
from gateway.platforms.base import MessageType
adapter = _make_adapter()
msg = _make_mock_message(
chat_id=-1001234567890, chat_type=_ChatType.SUPERGROUP,
user_id=42, user_name="Alice"
)
msg.from_user = None
event = adapter._build_message_event(msg, MessageType.TEXT)
# Groups should NOT fall back — anonymous senders stay None
assert event.source.user_id is None
assert event.source.user_name is None
def test_build_message_event_dm_from_user_present_uses_user():
"""When from_user is present in a DM, it should be used (no fallback)."""
from gateway.platforms.base import MessageType
adapter = _make_adapter()
msg = _make_mock_message(chat_id=12345, user_id=99999, user_name="Bob")
event = adapter._build_message_event(msg, MessageType.TEXT)
# Normal case — from_user is used directly
assert event.source.user_id == "99999"
assert event.source.user_name == "Bob"

View file

@ -2370,6 +2370,134 @@ class TestAdapterBehavior(unittest.TestCase):
elements = payload["zh_cn"]["content"][0]
self.assertEqual(elements, [{"tag": "md", "text": "可以用 **粗体** 和 *斜体*。"}])
@patch.dict(os.environ, {}, clear=True)
def test_send_splits_fenced_code_blocks_into_separate_post_rows(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
captured = {}
class _MessageAPI:
def create(self, request):
captured["request"] = request
return SimpleNamespace(
success=lambda: True,
data=SimpleNamespace(message_id="om_codeblock"),
)
adapter._client = SimpleNamespace(
im=SimpleNamespace(
v1=SimpleNamespace(
message=_MessageAPI(),
)
)
)
async def _direct(func, *args, **kwargs):
return func(*args, **kwargs)
content = (
"确认已入库 ✓\n"
"文件路径:`/root/.hermes/profiles/agent_cto/cron/jobs.json`\n"
"**解码后的内容:**\n"
"```json\n"
'{"cron": "list"}\n'
"```\n"
"后续说明仍应保留。"
)
with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct):
result = asyncio.run(
adapter.send(
chat_id="oc_chat",
content=content,
)
)
self.assertTrue(result.success)
self.assertEqual(captured["request"].request_body.msg_type, "post")
payload = json.loads(captured["request"].request_body.content)
rows = payload["zh_cn"]["content"]
self.assertEqual(
rows,
[
[
{
"tag": "md",
"text": "确认已入库 ✓\n文件路径:`/root/.hermes/profiles/agent_cto/cron/jobs.json`\n**解码后的内容:**",
}
],
[{"tag": "md", "text": "```json\n{\"cron\": \"list\"}\n```"}],
[{"tag": "md", "text": "后续说明仍应保留。"}],
],
)
@patch.dict(os.environ, {}, clear=True)
def test_build_post_payload_keeps_fence_like_code_lines_inside_code_block(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
payload = json.loads(
adapter._build_post_payload(
"before\n```python\n```oops\n```\nafter"
)
)
self.assertEqual(
payload["zh_cn"]["content"],
[
[{"tag": "md", "text": "before"}],
[{"tag": "md", "text": "```python\n```oops\n```"}],
[{"tag": "md", "text": "after"}],
],
)
@patch.dict(os.environ, {}, clear=True)
def test_build_post_payload_preserves_trailing_spaces_in_code_block(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
payload = json.loads(
adapter._build_post_payload(
"before\n```python\nline with two spaces \n```\nafter"
)
)
self.assertEqual(
payload["zh_cn"]["content"],
[
[{"tag": "md", "text": "before"}],
[{"tag": "md", "text": "```python\nline with two spaces \n```"}],
[{"tag": "md", "text": "after"}],
],
)
@patch.dict(os.environ, {}, clear=True)
def test_build_post_payload_splits_multiple_fenced_code_blocks(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
payload = json.loads(
adapter._build_post_payload(
"before\n```python\nprint(1)\n```\nmiddle\n```json\n{}\n```\nafter"
)
)
self.assertEqual(
payload["zh_cn"]["content"],
[
[{"tag": "md", "text": "before"}],
[{"tag": "md", "text": "```python\nprint(1)\n```"}],
[{"tag": "md", "text": "middle"}],
[{"tag": "md", "text": "```json\n{}\n```"}],
[{"tag": "md", "text": "after"}],
],
)
@patch.dict(os.environ, {}, clear=True)
def test_send_falls_back_to_text_when_post_payload_is_rejected(self):
from gateway.config import PlatformConfig

View file

@ -0,0 +1,212 @@
"""Regression tests: pending-drain + finally-cleanup races must not spawn
duplicate agents OR silently drop messages that arrived during cleanup.
Two related races in gateway/platforms/base.py:_process_message_background:
1. Pending-drain path (previous line 1931):
``del self._active_sessions[session_key]`` opened a window where a
concurrent inbound message could pass the Level-1 guard, spawn its
own _process_message_background, and run simultaneously with the
recursive drain. Two agents on one session_key = duplicate responses.
2. Finally-cleanup path (previous line 1990-1991):
Between the awaits in finally (typing_task, stop_typing) and the
``del self._active_sessions[session_key]``, a new message could
land in _pending_messages. The del ran anyway, and the message was
silently dropped user never got a reply.
Fix: keep the _active_sessions entry live across the turn chain and
clear the Event instead of deleting; in finally, drain any
late-arrival pending message by spawning a task instead of
dropping it.
"""
import asyncio
from unittest.mock import AsyncMock
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
MessageType,
)
from gateway.session import SessionSource, build_session_key
class _StubAdapter(BasePlatformAdapter):
async def connect(self):
pass
async def disconnect(self):
pass
async def send(self, chat_id, text, **kwargs):
return None
async def get_chat_info(self, chat_id):
return {}
def _make_adapter():
adapter = _StubAdapter(PlatformConfig(enabled=True, token="t"), Platform.TELEGRAM)
adapter._send_with_retry = AsyncMock(return_value=None)
return adapter
def _make_event(text="hi", chat_id="42"):
return MessageEvent(
text=text,
message_type=MessageType.TEXT,
source=SessionSource(platform=Platform.TELEGRAM, chat_id=chat_id, chat_type="dm"),
)
def _sk(chat_id="42"):
return build_session_key(
SessionSource(platform=Platform.TELEGRAM, chat_id=chat_id, chat_type="dm")
)
@pytest.mark.asyncio
async def test_pending_drain_keeps_active_session_guard_live():
"""Fix for R5: during pending-drain cleanup, _active_sessions must stay
populated so concurrent inbound messages can't spawn a duplicate
_process_message_background. We only CLEAR the Event, never delete."""
adapter = _make_adapter()
sk = _sk()
# Register a slow handler so the agent is "mid-processing" when the
# pending message arrives.
first_started = asyncio.Event()
release_first = asyncio.Event()
async def handler(event):
first_started.set()
await release_first.wait()
return "done"
adapter._message_handler = handler
# Spawn M1 through handle_message.
await adapter.handle_message(_make_event(text="M1"))
# Wait until M1 is actively running inside the handler.
await asyncio.wait_for(first_started.wait(), timeout=1.0)
# Assert: session is active.
assert sk in adapter._active_sessions
active_event = adapter._active_sessions[sk]
# Simulate pending message (M2) queued while M1 runs.
adapter._pending_messages[sk] = _make_event(text="M2")
# Release M1 — pending-drain block now runs. During its cleanup
# awaits, _active_sessions[sk] must remain populated (same object
# reference) so any M3 arriving in that window hits the busy-handler.
release_first.set()
# Give the drain a moment to execute its .clear() + await typing_task
# without letting it fully finish the recursive call.
await asyncio.sleep(0)
await asyncio.sleep(0)
# Across the drain transition, the Event object must be the SAME
# reference (not replaced, not deleted). If del happened, the key
# would be missing briefly; if a new Event was installed, the
# identity would differ.
assert sk in adapter._active_sessions, (
"_active_sessions[session_key] was deleted during pending-drain — "
"opens a window for duplicate-agent spawn"
)
assert adapter._active_sessions[sk] is active_event, (
"_active_sessions[session_key] was replaced during pending-drain — "
"the old Event may have waiters that now won't be signaled"
)
# Finish drain.
await asyncio.sleep(0.1)
await adapter.cancel_background_tasks()
@pytest.mark.asyncio
async def test_finally_cleanup_drains_late_arrival_pending():
"""Fix for R6: if a message lands in _pending_messages during the
finally-block cleanup awaits, the finally must spawn a drain task
instead of deleting _active_sessions and dropping the message."""
adapter = _make_adapter()
sk = _sk()
processed = []
async def handler(event):
processed.append(event.text)
return "ok"
adapter._message_handler = handler
# Instrument stop_typing to inject a late-arrival pending message
# during the finally-block await window. This exactly simulates the
# R6 race: the message arrives after the response has been sent but
# before _active_sessions is deleted.
original_stop = adapter.stop_typing if hasattr(adapter, "stop_typing") else None
injected = {"done": False}
async def stop_typing_injects_pending(*args, **kwargs):
# Yield so the injection happens mid-await.
await asyncio.sleep(0)
if not injected["done"]:
adapter._pending_messages[sk] = _make_event(text="LATE")
injected["done"] = True
if original_stop:
return await original_stop(*args, **kwargs)
return None
adapter.stop_typing = stop_typing_injects_pending
# Send M1.
await adapter.handle_message(_make_event(text="M1"))
# Drain: wait for M1 to finish and the late-drain task to process LATE.
for _ in range(50): # up to ~0.5s
if "LATE" in processed:
break
await asyncio.sleep(0.01)
await adapter.cancel_background_tasks()
assert "M1" in processed, "M1 was not processed"
assert "LATE" in processed, (
"Late-arrival pending message was silently dropped — finally "
"cleanup should have spawned a drain task"
)
@pytest.mark.asyncio
async def test_no_pending_cleans_up_normally():
"""Regression guard: when no pending message exists, the finally
block must still delete _active_sessions as before (no leak)."""
adapter = _make_adapter()
sk = _sk()
async def handler(event):
return "ok"
adapter._message_handler = handler
await adapter.handle_message(_make_event(text="solo"))
# Wait for background task to finish.
for _ in range(50):
if sk not in adapter._active_sessions:
break
await asyncio.sleep(0.01)
assert sk not in adapter._active_sessions, (
"_active_sessions was not cleaned up after a normal turn with no pending"
)
assert sk not in adapter._pending_messages
await adapter.cancel_background_tasks()

View file

@ -1,13 +1,18 @@
"""Tests for the pending_event None guard in recursive _run_agent calls.
"""Tests for pending follow-up extraction in recursive _run_agent calls.
When pending_event is None (Path B: pending comes from interrupt_message),
accessing pending_event.channel_prompt previously raised AttributeError.
This verifies the fix: channel_prompt is captured inside the
`if pending_event is not None:` block and falls back to None otherwise.
Also verifies that internal control interrupt reasons like "Stop requested"
do not get recycled into the pending-user-message follow-up path.
"""
from types import SimpleNamespace
from gateway.run import _is_control_interrupt_message
def _extract_channel_prompt(pending_event):
"""Reproduce the fixed logic from gateway/run.py.
@ -21,6 +26,15 @@ def _extract_channel_prompt(pending_event):
return next_channel_prompt
def _extract_pending_text(interrupted, pending_event, interrupt_message):
"""Reproduce the fixed pending-text selection from gateway/run.py."""
if interrupted and pending_event is None and interrupt_message:
if _is_control_interrupt_message(interrupt_message):
return None
return interrupt_message
return None
class TestPendingEventNoneChannelPrompt:
"""Guard against AttributeError when pending_event is None."""
@ -40,3 +54,19 @@ class TestPendingEventNoneChannelPrompt:
event = SimpleNamespace()
result = _extract_channel_prompt(event)
assert result is None
class TestControlInterruptMessages:
"""Control interrupt reasons must not become follow-up user input."""
def test_stop_requested_is_not_treated_as_pending_user_message(self):
result = _extract_pending_text(True, None, "Stop requested")
assert result is None
def test_session_reset_requested_is_not_treated_as_pending_user_message(self):
result = _extract_pending_text(True, None, "Session reset requested")
assert result is None
def test_real_user_interrupt_message_still_requeues(self):
result = _extract_pending_text(True, None, "actually use postgres instead")
assert result == "actually use postgres instead"

View file

@ -19,6 +19,7 @@ def _make_runner(proxy_url=None):
runner.config = MagicMock()
runner.config.streaming = StreamingConfig()
runner._running_agents = {}
runner._session_run_generation = {}
runner._session_model_overrides = {}
runner._agent_cache = {}
runner._agent_cache_lock = None
@ -160,10 +161,12 @@ class TestRunAgentProxyDispatch:
source=source,
session_id="test-session-123",
session_key="test-key",
run_generation=7,
)
assert result["final_response"] == "Hello from remote!"
runner._run_agent_via_proxy.assert_called_once()
assert runner._run_agent_via_proxy.call_args.kwargs["run_generation"] == 7
@pytest.mark.asyncio
async def test_run_agent_skips_proxy_when_not_configured(self, monkeypatch):
@ -370,6 +373,40 @@ class TestRunAgentViaProxy:
assert "session_id" in result
assert result["session_id"] == "sess-123"
@pytest.mark.asyncio
async def test_proxy_stale_generation_returns_empty_result(self, monkeypatch):
monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")
monkeypatch.delenv("GATEWAY_PROXY_KEY", raising=False)
runner = _make_runner()
source = _make_source()
runner._session_run_generation["test-key"] = 2
resp = _FakeSSEResponse(
status=200,
sse_chunks=[
'data: {"choices":[{"delta":{"content":"stale"}}]}\n\n',
"data: [DONE]\n\n",
],
)
session = _FakeSession(resp)
with patch("gateway.run._load_gateway_config", return_value={}):
with _patch_aiohttp(session):
with patch("aiohttp.ClientTimeout"):
result = await runner._run_agent_via_proxy(
message="hi",
context_prompt="",
history=[],
source=source,
session_id="sess-123",
session_key="test-key",
run_generation=1,
)
assert result["final_response"] == ""
assert result["messages"] == []
assert result["api_calls"] == 0
@pytest.mark.asyncio
async def test_no_auth_header_without_key(self, monkeypatch):
monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")

View file

@ -0,0 +1,247 @@
"""Tests for /restart idempotency guard against Telegram update re-delivery.
When PTB's graceful-shutdown ACK call (the final `get_updates` on exit) fails
with a network error, Telegram re-delivers the `/restart` message to the new
gateway process. Without a dedup guard, the new gateway would process
`/restart` again and immediately restart a self-perpetuating loop.
"""
import asyncio
import json
import time
from unittest.mock import MagicMock
import pytest
import gateway.run as gateway_run
from gateway.platforms.base import MessageEvent, MessageType
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
def _make_restart_event(update_id: int | None = 100) -> MessageEvent:
return MessageEvent(
text="/restart",
message_type=MessageType.TEXT,
source=make_restart_source(),
message_id="m1",
platform_update_id=update_id,
)
@pytest.mark.asyncio
async def test_restart_handler_writes_dedup_marker_with_update_id(tmp_path, monkeypatch):
"""First /restart writes .restart_last_processed.json with the triggering update_id."""
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("INVOCATION_ID", raising=False)
runner, _adapter = make_restart_runner()
runner.request_restart = MagicMock(return_value=True)
event = _make_restart_event(update_id=12345)
result = await runner._handle_restart_command(event)
assert "Restarting gateway" in result
marker_path = tmp_path / ".restart_last_processed.json"
assert marker_path.exists()
data = json.loads(marker_path.read_text())
assert data["platform"] == "telegram"
assert data["update_id"] == 12345
assert isinstance(data["requested_at"], (int, float))
@pytest.mark.asyncio
async def test_redelivered_restart_with_same_update_id_is_ignored(tmp_path, monkeypatch):
"""A /restart with update_id <= recorded marker is silently ignored as a redelivery."""
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("INVOCATION_ID", raising=False)
# Previous gateway recorded update_id=12345 a few seconds ago
marker = tmp_path / ".restart_last_processed.json"
marker.write_text(json.dumps({
"platform": "telegram",
"update_id": 12345,
"requested_at": time.time() - 5,
}))
runner, _adapter = make_restart_runner()
runner.request_restart = MagicMock()
event = _make_restart_event(update_id=12345) # same update_id → redelivery
result = await runner._handle_restart_command(event)
assert result == "" # silently ignored
runner.request_restart.assert_not_called()
@pytest.mark.asyncio
async def test_redelivered_restart_with_older_update_id_is_ignored(tmp_path, monkeypatch):
"""update_id strictly LESS than the recorded one is also a redelivery."""
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("INVOCATION_ID", raising=False)
marker = tmp_path / ".restart_last_processed.json"
marker.write_text(json.dumps({
"platform": "telegram",
"update_id": 12345,
"requested_at": time.time() - 5,
}))
runner, _adapter = make_restart_runner()
runner.request_restart = MagicMock()
event = _make_restart_event(update_id=12344) # older update — shouldn't happen,
# but if Telegram does re-deliver
# something older, treat as stale
result = await runner._handle_restart_command(event)
assert result == ""
runner.request_restart.assert_not_called()
@pytest.mark.asyncio
async def test_fresh_restart_with_higher_update_id_is_processed(tmp_path, monkeypatch):
"""A NEW /restart from the user (higher update_id) bypasses the dedup guard."""
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("INVOCATION_ID", raising=False)
# Previous restart recorded update_id=12345
marker = tmp_path / ".restart_last_processed.json"
marker.write_text(json.dumps({
"platform": "telegram",
"update_id": 12345,
"requested_at": time.time() - 5,
}))
runner, _adapter = make_restart_runner()
runner.request_restart = MagicMock(return_value=True)
event = _make_restart_event(update_id=12346) # strictly higher → fresh
result = await runner._handle_restart_command(event)
assert "Restarting gateway" in result
runner.request_restart.assert_called_once()
# Marker is overwritten with the new update_id
data = json.loads(marker.read_text())
assert data["update_id"] == 12346
@pytest.mark.asyncio
async def test_stale_marker_older_than_5min_does_not_block(tmp_path, monkeypatch):
"""A marker older than the 5-minute window is ignored — fresh /restart proceeds."""
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("INVOCATION_ID", raising=False)
marker = tmp_path / ".restart_last_processed.json"
marker.write_text(json.dumps({
"platform": "telegram",
"update_id": 12345,
"requested_at": time.time() - 600, # 10 minutes ago
}))
runner, _adapter = make_restart_runner()
runner.request_restart = MagicMock(return_value=True)
# Same update_id as the stale marker, but the marker is too old to trust
event = _make_restart_event(update_id=12345)
result = await runner._handle_restart_command(event)
assert "Restarting gateway" in result
runner.request_restart.assert_called_once()
@pytest.mark.asyncio
async def test_no_marker_file_allows_restart(tmp_path, monkeypatch):
"""Clean gateway start (no prior marker) processes /restart normally."""
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("INVOCATION_ID", raising=False)
runner, _adapter = make_restart_runner()
runner.request_restart = MagicMock(return_value=True)
event = _make_restart_event(update_id=100)
result = await runner._handle_restart_command(event)
assert "Restarting gateway" in result
runner.request_restart.assert_called_once()
@pytest.mark.asyncio
async def test_corrupt_marker_file_is_treated_as_absent(tmp_path, monkeypatch):
"""Malformed JSON in the marker file doesn't crash — /restart proceeds."""
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("INVOCATION_ID", raising=False)
marker = tmp_path / ".restart_last_processed.json"
marker.write_text("not-json{")
runner, _adapter = make_restart_runner()
runner.request_restart = MagicMock(return_value=True)
event = _make_restart_event(update_id=100)
result = await runner._handle_restart_command(event)
assert "Restarting gateway" in result
runner.request_restart.assert_called_once()
@pytest.mark.asyncio
async def test_event_without_update_id_bypasses_dedup(tmp_path, monkeypatch):
"""Events with no platform_update_id (non-Telegram, CLI fallback) aren't gated."""
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("INVOCATION_ID", raising=False)
marker = tmp_path / ".restart_last_processed.json"
marker.write_text(json.dumps({
"platform": "telegram",
"update_id": 999999,
"requested_at": time.time(),
}))
runner, _adapter = make_restart_runner()
runner.request_restart = MagicMock(return_value=True)
# No update_id — the dedup check should NOT kick in
event = _make_restart_event(update_id=None)
result = await runner._handle_restart_command(event)
assert "Restarting gateway" in result
runner.request_restart.assert_called_once()
@pytest.mark.asyncio
async def test_different_platform_bypasses_dedup(tmp_path, monkeypatch):
"""Marker from Telegram doesn't block a /restart from another platform."""
from gateway.config import Platform
from gateway.session import SessionSource
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("INVOCATION_ID", raising=False)
marker = tmp_path / ".restart_last_processed.json"
marker.write_text(json.dumps({
"platform": "telegram",
"update_id": 12345,
"requested_at": time.time(),
}))
runner, _adapter = make_restart_runner()
runner.request_restart = MagicMock(return_value=True)
# /restart from Discord — not a redelivery candidate
discord_source = SessionSource(
platform=Platform.DISCORD,
chat_id="discord-chan",
chat_type="dm",
user_id="u1",
)
event = MessageEvent(
text="/restart",
message_type=MessageType.TEXT,
source=discord_source,
message_id="m1",
platform_update_id=12345,
)
result = await runner._handle_restart_command(event)
assert "Restarting gateway" in result
runner.request_restart.assert_called_once()

View file

@ -0,0 +1,688 @@
"""Tests for the resume_pending session continuity path.
Covers the behaviour introduced to fix the ``Gateway shutting down ...
task will be interrupted`` follow-up bug (spec: PR #11852, builds on
PRs #9850, #9934, #7536):
1. When a gateway restart drain times out and agents are force-interrupted,
the affected sessions are flagged ``resume_pending=True`` not
``suspended`` so the next user message on the same session_key
auto-resumes from the existing transcript instead of getting routed
through ``suspend_recently_active()`` and converted into a fresh
session.
2. ``suspended=True`` (from ``/stop`` or stuck-loop escalation) still
wins over ``resume_pending`` the forced-wipe path is preserved.
3. The restart-resume system note injected into the next user message is
a superset of the existing tool-tail auto-continue note (from
PR #9934), using session-entry metadata rather than just transcript
shape so it fires even when the interrupted transcript does NOT end
with a ``tool`` role.
4. The existing ``.restart_failure_counts`` stuck-loop counter from
PR #7536 remains the single source of escalation — no parallel
counter is added on ``SessionEntry``.
"""
import asyncio
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.session import SessionEntry, SessionSource, SessionStore
from tests.gateway.restart_test_helpers import (
make_restart_runner,
make_restart_source,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_source(platform=Platform.TELEGRAM, chat_id="123", user_id="u1"):
return SessionSource(platform=platform, chat_id=chat_id, user_id=user_id)
def _make_store(tmp_path):
return SessionStore(sessions_dir=tmp_path, config=GatewayConfig())
def _simulate_note_injection(
agent_history: list,
user_message: str,
resume_entry: SessionEntry | None,
) -> str:
"""Mirror the note-injection logic in gateway/run.py _run_agent().
Matches the production code in the ``run_sync`` closure so we can
test the decision tree without a full gateway runner.
"""
message = user_message
is_resume_pending = bool(
resume_entry is not None and getattr(resume_entry, "resume_pending", False)
)
if is_resume_pending:
reason = getattr(resume_entry, "resume_reason", None) or "restart_timeout"
reason_phrase = (
"a gateway restart"
if reason == "restart_timeout"
else "a gateway shutdown"
if reason == "shutdown_timeout"
else "a gateway interruption"
)
message = (
f"[System note: Your previous turn in this session was interrupted "
f"by {reason_phrase}. The conversation history below is intact. "
f"If it contains unfinished tool result(s), process them first and "
f"summarize what was accomplished, then address the user's new "
f"message below.]\n\n"
+ message
)
elif agent_history and agent_history[-1].get("role") == "tool":
message = (
"[System note: Your previous turn was interrupted before you could "
"process the last tool result(s). The conversation history contains "
"tool outputs you haven't responded to yet. Please finish processing "
"those results and summarize what was accomplished, then address the "
"user's new message below.]\n\n"
+ message
)
return message
# ---------------------------------------------------------------------------
# SessionEntry field + serialization
# ---------------------------------------------------------------------------
class TestSessionEntryResumeFields:
def test_defaults(self):
now = datetime.now()
entry = SessionEntry(
session_key="agent:main:telegram:dm:1",
session_id="sid",
created_at=now,
updated_at=now,
)
assert entry.resume_pending is False
assert entry.resume_reason is None
assert entry.last_resume_marked_at is None
def test_roundtrip_with_resume_fields(self):
now = datetime(2026, 4, 18, 12, 0, 0)
entry = SessionEntry(
session_key="agent:main:telegram:dm:1",
session_id="sid",
created_at=now,
updated_at=now,
resume_pending=True,
resume_reason="restart_timeout",
last_resume_marked_at=now,
)
restored = SessionEntry.from_dict(entry.to_dict())
assert restored.resume_pending is True
assert restored.resume_reason == "restart_timeout"
assert restored.last_resume_marked_at == now
def test_from_dict_legacy_without_resume_fields(self):
"""Old sessions.json without the new fields deserialize cleanly."""
now = datetime.now()
legacy = {
"session_key": "agent:main:telegram:dm:1",
"session_id": "sid",
"created_at": now.isoformat(),
"updated_at": now.isoformat(),
"chat_type": "dm",
}
restored = SessionEntry.from_dict(legacy)
assert restored.resume_pending is False
assert restored.resume_reason is None
assert restored.last_resume_marked_at is None
def test_malformed_timestamp_is_tolerated(self):
now = datetime.now()
data = {
"session_key": "k",
"session_id": "sid",
"created_at": now.isoformat(),
"updated_at": now.isoformat(),
"resume_pending": True,
"resume_reason": "restart_timeout",
"last_resume_marked_at": "not-a-timestamp",
}
restored = SessionEntry.from_dict(data)
# resume_pending still honoured, only the broken timestamp drops
assert restored.resume_pending is True
assert restored.resume_reason == "restart_timeout"
assert restored.last_resume_marked_at is None
# ---------------------------------------------------------------------------
# SessionStore.mark_resume_pending / clear_resume_pending
# ---------------------------------------------------------------------------
class TestMarkResumePending:
def test_marks_existing_session(self, tmp_path):
store = _make_store(tmp_path)
source = _make_source()
entry = store.get_or_create_session(source)
assert store.mark_resume_pending(entry.session_key) is True
refreshed = store._entries[entry.session_key]
assert refreshed.resume_pending is True
assert refreshed.resume_reason == "restart_timeout"
assert refreshed.last_resume_marked_at is not None
def test_custom_reason_persists(self, tmp_path):
store = _make_store(tmp_path)
source = _make_source()
entry = store.get_or_create_session(source)
store.mark_resume_pending(entry.session_key, reason="shutdown_timeout")
assert store._entries[entry.session_key].resume_reason == "shutdown_timeout"
def test_returns_false_for_unknown_key(self, tmp_path):
store = _make_store(tmp_path)
assert store.mark_resume_pending("no-such-key") is False
def test_does_not_override_suspended(self, tmp_path):
"""suspended wins — mark_resume_pending is a no-op on a suspended entry."""
store = _make_store(tmp_path)
source = _make_source()
entry = store.get_or_create_session(source)
store.suspend_session(entry.session_key)
assert store.mark_resume_pending(entry.session_key) is False
e = store._entries[entry.session_key]
assert e.suspended is True
assert e.resume_pending is False
def test_survives_roundtrip_through_json(self, tmp_path):
store = _make_store(tmp_path)
source = _make_source()
entry = store.get_or_create_session(source)
store.mark_resume_pending(entry.session_key, reason="restart_timeout")
# Reload from disk
store2 = _make_store(tmp_path)
store2._ensure_loaded()
reloaded = store2._entries[entry.session_key]
assert reloaded.resume_pending is True
assert reloaded.resume_reason == "restart_timeout"
class TestClearResumePending:
def test_clears_flag(self, tmp_path):
store = _make_store(tmp_path)
source = _make_source()
entry = store.get_or_create_session(source)
store.mark_resume_pending(entry.session_key)
assert store.clear_resume_pending(entry.session_key) is True
e = store._entries[entry.session_key]
assert e.resume_pending is False
assert e.resume_reason is None
assert e.last_resume_marked_at is None
def test_returns_false_when_not_pending(self, tmp_path):
store = _make_store(tmp_path)
source = _make_source()
entry = store.get_or_create_session(source)
# Not marked
assert store.clear_resume_pending(entry.session_key) is False
def test_returns_false_for_unknown_key(self, tmp_path):
store = _make_store(tmp_path)
assert store.clear_resume_pending("no-such-key") is False
# ---------------------------------------------------------------------------
# SessionStore.get_or_create_session resume_pending behaviour
# ---------------------------------------------------------------------------
class TestGetOrCreateResumePending:
def test_resume_pending_preserves_session_id(self, tmp_path):
"""This is THE core behavioural fix — resume_pending ≠ new session."""
store = _make_store(tmp_path)
source = _make_source()
first = store.get_or_create_session(source)
original_sid = first.session_id
store.mark_resume_pending(first.session_key)
second = store.get_or_create_session(source)
assert second.session_id == original_sid
assert second.was_auto_reset is False
assert second.auto_reset_reason is None
# Flag is NOT cleared on read — only on successful turn completion.
assert second.resume_pending is True
def test_suspended_still_creates_new_session(self, tmp_path):
"""Regression guard — suspended must still force a clean slate."""
store = _make_store(tmp_path)
source = _make_source()
first = store.get_or_create_session(source)
original_sid = first.session_id
store.suspend_session(first.session_key)
second = store.get_or_create_session(source)
assert second.session_id != original_sid
assert second.was_auto_reset is True
assert second.auto_reset_reason == "suspended"
def test_suspended_overrides_resume_pending(self, tmp_path):
"""Terminal escalation: a session that somehow has BOTH flags must
behave like ``suspended`` forced wipe + auto_reset_reason."""
store = _make_store(tmp_path)
source = _make_source()
first = store.get_or_create_session(source)
original_sid = first.session_id
# Force the pathological state directly (normally mark_resume_pending
# refuses to run when suspended=True, but a stuck-loop escalation
# can set suspended=True AFTER resume_pending is set).
with store._lock:
e = store._entries[first.session_key]
e.resume_pending = True
e.resume_reason = "restart_timeout"
e.suspended = True
store._save()
second = store.get_or_create_session(source)
assert second.session_id != original_sid
assert second.was_auto_reset is True
assert second.auto_reset_reason == "suspended"
# ---------------------------------------------------------------------------
# SessionStore.suspend_recently_active skip behaviour
# ---------------------------------------------------------------------------
class TestSuspendRecentlyActiveSkipsResumePending:
def test_resume_pending_entries_not_suspended(self, tmp_path):
store = _make_store(tmp_path)
source = _make_source()
entry = store.get_or_create_session(source)
store.mark_resume_pending(entry.session_key)
count = store.suspend_recently_active()
assert count == 0
e = store._entries[entry.session_key]
assert e.suspended is False
assert e.resume_pending is True
def test_non_resume_pending_still_suspended(self, tmp_path):
"""Non-resume sessions still get the old crash-recovery suspension."""
store = _make_store(tmp_path)
source_a = _make_source(chat_id="a")
source_b = _make_source(chat_id="b")
entry_a = store.get_or_create_session(source_a)
entry_b = store.get_or_create_session(source_b)
store.mark_resume_pending(entry_a.session_key)
count = store.suspend_recently_active()
assert count == 1
assert store._entries[entry_a.session_key].suspended is False
assert store._entries[entry_b.session_key].suspended is True
# ---------------------------------------------------------------------------
# Restart-resume system-note injection
# ---------------------------------------------------------------------------
class TestResumePendingSystemNote:
def _pending_entry(self, reason="restart_timeout") -> SessionEntry:
now = datetime.now()
return SessionEntry(
session_key="agent:main:telegram:dm:1",
session_id="sid",
created_at=now,
updated_at=now,
resume_pending=True,
resume_reason=reason,
last_resume_marked_at=now,
)
def test_resume_pending_restart_note_mentions_restart(self):
entry = self._pending_entry(reason="restart_timeout")
result = _simulate_note_injection(
agent_history=[{"role": "assistant", "content": "in progress"}],
user_message="what happened?",
resume_entry=entry,
)
assert "[System note:" in result
assert "gateway restart" in result
assert "what happened?" in result
def test_resume_pending_shutdown_note_mentions_shutdown(self):
entry = self._pending_entry(reason="shutdown_timeout")
result = _simulate_note_injection(
agent_history=[{"role": "assistant", "content": "in progress"}],
user_message="ping",
resume_entry=entry,
)
assert "gateway shutdown" in result
def test_resume_pending_fires_without_tool_tail(self):
"""Key improvement over PR #9934: the restart-resume note fires
even when the transcript's last role is NOT ``tool``."""
entry = self._pending_entry()
history = [
{"role": "user", "content": "run a long thing"},
{"role": "assistant", "content": "ok, starting..."},
]
result = _simulate_note_injection(history, "ping", resume_entry=entry)
assert "[System note:" in result
assert "gateway restart" in result
def test_resume_pending_subsumes_tool_tail_note(self):
"""When BOTH conditions are true, the restart-resume note wins —
no duplicate notes."""
entry = self._pending_entry()
history = [
{"role": "assistant", "content": None, "tool_calls": [
{"id": "c1", "function": {"name": "x", "arguments": "{}"}},
]},
{"role": "tool", "tool_call_id": "c1", "content": "result"},
]
result = _simulate_note_injection(history, "ping", resume_entry=entry)
assert result.count("[System note:") == 1
assert "gateway restart" in result
# Old tool-tail wording absent
assert "haven't responded to yet" not in result
def test_no_resume_pending_preserves_tool_tail_note(self):
"""Regression: the old PR #9934 tool-tail behaviour is unchanged."""
history = [
{"role": "assistant", "content": None, "tool_calls": [
{"id": "c1", "function": {"name": "x", "arguments": "{}"}},
]},
{"role": "tool", "tool_call_id": "c1", "content": "result"},
]
result = _simulate_note_injection(history, "ping", resume_entry=None)
assert "[System note:" in result
assert "tool result" in result
def test_no_note_when_nothing_to_resume(self):
history = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
]
result = _simulate_note_injection(history, "ping", resume_entry=None)
assert result == "ping"
# ---------------------------------------------------------------------------
# Drain-timeout path marks sessions resume_pending
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_drain_timeout_marks_resume_pending():
"""End-to-end: a drain timeout during gateway stop should flag every
active session as resume_pending BEFORE the interrupt fires, so the
next startup's suspend_recently_active() does not destroy them."""
runner, adapter = make_restart_runner()
adapter.disconnect = AsyncMock()
runner._restart_drain_timeout = 0.05
running_agent = MagicMock()
session_key_one = "agent:main:telegram:dm:A"
session_key_two = "agent:main:telegram:dm:B"
runner._running_agents = {
session_key_one: running_agent,
session_key_two: MagicMock(),
}
# Plug a mock session_store that records marks.
session_store = MagicMock()
session_store.mark_resume_pending = MagicMock(return_value=True)
runner.session_store = session_store
with patch("gateway.status.remove_pid_file"), patch(
"gateway.status.write_runtime_status"
):
await runner.stop()
# Both active sessions were marked with the shutdown_timeout reason.
calls = session_store.mark_resume_pending.call_args_list
marked = {args[0][0] for args in calls}
assert marked == {session_key_one, session_key_two}
for args in calls:
assert args[0][1] == "shutdown_timeout"
@pytest.mark.asyncio
async def test_drain_timeout_uses_restart_reason_when_restarting():
runner, adapter = make_restart_runner()
adapter.disconnect = AsyncMock()
runner._restart_drain_timeout = 0.05
runner._restart_requested = True
running_agent = MagicMock()
runner._running_agents = {"agent:main:telegram:dm:A": running_agent}
session_store = MagicMock()
session_store.mark_resume_pending = MagicMock(return_value=True)
runner.session_store = session_store
with patch("gateway.status.remove_pid_file"), patch(
"gateway.status.write_runtime_status"
):
await runner.stop(restart=True, detached_restart=False, service_restart=True)
calls = session_store.mark_resume_pending.call_args_list
assert calls, "expected at least one mark_resume_pending call"
for args in calls:
assert args[0][1] == "restart_timeout"
@pytest.mark.asyncio
async def test_clean_drain_does_not_mark_resume_pending():
"""If the drain completes within timeout (no force-interrupt), no
sessions should be flagged the normal shutdown path is unchanged."""
runner, adapter = make_restart_runner()
adapter.disconnect = AsyncMock()
running_agent = MagicMock()
runner._running_agents = {"agent:main:telegram:dm:A": running_agent}
# Finish the agent before the (generous) drain deadline
async def finish_agent():
await asyncio.sleep(0.05)
runner._running_agents.clear()
asyncio.create_task(finish_agent())
session_store = MagicMock()
session_store.mark_resume_pending = MagicMock(return_value=True)
runner.session_store = session_store
with patch("gateway.status.remove_pid_file"), patch(
"gateway.status.write_runtime_status"
):
await runner.stop()
session_store.mark_resume_pending.assert_not_called()
running_agent.interrupt.assert_not_called()
@pytest.mark.asyncio
async def test_drain_timeout_only_marks_still_running_sessions():
"""A session that finished gracefully during the drain window must
NOT be marked ``resume_pending`` it completed cleanly and its
next turn should be a normal fresh turn, not one prefixed with the
restart-interruption system note.
Regression guard for using ``self._running_agents`` at timeout
rather than the ``active_agents`` drain-start snapshot.
"""
runner, adapter = make_restart_runner()
adapter.disconnect = AsyncMock()
# Long enough for the finisher to exit, short enough to still time out
# with the stuck session still present.
runner._restart_drain_timeout = 0.3
session_key_finisher = "agent:main:telegram:dm:A"
session_key_stuck = "agent:main:telegram:dm:B"
runner._running_agents = {
session_key_finisher: MagicMock(),
session_key_stuck: MagicMock(),
}
async def finish_one():
await asyncio.sleep(0.05)
runner._running_agents.pop(session_key_finisher, None)
asyncio.create_task(finish_one())
session_store = MagicMock()
session_store.mark_resume_pending = MagicMock(return_value=True)
runner.session_store = session_store
with patch("gateway.status.remove_pid_file"), patch(
"gateway.status.write_runtime_status"
):
await runner.stop()
calls = session_store.mark_resume_pending.call_args_list
marked = {args[0][0] for args in calls}
# Only the session still running at timeout is marked; the finisher is not.
assert marked == {session_key_stuck}
@pytest.mark.asyncio
async def test_drain_timeout_skips_pending_sentinel_sessions():
"""Pending sentinels — sessions whose AIAgent construction hasn't
produced a real agent yet are skipped by
``_interrupt_running_agents()``. The resume_pending marking must
mirror that: no agent started means no turn was interrupted.
"""
from gateway.run import _AGENT_PENDING_SENTINEL
runner, adapter = make_restart_runner()
adapter.disconnect = AsyncMock()
runner._restart_drain_timeout = 0.05
session_key_real = "agent:main:telegram:dm:A"
session_key_sentinel = "agent:main:telegram:dm:B"
runner._running_agents = {
session_key_real: MagicMock(),
session_key_sentinel: _AGENT_PENDING_SENTINEL,
}
session_store = MagicMock()
session_store.mark_resume_pending = MagicMock(return_value=True)
runner.session_store = session_store
with patch("gateway.status.remove_pid_file"), patch(
"gateway.status.write_runtime_status"
):
await runner.stop()
calls = session_store.mark_resume_pending.call_args_list
marked = {args[0][0] for args in calls}
assert marked == {session_key_real}
# ---------------------------------------------------------------------------
# Shutdown banner wording
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_restart_banner_uses_try_to_resume_wording():
"""The notification sent before drain should hedge the resume promise
the session-continuity fix is best-effort (stuck-loop counter can
still escalate to suspended)."""
runner, adapter = make_restart_runner()
runner._restart_requested = True
runner._running_agents["agent:main:telegram:dm:999"] = MagicMock()
await runner._notify_active_sessions_of_shutdown()
assert len(adapter.sent) == 1
msg = adapter.sent[0]
assert "restarting" in msg
assert "try to resume" in msg
# ---------------------------------------------------------------------------
# Stuck-loop escalation integration
# ---------------------------------------------------------------------------
class TestStuckLoopEscalation:
"""The existing .restart_failure_counts counter (PR #7536) remains the
single source of terminal escalation no parallel counter on
SessionEntry was added. After the configured threshold, the startup
path flips suspended=True which overrides resume_pending."""
def test_escalation_via_stuck_loop_counter_overrides_resume_pending(
self, tmp_path, monkeypatch
):
"""Simulate a session that keeps getting restart-interrupted and
hits the stuck-loop threshold: next startup should force it to
fresh-session despite resume_pending being set."""
import json
from gateway.run import GatewayRunner
store = _make_store(tmp_path)
source = _make_source()
entry = store.get_or_create_session(source)
store.mark_resume_pending(entry.session_key, reason="restart_timeout")
# Simulate counter already at threshold (3 consecutive interrupted
# restarts). _suspend_stuck_loop_sessions will flip suspended=True.
counts_file = tmp_path / ".restart_failure_counts"
counts_file.write_text(json.dumps({entry.session_key: 3}))
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
runner = object.__new__(GatewayRunner)
runner.session_store = store
suspended_count = GatewayRunner._suspend_stuck_loop_sessions(runner)
assert suspended_count == 1
assert store._entries[entry.session_key].suspended is True
# resume_pending is still set on the entry, but suspended wins in
# get_or_create_session so the next message still gets a new sid.
second = store.get_or_create_session(source)
assert second.session_id != entry.session_id
assert second.auto_reset_reason == "suspended"
def test_successful_turn_flow_clears_both_counter_and_resume_pending(
self, tmp_path, monkeypatch
):
"""The gateway's post-turn cleanup should clear both signals so a
future restart-interrupt starts with a fresh counter."""
import json
from gateway.run import GatewayRunner
store = _make_store(tmp_path)
source = _make_source()
entry = store.get_or_create_session(source)
store.mark_resume_pending(entry.session_key, reason="restart_timeout")
counts_file = tmp_path / ".restart_failure_counts"
counts_file.write_text(json.dumps({entry.session_key: 2}))
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
runner = object.__new__(GatewayRunner)
runner.session_store = store
GatewayRunner._clear_restart_failure_count(runner, entry.session_key)
store.clear_resume_pending(entry.session_key)
assert store._entries[entry.session_key].resume_pending is False
assert not counts_file.exists()

View file

@ -51,6 +51,9 @@ class ProgressCaptureAdapter(BasePlatformAdapter):
async def send_typing(self, chat_id, metadata=None) -> None:
self.typing.append({"chat_id": chat_id, "metadata": metadata})
async def stop_typing(self, chat_id) -> None:
self.typing.append({"chat_id": chat_id, "metadata": {"stopped": True}})
async def get_chat_info(self, chat_id: str):
return {"id": chat_id}
@ -90,6 +93,40 @@ class LongPreviewAgent:
}
class DelayedProgressAgent:
def __init__(self, **kwargs):
self.tool_progress_callback = kwargs.get("tool_progress_callback")
self.tools = []
def run_conversation(self, message, conversation_history=None, task_id=None):
self.tool_progress_callback("tool.started", "terminal", "first command", {})
time.sleep(0.45)
self.tool_progress_callback("tool.started", "terminal", "second command", {})
time.sleep(0.1)
return {
"final_response": "done",
"messages": [],
"api_calls": 1,
}
class DelayedInterimAgent:
def __init__(self, **kwargs):
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
self.tools = []
def run_conversation(self, message, conversation_history=None, task_id=None):
self.interim_assistant_callback("first interim")
time.sleep(0.45)
self.interim_assistant_callback("second interim")
time.sleep(0.1)
return {
"final_response": "done",
"messages": [],
"api_calls": 1,
}
def _make_runner(adapter):
gateway_run = importlib.import_module("gateway.run")
GatewayRunner = gateway_run.GatewayRunner
@ -104,6 +141,7 @@ def _make_runner(adapter):
runner._fallback_model = None
runner._session_db = None
runner._running_agents = {}
runner._session_run_generation = {}
runner.hooks = SimpleNamespace(loaded_hooks=False)
runner.config = SimpleNamespace(
thread_sessions_per_user=False,
@ -744,6 +782,154 @@ async def test_base_processing_releases_post_delivery_callback_after_main_send()
assert released == [True]
@pytest.mark.asyncio
async def test_run_agent_drops_tool_progress_after_generation_invalidation(monkeypatch, tmp_path):
import yaml
(tmp_path / "config.yaml").write_text(
yaml.dump({"display": {"tool_progress": "all"}}),
encoding="utf-8",
)
fake_dotenv = types.ModuleType("dotenv")
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = DelayedProgressAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
import tools.terminal_tool # noqa: F401 - register terminal tool metadata
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
runner = _make_runner(adapter)
gateway_run = importlib.import_module("gateway.run")
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
source = SessionSource(
platform=Platform.DISCORD,
chat_id="dm-1",
chat_type="dm",
thread_id=None,
)
session_key = "agent:main:discord:dm:dm-1"
runner._session_run_generation[session_key] = 1
original_send = adapter.send
invalidated = {"done": False}
async def send_and_invalidate(chat_id, content, reply_to=None, metadata=None):
result = await original_send(chat_id, content, reply_to=reply_to, metadata=metadata)
if "first command" in content and not invalidated["done"]:
invalidated["done"] = True
runner._invalidate_session_run_generation(session_key, reason="test_stop")
return result
adapter.send = send_and_invalidate
result = await runner._run_agent(
message="hello",
context_prompt="",
history=[],
source=source,
session_id="sess-progress-stop",
session_key=session_key,
run_generation=1,
)
all_progress_text = " ".join(call["content"] for call in adapter.sent)
all_progress_text += " ".join(call["content"] for call in adapter.edits)
assert result["final_response"] == "done"
assert 'first command' in all_progress_text
assert 'second command' not in all_progress_text
@pytest.mark.asyncio
async def test_run_agent_drops_interim_commentary_after_generation_invalidation(monkeypatch, tmp_path):
import yaml
(tmp_path / "config.yaml").write_text(
yaml.dump({"display": {"tool_progress": "off", "interim_assistant_messages": True}}),
encoding="utf-8",
)
fake_dotenv = types.ModuleType("dotenv")
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = DelayedInterimAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
runner = _make_runner(adapter)
gateway_run = importlib.import_module("gateway.run")
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
source = SessionSource(
platform=Platform.DISCORD,
chat_id="dm-2",
chat_type="dm",
thread_id=None,
)
session_key = "agent:main:discord:dm:dm-2"
runner._session_run_generation[session_key] = 1
original_send = adapter.send
invalidated = {"done": False}
async def send_and_invalidate(chat_id, content, reply_to=None, metadata=None):
result = await original_send(chat_id, content, reply_to=reply_to, metadata=metadata)
if content == "first interim" and not invalidated["done"]:
invalidated["done"] = True
runner._invalidate_session_run_generation(session_key, reason="test_stop")
return result
adapter.send = send_and_invalidate
result = await runner._run_agent(
message="hello",
context_prompt="",
history=[],
source=source,
session_id="sess-commentary-stop",
session_key=session_key,
run_generation=1,
)
sent_texts = [call["content"] for call in adapter.sent]
assert result["final_response"] == "done"
assert "first interim" in sent_texts
assert "second interim" not in sent_texts
@pytest.mark.asyncio
async def test_keep_typing_stops_immediately_when_interrupt_event_is_set():
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
stop_event = asyncio.Event()
task = asyncio.create_task(
adapter._keep_typing(
"dm-typing-stop",
interval=30.0,
stop_event=stop_event,
)
)
await asyncio.sleep(0.05)
stop_event.set()
await asyncio.wait_for(task, timeout=0.5)
normal_typing_calls = [
call for call in adapter.typing if call.get("metadata") != {"stopped": True}
]
stopped_calls = [
call for call in adapter.typing if call.get("metadata") == {"stopped": True}
]
assert len(normal_typing_calls) == 1
assert len(stopped_calls) == 1
@pytest.mark.asyncio
async def test_verbose_mode_does_not_truncate_args_by_default(monkeypatch, tmp_path):
"""Verbose mode with default tool_preview_length (0) should NOT truncate args.

View file

@ -319,3 +319,23 @@ async def test_start_gateway_replace_clears_marker_on_permission_denied(
assert ok is False
# Marker must NOT be left behind
assert not (tmp_path / ".gateway-takeover.json").exists()
def test_runner_warns_when_docker_gateway_lacks_explicit_output_mount(monkeypatch, tmp_path, caplog):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.setenv("TERMINAL_ENV", "docker")
monkeypatch.setenv("TERMINAL_DOCKER_VOLUMES", '["/etc/localtime:/etc/localtime:ro"]')
config = GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")
},
sessions_dir=tmp_path / "sessions",
)
with caplog.at_level("WARNING"):
GatewayRunner(config)
assert any(
"host-visible output mount" in record.message
for record in caplog.records
)

View file

@ -0,0 +1,59 @@
"""Regression tests: failed-connect path must call adapter.disconnect().
When adapter.connect() returns False or raises, the adapter may have
allocated resources (aiohttp.ClientSession, poll tasks, child
subprocesses) before giving up. Without a defensive disconnect() call
these leak and surface as "Unclosed client session" warnings at
process exit (seen on the 2026-04-18 18:08:16 gateway restart).
The fix: gateway/run.py wraps each adapter connect() with a safety-net
call to _safe_adapter_disconnect() in the failure branches.
"""
from unittest.mock import AsyncMock, MagicMock
import pytest
from gateway.config import Platform
from gateway.run import GatewayRunner
@pytest.fixture
def bare_runner():
"""A GatewayRunner shell that only needs to support _safe_adapter_disconnect."""
return object.__new__(GatewayRunner)
@pytest.mark.asyncio
async def test_safe_disconnect_calls_adapter_disconnect(bare_runner):
"""The helper forwards to adapter.disconnect()."""
adapter = MagicMock()
adapter.disconnect = AsyncMock(return_value=None)
await bare_runner._safe_adapter_disconnect(adapter, Platform.TELEGRAM)
adapter.disconnect.assert_awaited_once()
@pytest.mark.asyncio
async def test_safe_disconnect_swallows_exceptions(bare_runner):
"""An exception in adapter.disconnect() must not propagate — the
caller is already on an error path."""
adapter = MagicMock()
adapter.disconnect = AsyncMock(side_effect=RuntimeError("partial init"))
# Must NOT raise
await bare_runner._safe_adapter_disconnect(adapter, Platform.TELEGRAM)
adapter.disconnect.assert_awaited_once()
@pytest.mark.asyncio
async def test_safe_disconnect_handles_none_platform(bare_runner):
"""Logging path must tolerate platform=None."""
adapter = MagicMock()
adapter.disconnect = AsyncMock(side_effect=ValueError("nope"))
await bare_runner._safe_adapter_disconnect(adapter, None)
adapter.disconnect.assert_awaited_once()

View file

@ -24,10 +24,18 @@ class _FakeAdapter:
def __init__(self):
self._pending_messages = {}
self._active_sessions = {}
self.interrupted_sessions = []
async def send(self, chat_id, text, **kwargs):
pass
async def interrupt_session_activity(self, session_key, chat_id):
self.interrupted_sessions.append((session_key, chat_id))
event = self._active_sessions.get(session_key)
if event is not None:
event.set()
def _make_runner():
runner = object.__new__(GatewayRunner)
@ -37,6 +45,7 @@ def _make_runner():
runner.adapters = {Platform.TELEGRAM: _FakeAdapter()}
runner._running_agents = {}
runner._running_agents_ts = {}
runner._session_run_generation = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._voice_mode = {}
@ -81,7 +90,7 @@ async def test_sentinel_placed_before_agent_setup():
# Patch _handle_message_with_agent to capture state at entry
sentinel_was_set = False
async def mock_inner(self_inner, ev, src, qk):
async def mock_inner(self_inner, ev, src, qk, generation):
nonlocal sentinel_was_set
sentinel_was_set = runner._running_agents.get(qk) is _AGENT_PENDING_SENTINEL
return "ok"
@ -105,7 +114,7 @@ async def test_sentinel_cleaned_up_after_handler_returns():
event = _make_event()
session_key = build_session_key(event.source)
async def mock_inner(self_inner, ev, src, qk):
async def mock_inner(self_inner, ev, src, qk, generation):
return "ok"
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
@ -127,7 +136,7 @@ async def test_sentinel_cleaned_up_on_exception():
event = _make_event()
session_key = build_session_key(event.source)
async def mock_inner(self_inner, ev, src, qk):
async def mock_inner(self_inner, ev, src, qk, generation):
raise RuntimeError("boom")
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
@ -154,7 +163,7 @@ async def test_second_message_during_sentinel_queued_not_duplicate():
barrier = asyncio.Event()
async def slow_inner(self_inner, ev, src, qk):
async def slow_inner(self_inner, ev, src, qk, generation):
# Simulate slow setup — wait until test tells us to proceed
await barrier.wait()
return "ok"
@ -333,7 +342,7 @@ async def test_stop_during_sentinel_force_cleans_session():
barrier = asyncio.Event()
async def slow_inner(self_inner, ev, src, qk):
async def slow_inner(self_inner, ev, src, qk, generation):
await barrier.wait()
return "ok"
@ -381,6 +390,7 @@ async def test_stop_hard_kills_running_agent():
fake_agent = MagicMock()
fake_agent.get_activity_summary.return_value = {"seconds_since_activity": 0}
runner._running_agents[session_key] = fake_agent
runner.adapters[Platform.TELEGRAM]._active_sessions[session_key] = asyncio.Event()
# Send /stop
stop_event = _make_event(text="/stop")
@ -393,6 +403,10 @@ async def test_stop_hard_kills_running_agent():
assert session_key not in runner._running_agents, (
"/stop must remove the agent from _running_agents so the session is unlocked"
)
assert runner.adapters[Platform.TELEGRAM].interrupted_sessions == [
(session_key, "12345")
]
assert runner.adapters[Platform.TELEGRAM]._active_sessions[session_key].is_set()
# Must return a confirmation
assert result is not None

View file

@ -740,3 +740,140 @@ class TestSignalStopTyping:
await adapter.stop_typing("+155****4567")
adapter._stop_typing_indicator.assert_awaited_once_with("+155****4567")
# ---------------------------------------------------------------------------
# Typing-indicator backoff on repeated failures (Signal RPC spam fix)
# ---------------------------------------------------------------------------
class TestSignalTypingBackoff:
"""When base.py's _keep_typing refresh loop calls send_typing every ~2s
and the recipient is unreachable (NETWORK_FAILURE), the adapter must:
- log WARNING only for the first failure (subsequent failures use DEBUG
via log_failures=False on the _rpc call)
- after 3 consecutive failures, skip the RPC entirely during an
exponential cooldown window instead of hammering signal-cli every 2s
- reset counters on a successful sendTyping
- reset counters when _stop_typing_indicator() is called for the chat
"""
@pytest.mark.asyncio
async def test_first_failure_logs_at_warning_subsequent_at_debug(
self, monkeypatch
):
adapter = _make_signal_adapter(monkeypatch)
calls = []
async def _fake_rpc(method, params, rpc_id=None, *, log_failures=True):
calls.append({"log_failures": log_failures})
return None # simulate NETWORK_FAILURE
adapter._rpc = _fake_rpc
await adapter.send_typing("+155****4567")
await adapter.send_typing("+155****4567")
assert len(calls) == 2
assert calls[0]["log_failures"] is True # first failure — warn
assert calls[1]["log_failures"] is False # subsequent — debug
@pytest.mark.asyncio
async def test_three_consecutive_failures_trigger_cooldown(
self, monkeypatch
):
adapter = _make_signal_adapter(monkeypatch)
call_count = {"n": 0}
async def _fake_rpc(method, params, rpc_id=None, *, log_failures=True):
call_count["n"] += 1
return None
adapter._rpc = _fake_rpc
# Three failures engage the cooldown.
await adapter.send_typing("+155****4567")
await adapter.send_typing("+155****4567")
await adapter.send_typing("+155****4567")
assert call_count["n"] == 3
assert "+155****4567" in adapter._typing_skip_until
# Fourth, fifth, ... calls during the cooldown window are short-
# circuited — the RPC is not issued at all.
await adapter.send_typing("+155****4567")
await adapter.send_typing("+155****4567")
assert call_count["n"] == 3
@pytest.mark.asyncio
async def test_cooldown_is_per_chat_not_global(self, monkeypatch):
adapter = _make_signal_adapter(monkeypatch)
call_log = []
async def _fake_rpc(method, params, rpc_id=None, *, log_failures=True):
call_log.append(params.get("recipient") or params.get("groupId"))
return None
adapter._rpc = _fake_rpc
# Drive chat A into cooldown.
for _ in range(3):
await adapter.send_typing("+155****4567")
assert "+155****4567" in adapter._typing_skip_until
# Chat B is unaffected — still makes RPCs.
await adapter.send_typing("+155****9999")
await adapter.send_typing("+155****9999")
assert "+155****9999" not in adapter._typing_skip_until
# Chat A cooldown untouched
assert "+155****4567" in adapter._typing_skip_until
@pytest.mark.asyncio
async def test_success_resets_failure_counter_and_cooldown(
self, monkeypatch
):
adapter = _make_signal_adapter(monkeypatch)
result_queue = [None, None, {"timestamp": 12345}]
call_log = []
async def _fake_rpc(method, params, rpc_id=None, *, log_failures=True):
call_log.append(log_failures)
return result_queue.pop(0)
adapter._rpc = _fake_rpc
await adapter.send_typing("+155****4567") # fail 1 — warn
await adapter.send_typing("+155****4567") # fail 2 — debug
await adapter.send_typing("+155****4567") # success — reset
assert adapter._typing_failures.get("+155****4567", 0) == 0
assert "+155****4567" not in adapter._typing_skip_until
# Next failure after recovery logs at WARNING again (fresh counter).
async def _fail(method, params, rpc_id=None, *, log_failures=True):
call_log.append(log_failures)
return None
adapter._rpc = _fail
await adapter.send_typing("+155****4567")
assert call_log[-1] is True # first failure in a fresh cycle
@pytest.mark.asyncio
async def test_stop_typing_indicator_clears_backoff_state(
self, monkeypatch
):
adapter = _make_signal_adapter(monkeypatch)
async def _fail(method, params, rpc_id=None, *, log_failures=True):
return None
adapter._rpc = _fail
for _ in range(3):
await adapter.send_typing("+155****4567")
assert adapter._typing_failures.get("+155****4567") == 3
assert "+155****4567" in adapter._typing_skip_until
await adapter._stop_typing_indicator("+155****4567")
assert "+155****4567" not in adapter._typing_failures
assert "+155****4567" not in adapter._typing_skip_until

View file

@ -50,6 +50,7 @@ def _make_runner(session_entry: SessionEntry):
runner.session_store.rewrite_transcript = MagicMock()
runner.session_store.update_session = MagicMock()
runner._running_agents = {}
runner._session_run_generation = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._session_db = MagicMock()
@ -223,6 +224,121 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch):
)
@pytest.mark.asyncio
async def test_handle_message_discards_stale_result_after_session_invalidation(monkeypatch):
import gateway.run as gateway_run
session_entry = SessionEntry(
session_key=build_session_key(_make_source()),
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
runner = _make_runner(session_entry)
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
session_key = session_entry.session_key
runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks = {session_key: object()}
async def _stale_result(**kwargs):
runner._invalidate_session_run_generation(kwargs["session_key"], reason="test_stale_result")
return {
"final_response": "late reply",
"messages": [],
"tools": [],
"history_offset": 0,
"last_prompt_tokens": 80,
"input_tokens": 120,
"output_tokens": 45,
"model": "openai/test-model",
}
runner._run_agent = AsyncMock(side_effect=_stale_result)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
monkeypatch.setattr(
"agent.model_metadata.get_model_context_length",
lambda *_args, **_kwargs: 100000,
)
result = await runner._handle_message(_make_event("hello"))
assert result is None
runner.session_store.append_to_transcript.assert_not_called()
runner.session_store.update_session.assert_not_called()
assert session_key not in runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks
@pytest.mark.asyncio
async def test_handle_message_stale_result_keeps_newer_generation_callback(monkeypatch):
import gateway.run as gateway_run
class _Adapter:
def __init__(self):
self._post_delivery_callbacks = {}
async def send(self, *args, **kwargs):
return None
def pop_post_delivery_callback(self, session_key, *, generation=None):
entry = self._post_delivery_callbacks.get(session_key)
if entry is None:
return None
if isinstance(entry, tuple):
entry_generation, callback = entry
if generation is not None and entry_generation != generation:
return None
self._post_delivery_callbacks.pop(session_key, None)
return callback
if generation is not None:
return None
return self._post_delivery_callbacks.pop(session_key, None)
session_entry = SessionEntry(
session_key=build_session_key(_make_source()),
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
runner = _make_runner(session_entry)
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
session_key = session_entry.session_key
adapter = _Adapter()
runner.adapters[Platform.TELEGRAM] = adapter
async def _stale_result(**kwargs):
# Simulate a newer run claiming the callback slot before the stale run unwinds.
runner._session_run_generation[session_key] = 2
adapter._post_delivery_callbacks[session_key] = (2, lambda: None)
return {
"final_response": "late reply",
"messages": [],
"tools": [],
"history_offset": 0,
"last_prompt_tokens": 80,
"input_tokens": 120,
"output_tokens": 45,
"model": "openai/test-model",
}
runner._run_agent = AsyncMock(side_effect=_stale_result)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
monkeypatch.setattr(
"agent.model_metadata.get_model_context_length",
lambda *_args, **_kwargs: 100000,
)
result = await runner._handle_message(_make_event("hello"))
assert result is None
assert session_key in adapter._post_delivery_callbacks
assert adapter._post_delivery_callbacks[session_key][0] == 2
@pytest.mark.asyncio
async def test_status_command_bypasses_active_session_guard():

View file

@ -0,0 +1,191 @@
"""Tests for the gateway /steer command handler.
/steer injects a user message into the agent's next tool result without
interrupting. The gateway runner must:
1. When an agent IS running call ``agent.steer(text)``, do NOT set
``_interrupt_requested``, do NOT touch ``_pending_messages``.
2. When the agent is the PENDING sentinel fall back to /queue
semantics (store in ``adapter._pending_messages``).
3. When no agent is active strip the slash prefix and let the normal
prompt pipeline handle it as a regular user message.
"""
from __future__ import annotations
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent
from gateway.session import SessionEntry, SessionSource, build_session_key
def _make_source() -> SessionSource:
return SessionSource(
platform=Platform.TELEGRAM,
user_id="u1",
chat_id="c1",
user_name="tester",
chat_type="dm",
)
def _make_event(text: str) -> MessageEvent:
return MessageEvent(
text=text,
source=_make_source(),
message_id="m1",
)
def _make_runner(session_entry: SessionEntry):
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
)
adapter = MagicMock()
adapter.send = AsyncMock()
adapter._pending_messages = {}
runner.adapters = {Platform.TELEGRAM: adapter}
runner._voice_mode = {}
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
runner.session_store = MagicMock()
runner.session_store.get_or_create_session.return_value = session_entry
runner.session_store.load_transcript.return_value = []
runner.session_store.has_any_sessions.return_value = True
runner._running_agents = {}
runner._running_agents_ts = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._session_db = MagicMock()
runner._session_db.get_session_title.return_value = None
runner._reasoning_config = None
runner._provider_routing = {}
runner._fallback_model = None
runner._show_reasoning = False
runner._is_user_authorized = lambda _source: True
runner._set_session_env = lambda _context: None
runner._should_send_voice_reply = lambda *_args, **_kwargs: False
runner._send_voice_reply = AsyncMock()
runner._capture_gateway_honcho_if_configured = lambda *args, **kwargs: None
runner._emit_gateway_run_progress = AsyncMock()
return runner, adapter
def _session_entry() -> SessionEntry:
return SessionEntry(
session_key=build_session_key(_make_source()),
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
total_tokens=0,
)
@pytest.mark.asyncio
async def test_steer_calls_agent_steer_and_does_not_interrupt():
"""When an agent is running, /steer must call agent.steer(text) and
leave interrupt state untouched."""
runner, adapter = _make_runner(_session_entry())
sk = build_session_key(_make_source())
running_agent = MagicMock()
running_agent.steer.return_value = True
runner._running_agents[sk] = running_agent
result = await runner._handle_message(_make_event("/steer also check auth.log"))
# The handler replied with a confirmation
assert result is not None
assert "steer" in result.lower() or "queued" in result.lower()
# The agent's steer() was called with the payload (prefix stripped)
running_agent.steer.assert_called_once_with("also check auth.log")
# Critically: interrupt was NOT called
running_agent.interrupt.assert_not_called()
# And no user-text queueing happened — the steer doesn't go into
# _pending_messages (that would be turn-boundary /queue semantics).
assert runner._pending_messages == {}
assert adapter._pending_messages == {}
@pytest.mark.asyncio
async def test_steer_without_payload_returns_usage():
runner, _adapter = _make_runner(_session_entry())
sk = build_session_key(_make_source())
running_agent = MagicMock()
runner._running_agents[sk] = running_agent
result = await runner._handle_message(_make_event("/steer"))
assert result is not None
assert "Usage" in result or "usage" in result
running_agent.steer.assert_not_called()
running_agent.interrupt.assert_not_called()
@pytest.mark.asyncio
async def test_steer_with_pending_sentinel_falls_back_to_queue():
"""When the agent hasn't finished booting (sentinel), /steer should
queue as a turn-boundary follow-up instead of crashing."""
from gateway.run import _AGENT_PENDING_SENTINEL
runner, adapter = _make_runner(_session_entry())
sk = build_session_key(_make_source())
runner._running_agents[sk] = _AGENT_PENDING_SENTINEL
result = await runner._handle_message(_make_event("/steer wait up"))
assert result is not None
assert "queued" in result.lower() or "starting" in result.lower()
# The fallback put the text into the adapter's pending queue.
assert sk in adapter._pending_messages
assert adapter._pending_messages[sk].text == "wait up"
@pytest.mark.asyncio
async def test_steer_agent_without_steer_method_falls_back():
"""If the running agent somehow lacks the steer() method (older build,
test stub), the handler must not explode fall back to /queue."""
runner, adapter = _make_runner(_session_entry())
sk = build_session_key(_make_source())
# A bare object that does NOT have steer() — use a spec'd Mock so
# hasattr(agent, "steer") returns False.
running_agent = MagicMock(spec=[])
runner._running_agents[sk] = running_agent
result = await runner._handle_message(_make_event("/steer fallback"))
assert result is not None
# Must mention queueing since steer wasn't available
assert "queued" in result.lower()
assert sk in adapter._pending_messages
assert adapter._pending_messages[sk].text == "fallback"
@pytest.mark.asyncio
async def test_steer_rejected_payload_returns_rejection_message():
"""If agent.steer() returns False (e.g. empty after strip — though
the gateway already guards this), surface a rejection message."""
runner, _adapter = _make_runner(_session_entry())
sk = build_session_key(_make_source())
running_agent = MagicMock()
running_agent.steer.return_value = False
runner._running_agents[sk] = running_agent
result = await runner._handle_message(_make_event("/steer hello"))
assert result is not None
assert "rejected" in result.lower() or "empty" in result.lower()
if __name__ == "__main__": # pragma: no cover
pytest.main([__file__, "-v"])

View file

@ -502,11 +502,13 @@ class TestSegmentBreakOnToolBoundary:
@pytest.mark.asyncio
async def test_segment_break_clears_failed_edit_fallback_state(self):
"""A tool boundary after edit failure must not duplicate the next segment."""
"""A tool boundary after edit failure must flush the undelivered tail
without duplicating the prefix the user already saw (#8124)."""
adapter = MagicMock()
send_results = [
SimpleNamespace(success=True, message_id="msg_1"),
SimpleNamespace(success=True, message_id="msg_2"),
SimpleNamespace(success=True, message_id="msg_3"),
]
adapter.send = AsyncMock(side_effect=send_results)
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=False, error="flood_control:6"))
@ -526,7 +528,60 @@ class TestSegmentBreakOnToolBoundary:
await task
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
assert sent_texts == ["Hello ▉", "Next segment"]
# The undelivered "world" tail must reach the user, and the next
# segment must not duplicate "Hello" that was already visible.
assert sent_texts == ["Hello ▉", "world", "Next segment"]
@pytest.mark.asyncio
async def test_segment_break_after_mid_stream_edit_failure_preserves_tail(self):
"""Regression for #8124: when an earlier edit succeeded but later edits
fail (persistent flood control) and a tool boundary arrives before the
fallback threshold is reached, the pre-boundary tail must still be
delivered not silently dropped by the segment reset."""
adapter = MagicMock()
# msg_1 for the initial partial, msg_2 for the flushed tail,
# msg_3 for the post-boundary segment.
send_results = [
SimpleNamespace(success=True, message_id="msg_1"),
SimpleNamespace(success=True, message_id="msg_2"),
SimpleNamespace(success=True, message_id="msg_3"),
]
adapter.send = AsyncMock(side_effect=send_results)
# First two edits succeed, everything after fails with flood control
# — simulating Telegram's "edit once then get rate-limited" pattern.
edit_results = [
SimpleNamespace(success=True), # "Hello world ▉" — succeeds
SimpleNamespace(success=False, error="flood_control:6.0"), # "Hello world more ▉" — flood triggered
SimpleNamespace(success=False, error="flood_control:6.0"), # finalize edit at segment break
SimpleNamespace(success=False, error="flood_control:6.0"), # cursor-strip attempt
]
adapter.edit_message = AsyncMock(side_effect=edit_results + [edit_results[-1]] * 10)
adapter.MAX_MESSAGE_LENGTH = 4096
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor="")
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
consumer.on_delta("Hello")
task = asyncio.create_task(consumer.run())
await asyncio.sleep(0.08)
consumer.on_delta(" world")
await asyncio.sleep(0.08)
consumer.on_delta(" more")
await asyncio.sleep(0.08)
consumer.on_delta(None) # tool boundary
consumer.on_delta("Here is the tool result.")
consumer.finish()
await task
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
# "more" must have been delivered, not dropped.
all_text = " ".join(sent_texts)
assert "more" in all_text, (
f"Pre-boundary tail 'more' was silently dropped: sends={sent_texts}"
)
# Post-boundary text must also reach the user.
assert "Here is the tool result." in all_text
@pytest.mark.asyncio
async def test_no_message_id_enters_fallback_mode(self):
@ -1161,3 +1216,87 @@ class TestBufferOnlyMode:
# text, the consumer may send then edit, or just send once at got_done.
# The key assertion: this doesn't break.
assert adapter.send.call_count >= 1
# ── Cursor stripping on fallback (#7183) ────────────────────────────────────
class TestCursorStrippingOnFallback:
"""Regression: cursor must be stripped when fallback continuation is empty (#7183).
When _send_fallback_final is called with nothing new to deliver (the visible
partial already matches final_text), the last edit may still show the cursor
character because fallback mode was entered after a failed edit. Before the
fix this would leave the message permanently frozen with a visible .
"""
@pytest.mark.asyncio
async def test_cursor_stripped_when_continuation_empty(self):
"""_send_fallback_final must attempt a final edit to strip the cursor."""
adapter = MagicMock()
adapter.MAX_MESSAGE_LENGTH = 4096
adapter.edit_message = AsyncMock(
return_value=SimpleNamespace(success=True, message_id="msg-1")
)
consumer = GatewayStreamConsumer(
adapter, "chat-1",
config=StreamConsumerConfig(cursor=""),
)
consumer._message_id = "msg-1"
consumer._last_sent_text = "Hello world ▉"
consumer._fallback_final_send = False
await consumer._send_fallback_final("Hello world")
adapter.edit_message.assert_called_once()
call_args = adapter.edit_message.call_args
assert call_args.kwargs["content"] == "Hello world"
assert consumer._already_sent is True
# _last_sent_text should reflect the cleaned text after a successful strip
assert consumer._last_sent_text == "Hello world"
@pytest.mark.asyncio
async def test_cursor_not_stripped_when_no_cursor_configured(self):
"""No edit attempted when cursor is not configured."""
adapter = MagicMock()
adapter.MAX_MESSAGE_LENGTH = 4096
adapter.edit_message = AsyncMock()
consumer = GatewayStreamConsumer(
adapter, "chat-1",
config=StreamConsumerConfig(cursor=""),
)
consumer._message_id = "msg-1"
consumer._last_sent_text = "Hello world"
consumer._fallback_final_send = False
await consumer._send_fallback_final("Hello world")
adapter.edit_message.assert_not_called()
assert consumer._already_sent is True
@pytest.mark.asyncio
async def test_cursor_strip_edit_failure_handled(self):
"""If the cursor-stripping edit itself fails, it must not crash and
must not corrupt _last_sent_text."""
adapter = MagicMock()
adapter.MAX_MESSAGE_LENGTH = 4096
adapter.edit_message = AsyncMock(
return_value=SimpleNamespace(success=False, error="flood_control")
)
consumer = GatewayStreamConsumer(
adapter, "chat-1",
config=StreamConsumerConfig(cursor=""),
)
consumer._message_id = "msg-1"
consumer._last_sent_text = "Hello ▉"
consumer._fallback_final_send = False
await consumer._send_fallback_final("Hello")
# Should still set already_sent despite the cursor-strip edit failure
assert consumer._already_sent is True
# _last_sent_text must NOT be updated when the edit failed
assert consumer._last_sent_text == "Hello ▉"

View file

@ -483,6 +483,32 @@ class TestSendDocument:
assert "not found" in result.error.lower()
connected_adapter._bot.send_document.assert_not_called()
@pytest.mark.asyncio
async def test_send_document_workspace_path_has_docker_hint(self, connected_adapter):
"""Container-local-looking paths get a more actionable Docker hint."""
result = await connected_adapter.send_document(
chat_id="12345",
file_path="/workspace/report.txt",
)
assert result.success is False
assert "docker sandbox" in result.error.lower()
assert "host-visible path" in result.error.lower()
connected_adapter._bot.send_document.assert_not_called()
@pytest.mark.asyncio
async def test_send_document_outputs_path_has_docker_hint(self, connected_adapter):
"""Legacy /outputs paths also get the Docker hint."""
result = await connected_adapter.send_document(
chat_id="12345",
file_path="/outputs/report.txt",
)
assert result.success is False
assert "docker sandbox" in result.error.lower()
assert "host-visible path" in result.error.lower()
connected_adapter._bot.send_document.assert_not_called()
@pytest.mark.asyncio
async def test_send_document_not_connected(self, adapter):
"""If bot is None, returns not connected error."""
@ -665,6 +691,17 @@ class TestSendVideo:
assert result.success is False
assert "not found" in result.error.lower()
@pytest.mark.asyncio
async def test_send_video_workspace_path_has_docker_hint(self, connected_adapter):
result = await connected_adapter.send_video(
chat_id="12345",
video_path="/workspace/video.mp4",
)
assert result.success is False
assert "docker sandbox" in result.error.lower()
assert "host-visible path" in result.error.lower()
@pytest.mark.asyncio
async def test_send_video_not_connected(self, adapter):
result = await adapter.send_video(

View file

@ -148,6 +148,70 @@ class TestDiscordTextBatching:
await asyncio.sleep(0.25)
adapter.handle_message.assert_called_once()
@pytest.mark.asyncio
async def test_shield_protects_handle_message_from_cancel(self):
"""Regression guard: a follow-up chunk arriving while
handle_message is mid-flight must NOT cancel the running
dispatch. _enqueue_text_event fires prior_task.cancel() on
every new chunk; without asyncio.shield around handle_message
the cancel propagates into the agent's streaming request and
aborts the response.
"""
adapter = _make_discord_adapter()
handle_started = asyncio.Event()
release_handle = asyncio.Event()
first_handle_cancelled = asyncio.Event()
first_handle_completed = asyncio.Event()
call_count = [0]
async def slow_handle(event):
call_count[0] += 1
# Only the first call (batch 1) is the one we're protecting.
if call_count[0] == 1:
handle_started.set()
try:
await release_handle.wait()
first_handle_completed.set()
except asyncio.CancelledError:
first_handle_cancelled.set()
raise
# Second call (batch 2) returns immediately — not the subject
# of this test.
adapter.handle_message = slow_handle
# Prime batch 1 and wait for it to land inside handle_message.
adapter._enqueue_text_event(_make_event("batch 1", Platform.DISCORD))
await asyncio.wait_for(handle_started.wait(), timeout=1.0)
# A new chunk arrives — _enqueue_text_event fires
# prior_task.cancel() on batch 1's flush task, which is
# currently awaiting inside handle_message.
adapter._enqueue_text_event(_make_event("batch 2 follow-up", Platform.DISCORD))
# Let the cancel propagate.
await asyncio.sleep(0.05)
# CRITICAL ASSERTION: batch 1's handle_message must NOT have
# been cancelled. Without asyncio.shield this assertion fails
# because CancelledError propagates from the flush task's
# `await self.handle_message(event)` into slow_handle.
assert not first_handle_cancelled.is_set(), (
"handle_message for batch 1 was cancelled by a follow-up "
"chunk — asyncio.shield is missing or broken"
)
# Release batch 1's handle_message and let it complete.
release_handle.set()
await asyncio.wait_for(first_handle_completed.wait(), timeout=1.0)
assert first_handle_completed.is_set()
# Cleanup
for task in list(adapter._pending_text_batch_tasks.values()):
task.cancel()
await asyncio.sleep(0.01)
# =====================================================================
# Matrix text batching

View file

@ -758,7 +758,7 @@ class TestVoiceChannelCommands:
result = await runner._handle_voice_channel_join(event)
assert "voice dependencies are missing" in result.lower()
assert "hermes-agent[messaging]" in result
assert "PyNaCl" in result
# -- _handle_voice_channel_leave --

View file

@ -0,0 +1,473 @@
"""Tests for the webhook adapter's ``deliver_only`` route mode.
``deliver_only`` lets external services (Supabase webhooks, monitoring
alerts, background jobs, other agents) push plain-text notifications to
a user's chat via the webhook adapter WITHOUT invoking the agent. The
rendered prompt template becomes the literal message body.
Covers:
- Agent is NOT invoked (``handle_message`` never called)
- Rendered content is delivered to the target platform adapter
- HTTP returns 200 OK on success, 502 on delivery failure
- Startup validation rejects ``deliver_only`` without a real delivery target
- HMAC auth, rate limiting, and idempotency still apply
"""
import asyncio
import hashlib
import hmac
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from aiohttp import web
from aiohttp.test_utils import TestClient, TestServer
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import MessageEvent, SendResult
from gateway.platforms.webhook import WebhookAdapter, _INSECURE_NO_AUTH
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_adapter(routes, **extra_kw) -> WebhookAdapter:
extra = {"host": "0.0.0.0", "port": 0, "routes": routes}
extra.update(extra_kw)
config = PlatformConfig(enabled=True, extra=extra)
return WebhookAdapter(config)
def _create_app(adapter: WebhookAdapter) -> web.Application:
app = web.Application()
app.router.add_get("/health", adapter._handle_health)
app.router.add_post("/webhooks/{route_name}", adapter._handle_webhook)
return app
def _wire_mock_target(adapter: WebhookAdapter, platform_name: str = "telegram"):
"""Attach a gateway_runner with a mocked target adapter."""
mock_target = AsyncMock()
mock_target.send = AsyncMock(return_value=SendResult(success=True))
mock_runner = MagicMock()
mock_runner.adapters = {Platform(platform_name): mock_target}
mock_runner.config.get_home_channel.return_value = None
adapter.gateway_runner = mock_runner
return mock_target
# ===================================================================
# Core behaviour: agent bypass
# ===================================================================
class TestDeliverOnlyBypassesAgent:
"""The whole point of the feature — handle_message must not be called."""
@pytest.mark.asyncio
async def test_post_delivers_directly_without_agent(self):
routes = {
"match-alert": {
"secret": _INSECURE_NO_AUTH,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "12345"},
"prompt": "{payload.user} matched with {payload.other}!",
}
}
adapter = _make_adapter(routes)
mock_target = _wire_mock_target(adapter)
# Guard: handle_message must NOT be called in deliver_only mode
handle_message_calls: list[MessageEvent] = []
async def _capture(event):
handle_message_calls.append(event)
adapter.handle_message = _capture
app = _create_app(adapter)
body = json.dumps(
{"payload": {"user": "alice", "other": "bob"}}
).encode()
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/match-alert",
data=body,
headers={
"Content-Type": "application/json",
"X-GitHub-Delivery": "delivery-1",
},
)
assert resp.status == 200
data = await resp.json()
assert data["status"] == "delivered"
assert data["route"] == "match-alert"
assert data["target"] == "telegram"
# Let any background tasks settle before asserting no agent call
await asyncio.sleep(0.05)
# Agent was NOT invoked
assert handle_message_calls == []
# Target adapter.send() WAS called with the rendered template
mock_target.send.assert_awaited_once()
call_args = mock_target.send.await_args
chat_id_arg, content_arg = call_args.args[0], call_args.args[1]
assert chat_id_arg == "12345"
assert content_arg == "alice matched with bob!"
@pytest.mark.asyncio
async def test_template_rendering_works(self):
"""Dot-notation template variables resolve in deliver_only mode."""
routes = {
"alert": {
"secret": _INSECURE_NO_AUTH,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "chat-1"},
"prompt": "Build {build.number} status: {build.status}",
}
}
adapter = _make_adapter(routes)
mock_target = _wire_mock_target(adapter)
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/alert",
json={"build": {"number": 77, "status": "FAILED"}},
headers={"X-GitHub-Delivery": "d-render-1"},
)
assert resp.status == 200
mock_target.send.assert_awaited_once()
content_arg = mock_target.send.await_args.args[1]
assert content_arg == "Build 77 status: FAILED"
@pytest.mark.asyncio
async def test_thread_id_passed_through(self):
"""deliver_extra.thread_id flows through to the target adapter."""
routes = {
"r": {
"secret": _INSECURE_NO_AUTH,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "c-1", "thread_id": "topic-42"},
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
mock_target = _wire_mock_target(adapter)
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": "d-thread-1"},
)
assert resp.status == 200
assert mock_target.send.await_args.kwargs["metadata"] == {
"thread_id": "topic-42"
}
# ===================================================================
# HTTP status codes
# ===================================================================
class TestDeliverOnlyStatusCodes:
@pytest.mark.asyncio
async def test_delivery_failure_returns_502(self):
"""If the target adapter returns SendResult(success=False), 502."""
routes = {
"r": {
"secret": _INSECURE_NO_AUTH,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "c-1"},
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
mock_target = _wire_mock_target(adapter)
mock_target.send = AsyncMock(
return_value=SendResult(success=False, error="rate limited by tg")
)
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": "d-fail-1"},
)
assert resp.status == 502
data = await resp.json()
# Generic error — no adapter-level detail leaks
assert data["error"] == "Delivery failed"
assert "rate limited" not in json.dumps(data)
@pytest.mark.asyncio
async def test_delivery_exception_returns_502(self):
"""If adapter.send() raises, we return 502 (not 500)."""
routes = {
"r": {
"secret": _INSECURE_NO_AUTH,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "c-1"},
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
mock_target = _wire_mock_target(adapter)
mock_target.send = AsyncMock(side_effect=RuntimeError("tg exploded"))
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": "d-exc-1"},
)
assert resp.status == 502
data = await resp.json()
assert data["error"] == "Delivery failed"
# Exception message must not leak
assert "exploded" not in json.dumps(data)
@pytest.mark.asyncio
async def test_target_platform_not_connected_returns_502(self):
"""deliver_only to a platform the gateway doesn't have → 502."""
routes = {
"r": {
"secret": _INSECURE_NO_AUTH,
"deliver": "discord", # not configured in mock runner
"deliver_only": True,
"deliver_extra": {"chat_id": "c-1"},
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
_wire_mock_target(adapter, platform_name="telegram") # only TG wired
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": "d-no-platform-1"},
)
assert resp.status == 502
# ===================================================================
# Startup validation
# ===================================================================
class TestDeliverOnlyStartupValidation:
@pytest.mark.asyncio
async def test_deliver_only_with_log_deliver_rejected(self):
"""deliver_only=true + deliver=log is nonsense — reject at connect()."""
routes = {
"bad": {
"secret": _INSECURE_NO_AUTH,
"deliver": "log",
"deliver_only": True,
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
with pytest.raises(ValueError, match="deliver_only=true but deliver is 'log'"):
await adapter.connect()
@pytest.mark.asyncio
async def test_deliver_only_with_missing_deliver_rejected(self):
"""deliver_only=true with no deliver field defaults to 'log' → reject."""
routes = {
"bad": {
"secret": _INSECURE_NO_AUTH,
# no deliver field
"deliver_only": True,
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
with pytest.raises(ValueError, match="deliver_only=true"):
await adapter.connect()
@pytest.mark.asyncio
async def test_deliver_only_with_real_target_accepted(self):
"""Sanity check — a valid deliver_only config passes validation."""
routes = {
"good": {
"secret": _INSECURE_NO_AUTH,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "c-1"},
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
# connect() does more than validation (binds a socket) — we just
# want to verify the validation doesn't raise. Call it and tear
# down immediately.
try:
started = await adapter.connect()
if started:
await adapter.disconnect()
except ValueError:
pytest.fail("valid deliver_only config should not raise ValueError")
# ===================================================================
# Security + reliability invariants still hold
# ===================================================================
class TestDeliverOnlySecurityInvariants:
@pytest.mark.asyncio
async def test_hmac_still_enforced(self):
"""deliver_only does NOT bypass HMAC validation."""
secret = "real-secret-123"
routes = {
"r": {
"secret": secret,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "c-1"},
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
mock_target = _wire_mock_target(adapter)
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
# No signature header → reject
resp = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": "d-noauth-1"},
)
assert resp.status == 401
# Target never called
mock_target.send.assert_not_awaited()
@pytest.mark.asyncio
async def test_idempotency_still_applies(self):
"""Same delivery_id posted twice → second is suppressed."""
routes = {
"r": {
"secret": _INSECURE_NO_AUTH,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "c-1"},
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
mock_target = _wire_mock_target(adapter)
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
r1 = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": "dup-1"},
)
assert r1.status == 200
r2 = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": "dup-1"},
)
# Existing webhook adapter treats duplicates as 200 + status=duplicate
assert r2.status == 200
data = await r2.json()
assert data["status"] == "duplicate"
# Target was called exactly once
assert mock_target.send.await_count == 1
@pytest.mark.asyncio
async def test_rate_limit_still_applies(self):
"""Route-level rate limit caps deliver_only POSTs too."""
routes = {
"r": {
"secret": _INSECURE_NO_AUTH,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "c-1"},
"prompt": "hi",
}
}
adapter = _make_adapter(routes, rate_limit=2)
_wire_mock_target(adapter)
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
for i in range(2):
r = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": f"rl-{i}"},
)
assert r.status == 200
# Third within the window → 429
r3 = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": "rl-3"},
)
assert r3.status == 429
# ===================================================================
# Unit: _direct_deliver dispatch
# ===================================================================
class TestDirectDeliverUnit:
@pytest.mark.asyncio
async def test_dispatches_to_cross_platform_for_messaging_targets(self):
adapter = _make_adapter({})
mock_target = _wire_mock_target(adapter, "telegram")
result = await adapter._direct_deliver(
"hello",
{"deliver": "telegram", "deliver_extra": {"chat_id": "c-1"}},
)
assert result.success is True
mock_target.send.assert_awaited_once_with(
"c-1", "hello", metadata=None
)
@pytest.mark.asyncio
async def test_dispatches_to_github_comment(self):
adapter = _make_adapter({})
with patch.object(
adapter, "_deliver_github_comment",
new=AsyncMock(return_value=SendResult(success=True)),
) as mock_gh:
result = await adapter._direct_deliver(
"review body",
{
"deliver": "github_comment",
"deliver_extra": {"repo": "org/r", "pr_number": "1"},
},
)
assert result.success is True
mock_gh.assert_awaited_once()

View file

@ -14,7 +14,6 @@ from hermes_cli.auth import (
PROVIDER_REGISTRY,
_read_codex_tokens,
_save_codex_tokens,
_write_codex_cli_tokens,
_import_codex_cli_tokens,
get_codex_auth_status,
get_provider_auth_state,
@ -182,98 +181,6 @@ def test_codex_tokens_not_written_to_shared_file(tmp_path, monkeypatch):
assert data["tokens"]["access_token"] == "hermes-at"
def test_write_codex_cli_tokens_creates_file(tmp_path, monkeypatch):
"""_write_codex_cli_tokens creates ~/.codex/auth.json with refreshed tokens."""
codex_home = tmp_path / "codex-cli"
monkeypatch.setenv("CODEX_HOME", str(codex_home))
_write_codex_cli_tokens("new-access", "new-refresh", last_refresh="2026-04-12T00:00:00Z")
auth_path = codex_home / "auth.json"
assert auth_path.exists()
data = json.loads(auth_path.read_text())
assert data["tokens"]["access_token"] == "new-access"
assert data["tokens"]["refresh_token"] == "new-refresh"
assert data["last_refresh"] == "2026-04-12T00:00:00Z"
# Verify file permissions are restricted
assert (auth_path.stat().st_mode & 0o777) == 0o600
def test_write_codex_cli_tokens_preserves_existing(tmp_path, monkeypatch):
"""_write_codex_cli_tokens preserves extra fields in existing auth.json."""
codex_home = tmp_path / "codex-cli"
codex_home.mkdir(parents=True, exist_ok=True)
monkeypatch.setenv("CODEX_HOME", str(codex_home))
existing = {
"tokens": {
"access_token": "old-access",
"refresh_token": "old-refresh",
"extra_field": "preserved",
},
"last_refresh": "2026-01-01T00:00:00Z",
"custom_key": "keep_me",
}
(codex_home / "auth.json").write_text(json.dumps(existing))
_write_codex_cli_tokens("updated-access", "updated-refresh")
data = json.loads((codex_home / "auth.json").read_text())
assert data["tokens"]["access_token"] == "updated-access"
assert data["tokens"]["refresh_token"] == "updated-refresh"
assert data["tokens"]["extra_field"] == "preserved"
assert data["custom_key"] == "keep_me"
# last_refresh not updated since we didn't pass it
assert data["last_refresh"] == "2026-01-01T00:00:00Z"
def test_write_codex_cli_tokens_handles_missing_dir(tmp_path, monkeypatch):
"""_write_codex_cli_tokens creates parent directories if missing."""
codex_home = tmp_path / "does" / "not" / "exist"
monkeypatch.setenv("CODEX_HOME", str(codex_home))
_write_codex_cli_tokens("at", "rt")
assert (codex_home / "auth.json").exists()
data = json.loads((codex_home / "auth.json").read_text())
assert data["tokens"]["access_token"] == "at"
def test_refresh_codex_auth_tokens_writes_back_to_cli(tmp_path, monkeypatch):
"""After refreshing, _refresh_codex_auth_tokens writes back to ~/.codex/auth.json."""
from hermes_cli.auth import _refresh_codex_auth_tokens
hermes_home = tmp_path / "hermes"
codex_home = tmp_path / "codex-cli"
hermes_home.mkdir(parents=True, exist_ok=True)
codex_home.mkdir(parents=True, exist_ok=True)
(hermes_home / "auth.json").write_text(json.dumps({"version": 1, "providers": {}}))
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setenv("CODEX_HOME", str(codex_home))
# Write initial CLI tokens
(codex_home / "auth.json").write_text(json.dumps({
"tokens": {"access_token": "old-at", "refresh_token": "old-rt"},
}))
# Mock the pure refresh to return new tokens
monkeypatch.setattr("hermes_cli.auth.refresh_codex_oauth_pure", lambda *a, **kw: {
"access_token": "refreshed-at",
"refresh_token": "refreshed-rt",
"last_refresh": "2026-04-12T01:00:00Z",
})
_refresh_codex_auth_tokens(
{"access_token": "old-at", "refresh_token": "old-rt"},
timeout_seconds=10,
)
# Verify CLI file was updated
cli_data = json.loads((codex_home / "auth.json").read_text())
assert cli_data["tokens"]["access_token"] == "refreshed-at"
assert cli_data["tokens"]["refresh_token"] == "refreshed-rt"
def test_resolve_returns_hermes_auth_store_source(tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
_setup_hermes_auth(hermes_home)

View file

@ -124,29 +124,23 @@ class TestCmdUpdateBranchFallback:
if call.args and call.args[0][0] == "/usr/bin/npm"
]
# cmd_update runs npm commands in three locations:
# 1. repo root — slash-command / TUI bridge deps
# 2. ui-tui/ — Ink TUI deps
# 3. web/ — install + "npm run build" for the web frontend
full_flags = [
"/usr/bin/npm",
"install",
"--silent",
"--no-fund",
"--no-audit",
"--progress=false",
]
assert npm_calls == [
(
[
"/usr/bin/npm",
"install",
"--silent",
"--no-fund",
"--no-audit",
"--progress=false",
],
PROJECT_ROOT,
),
(
[
"/usr/bin/npm",
"install",
"--silent",
"--no-fund",
"--no-audit",
"--progress=false",
],
PROJECT_ROOT / "ui-tui",
),
(full_flags, PROJECT_ROOT),
(full_flags, PROJECT_ROOT / "ui-tui"),
(["/usr/bin/npm", "install", "--silent"], PROJECT_ROOT / "web"),
(["/usr/bin/npm", "run", "build"], PROJECT_ROOT / "web"),
]
def test_update_non_interactive_skips_migration_prompt(self, mock_args, capsys):

View file

@ -459,7 +459,7 @@ class TestCustomProviderCompatibility:
migrate_config(interactive=False, quiet=True)
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
assert raw["_config_version"] == 18
assert raw["_config_version"] == 19
assert raw["providers"]["openai-direct"] == {
"api": "https://api.openai.com/v1",
"api_key": "test-key",
@ -606,7 +606,7 @@ class TestInterimAssistantMessageConfig:
migrate_config(interactive=False, quiet=True)
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
assert raw["_config_version"] == 18
assert raw["_config_version"] == 19
assert raw["display"]["tool_progress"] == "off"
assert raw["display"]["interim_assistant_messages"] is True
@ -626,6 +626,6 @@ class TestDiscordChannelPromptsConfig:
migrate_config(interactive=False, quiet=True)
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
assert raw["_config_version"] == 18
assert raw["_config_version"] == 19
assert raw["discord"]["auto_thread"] is True
assert raw["discord"]["channel_prompts"] == {}

View file

@ -54,12 +54,12 @@ class TestCronCommandLifecycle:
deliver=None,
repeat=None,
skill=None,
skills=["find-nearby", "blogwatcher"],
skills=["maps", "blogwatcher"],
clear_skills=False,
)
)
updated = get_job(job["id"])
assert updated["skills"] == ["find-nearby", "blogwatcher"]
assert updated["skills"] == ["maps", "blogwatcher"]
assert updated["name"] == "Edited Job"
assert updated["prompt"] == "Revised prompt"
assert updated["schedule_display"] == "every 120m"
@ -95,7 +95,7 @@ class TestCronCommandLifecycle:
deliver=None,
repeat=None,
skill=None,
skills=["blogwatcher", "find-nearby"],
skills=["blogwatcher", "maps"],
)
)
out = capsys.readouterr().out
@ -103,5 +103,5 @@ class TestCronCommandLifecycle:
jobs = list_jobs()
assert len(jobs) == 1
assert jobs[0]["skills"] == ["blogwatcher", "find-nearby"]
assert jobs[0]["skills"] == ["blogwatcher", "maps"]
assert jobs[0]["name"] == "Skill combo"

View file

@ -130,7 +130,7 @@ class TestGeminiModelCatalog:
models = _PROVIDER_MODELS["gemini"]
assert "gemini-2.5-pro" in models
assert "gemini-2.5-flash" in models
assert "gemma-4-31b-it" in models
assert "gemma-4-31b-it" not in models
def test_provider_models_has_3x(self):
models = _PROVIDER_MODELS["gemini"]
@ -207,6 +207,37 @@ class TestGeminiAgentInit:
assert agent.api_mode == "chat_completions"
assert agent.provider == "gemini"
def test_gemini_uses_bearer_auth(self, monkeypatch):
"""Gemini OpenAI-compatible endpoint should receive the real API key."""
monkeypatch.setenv("GOOGLE_API_KEY", "AIzaSy_REAL_KEY")
real_key = "AIzaSy_REAL_KEY"
with patch("run_agent.OpenAI") as mock_openai:
mock_openai.return_value = MagicMock()
from run_agent import AIAgent
AIAgent(
model="gemini-2.5-flash",
provider="gemini",
api_key=real_key,
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
)
call_kwargs = mock_openai.call_args[1]
assert call_kwargs.get("api_key") == real_key
headers = call_kwargs.get("default_headers", {})
assert "x-goog-api-key" not in headers
def test_gemini_resolve_provider_client_auth(self, monkeypatch):
"""resolve_provider_client('gemini') should pass the real API key through."""
monkeypatch.setenv("GEMINI_API_KEY", "AIzaSy_TEST_KEY")
real_key = "AIzaSy_TEST_KEY"
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
mock_openai.return_value = MagicMock()
from agent.auxiliary_client import resolve_provider_client
resolve_provider_client("gemini")
call_kwargs = mock_openai.call_args[1]
assert call_kwargs.get("api_key") == real_key
headers = call_kwargs.get("default_headers", {})
assert "x-goog-api-key" not in headers
# ── models.dev Integration ──
@ -261,9 +292,32 @@ class TestGeminiModelsDev:
result = list_agentic_models("gemini")
assert "gemini-3-flash-preview" in result
assert "gemini-2.5-pro" in result
assert "gemma-4-31b-it" in result
assert "gemma-4-31b-it" not in result
# Filtered out:
assert "gemini-embedding-001" not in result # no tool_call
assert "gemini-2.5-flash-preview-tts" not in result # no tool_call
assert "gemini-live-2.5-flash" not in result # noise: live-
assert "gemini-2.5-flash-preview-04-17" not in result # noise: dated preview
def test_list_provider_models_hides_low_tpm_google_gemmas(self):
mock_data = {
"google": {
"models": {
"gemini-2.5-pro": {},
"gemma-4-31b-it": {},
"gemma-3-27b-it": {},
"gemini-1.5-pro": {},
"gemini-2.0-flash": {},
}
}
}
with patch("agent.models_dev.fetch_models_dev", return_value=mock_data):
from agent.models_dev import list_provider_models
result = list_provider_models("gemini")
assert "gemini-2.5-pro" in result
assert "gemma-4-31b-it" not in result
assert "gemma-3-27b-it" not in result
assert "gemini-1.5-pro" not in result
assert "gemini-2.0-flash" not in result

View file

@ -450,9 +450,9 @@ class TestValidateApiNotFound:
assert result["recognized"] is True
def test_dissimilar_model_shows_suggestions_not_autocorrect(self):
"""Models too different for auto-correction still get suggestions."""
"""Models too different for auto-correction are rejected with suggestions."""
result = _validate("anthropic/claude-nonexistent")
assert result["accepted"] is True
assert result["accepted"] is False
assert result.get("corrected_model") is None
assert "not found" in result["message"]
@ -532,11 +532,11 @@ class TestValidateCodexAutoCorrection:
assert result["message"] is None
def test_very_different_name_falls_to_suggestions(self):
"""Names too different for auto-correction get the suggestion list."""
"""Names too different for auto-correction are rejected with a suggestion list."""
codex_models = ["gpt-5.4-mini", "gpt-5.4", "gpt-5.3-codex"]
with patch("hermes_cli.models.provider_model_ids", return_value=codex_models):
result = validate_requested_model("totally-wrong", "openai-codex")
assert result["accepted"] is True
assert result["accepted"] is False
assert result["recognized"] is False
assert result.get("corrected_model") is None
assert "not found" in result["message"]

View file

@ -0,0 +1,29 @@
"""Tests for agent-settings copy in the interactive setup wizard."""
from hermes_cli.setup import setup_agent_settings
def test_setup_agent_settings_uses_displayed_max_iterations_value(tmp_path, monkeypatch, capsys):
"""The helper text should match the value shown in the prompt."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
config = {
"agent": {"max_turns": 90},
"display": {"tool_progress": "all"},
"compression": {"threshold": 0.50},
"session_reset": {"mode": "both", "idle_minutes": 1440, "at_hour": 4},
}
prompt_answers = iter(["60", "all", "0.5"])
monkeypatch.setattr("hermes_cli.setup.get_env_value", lambda key: "60" if key == "HERMES_MAX_ITERATIONS" else "")
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: next(prompt_answers))
monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda *args, **kwargs: 4)
monkeypatch.setattr("hermes_cli.setup.save_env_value", lambda *args, **kwargs: None)
monkeypatch.setattr("hermes_cli.setup.save_config", lambda *args, **kwargs: None)
setup_agent_settings(config)
out = capsys.readouterr().out
assert "Press Enter to keep 60." in out
assert "Default is 90" not in out

View file

@ -0,0 +1,325 @@
"""Tests for SIGHUP protection and stdout mirroring in ``hermes update``.
Covers ``_UpdateOutputStream``, ``_install_hangup_protection``, and
``_finalize_update_output`` in ``hermes_cli/main.py``. These exist so
that ``hermes update`` survives a terminal disconnect mid-install
(SSH drop, shell close) without leaving the venv half-installed.
"""
from __future__ import annotations
import io
import os
import signal
import sys
from pathlib import Path
from unittest.mock import patch
import pytest
from hermes_cli.main import (
_UpdateOutputStream,
_finalize_update_output,
_install_hangup_protection,
)
# -----------------------------------------------------------------------------
# _UpdateOutputStream
# -----------------------------------------------------------------------------
class TestUpdateOutputStream:
def test_write_mirrors_to_both_original_and_log(self):
original = io.StringIO()
log = io.StringIO()
stream = _UpdateOutputStream(original, log)
stream.write("hello world\n")
assert original.getvalue() == "hello world\n"
assert log.getvalue() == "hello world\n"
def test_write_continues_after_broken_original(self):
"""When the terminal disconnects, original.write raises BrokenPipeError.
The wrapper must catch it, flip the broken flag, and keep writing to
the log from then on.
"""
log = io.StringIO()
class _BrokenStream:
def write(self, data):
raise BrokenPipeError("terminal gone")
def flush(self):
raise BrokenPipeError("terminal gone")
stream = _UpdateOutputStream(_BrokenStream(), log)
# First write triggers the broken-pipe path.
stream.write("first line\n")
# Subsequent writes take the fast broken path (no exception).
stream.write("second line\n")
assert log.getvalue() == "first line\nsecond line\n"
assert stream._original_broken is True
def test_write_tolerates_oserror_and_valueerror(self):
"""OSError (EIO) and ValueError (closed file) should also be absorbed."""
log = io.StringIO()
class _RaisingStream:
def __init__(self, exc):
self._exc = exc
def write(self, data):
raise self._exc
def flush(self):
raise self._exc
for exc in (OSError("EIO"), ValueError("closed file")):
stream = _UpdateOutputStream(_RaisingStream(exc), log)
stream.write("x\n")
assert stream._original_broken is True
def test_log_failure_does_not_abort_write(self):
"""Even if the log file write raises, the original write must still happen."""
class _BrokenLog:
def write(self, data):
raise OSError("disk full")
def flush(self):
raise OSError("disk full")
original = io.StringIO()
stream = _UpdateOutputStream(original, _BrokenLog())
stream.write("data\n")
assert original.getvalue() == "data\n"
def test_flush_tolerates_broken_original(self):
class _BrokenStream:
def write(self, data):
return len(data)
def flush(self):
raise BrokenPipeError("gone")
log = io.StringIO()
stream = _UpdateOutputStream(_BrokenStream(), log)
stream.flush() # must not raise
assert stream._original_broken is True
def test_isatty_delegates_to_original(self):
class _TtyStream:
def isatty(self):
return True
def write(self, data):
return len(data)
def flush(self):
return None
stream = _UpdateOutputStream(_TtyStream(), io.StringIO())
assert stream.isatty() is True
def test_isatty_returns_false_after_broken(self):
class _BrokenStream:
def isatty(self):
return True
def write(self, data):
raise BrokenPipeError()
def flush(self):
return None
stream = _UpdateOutputStream(_BrokenStream(), io.StringIO())
stream.write("x") # marks broken
assert stream.isatty() is False
def test_getattr_delegates_unknown_attrs(self):
class _StreamWithEncoding:
encoding = "utf-8"
def write(self, data):
return len(data)
def flush(self):
return None
stream = _UpdateOutputStream(_StreamWithEncoding(), io.StringIO())
assert stream.encoding == "utf-8"
# -----------------------------------------------------------------------------
# _install_hangup_protection
# -----------------------------------------------------------------------------
class TestInstallHangupProtection:
def test_gateway_mode_is_noop(self):
"""In gateway mode the process is already detached — don't touch stdio or signals."""
prev_out, prev_err = sys.stdout, sys.stderr
prev_sighup = signal.getsignal(signal.SIGHUP) if hasattr(signal, "SIGHUP") else None
state = _install_hangup_protection(gateway_mode=True)
try:
assert sys.stdout is prev_out
assert sys.stderr is prev_err
assert state["log_file"] is None
assert state["installed"] is False
if hasattr(signal, "SIGHUP"):
assert signal.getsignal(signal.SIGHUP) == prev_sighup
finally:
_finalize_update_output(state)
@pytest.mark.skipif(
not hasattr(signal, "SIGHUP"), reason="SIGHUP not available on this platform"
)
def test_installs_sighup_ignore(self, tmp_path, monkeypatch):
"""SIGHUP should be set to SIG_IGN so SSH disconnect doesn't kill the update."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
# Clear cached get_hermes_home if present
import hermes_cli.config as _cfg
if hasattr(_cfg, "_HERMES_HOME_CACHE"):
_cfg._HERMES_HOME_CACHE = None # type: ignore[attr-defined]
original_handler = signal.getsignal(signal.SIGHUP)
state = _install_hangup_protection(gateway_mode=False)
try:
assert signal.getsignal(signal.SIGHUP) == signal.SIG_IGN
finally:
_finalize_update_output(state)
# Restore whatever was there before so we don't leak to other tests.
signal.signal(signal.SIGHUP, original_handler)
def test_wraps_stdout_and_stderr_with_mirror(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
# Nuke any cached home path
import hermes_cli.config as _cfg
if hasattr(_cfg, "_HERMES_HOME_CACHE"):
_cfg._HERMES_HOME_CACHE = None # type: ignore[attr-defined]
prev_out, prev_err = sys.stdout, sys.stderr
state = _install_hangup_protection(gateway_mode=False)
try:
# On Windows (no SIGHUP) we still wrap stdio and create the log.
assert state["installed"] is True
assert isinstance(sys.stdout, _UpdateOutputStream)
assert isinstance(sys.stderr, _UpdateOutputStream)
assert state["log_file"] is not None
sys.stdout.write("checking mirror\n")
sys.stdout.flush()
log_path = tmp_path / "logs" / "update.log"
assert log_path.exists()
contents = log_path.read_text(encoding="utf-8")
assert "checking mirror" in contents
assert "hermes update started" in contents
finally:
_finalize_update_output(state)
# Sanity-check restoration
assert sys.stdout is prev_out
assert sys.stderr is prev_err
def test_logs_dir_created_if_missing(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
import hermes_cli.config as _cfg
if hasattr(_cfg, "_HERMES_HOME_CACHE"):
_cfg._HERMES_HOME_CACHE = None # type: ignore[attr-defined]
# No logs/ dir yet.
assert not (tmp_path / "logs").exists()
state = _install_hangup_protection(gateway_mode=False)
try:
assert (tmp_path / "logs").is_dir()
assert (tmp_path / "logs" / "update.log").exists()
finally:
_finalize_update_output(state)
def test_non_fatal_if_log_setup_fails(self, monkeypatch):
"""If get_hermes_home() raises, stdio must be left untouched but SIGHUP still handled."""
prev_out, prev_err = sys.stdout, sys.stderr
def _boom():
raise RuntimeError("no home for you")
# Patch the import inside _install_hangup_protection.
monkeypatch.setattr(
"hermes_cli.config.get_hermes_home", _boom, raising=True
)
original_handler = (
signal.getsignal(signal.SIGHUP) if hasattr(signal, "SIGHUP") else None
)
state = _install_hangup_protection(gateway_mode=False)
try:
assert sys.stdout is prev_out
assert sys.stderr is prev_err
assert state["installed"] is False
# SIGHUP must still be installed even when log setup fails.
if hasattr(signal, "SIGHUP"):
assert signal.getsignal(signal.SIGHUP) == signal.SIG_IGN
finally:
_finalize_update_output(state)
if hasattr(signal, "SIGHUP") and original_handler is not None:
signal.signal(signal.SIGHUP, original_handler)
# -----------------------------------------------------------------------------
# _finalize_update_output
# -----------------------------------------------------------------------------
class TestFinalizeUpdateOutput:
def test_none_state_is_noop(self):
_finalize_update_output(None) # must not raise
def test_restores_streams_and_closes_log(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
import hermes_cli.config as _cfg
if hasattr(_cfg, "_HERMES_HOME_CACHE"):
_cfg._HERMES_HOME_CACHE = None # type: ignore[attr-defined]
prev_out = sys.stdout
state = _install_hangup_protection(gateway_mode=False)
log_file = state["log_file"]
assert sys.stdout is not prev_out
assert log_file is not None
_finalize_update_output(state)
assert sys.stdout is prev_out
# The log file handle should be closed.
assert log_file.closed is True
def test_skipped_install_leaves_stdio_alone(self):
"""When install failed (state['installed']=False) finalize should not
touch sys.stdout / sys.stderr (they were never wrapped)."""
# Build a synthetic state that mimics a failed install.
sentinel_out = object()
state = {
"prev_stdout": sentinel_out,
"prev_stderr": sentinel_out,
"log_file": None,
"installed": False,
}
before_out, before_err = sys.stdout, sys.stderr
_finalize_update_output(state)
assert sys.stdout is before_out
assert sys.stderr is before_err

View file

@ -460,10 +460,3 @@ class TestPrefetchCacheAccessors:
assert mgr.pop_context_result("cli:test") == payload
assert mgr.pop_context_result("cli:test") == {}
def test_set_and_pop_dialectic_result(self):
mgr = _make_manager(write_frequency="turn")
mgr.set_dialectic_result("cli:test", "Resume with toolset cleanup")
assert mgr.pop_dialectic_result("cli:test") == "Resume with toolset cleanup"
assert mgr.pop_dialectic_result("cli:test") == ""

View file

@ -26,6 +26,9 @@ class TestCmdStatus:
write_frequency = "async"
session_strategy = "per-session"
context_tokens = 800
dialectic_reasoning_level = "low"
reasoning_level_cap = "high"
reasoning_heuristic = True
def resolve_session_name(self):
return "hermes"

View file

@ -568,15 +568,15 @@ class TestToolsModeInitBehavior:
with patch("plugins.memory.honcho.client.HonchoClientConfig.from_global_config", return_value=cfg), \
patch("plugins.memory.honcho.client.get_honcho_client", return_value=MagicMock()), \
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mock_manager), \
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mock_manager) as mock_manager_cls, \
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
provider.initialize(session_id="test-session-001", **init_kwargs)
return provider, cfg
return provider, cfg, mock_manager_cls
def test_tools_lazy_default(self):
"""tools + initOnSessionStart=false → session NOT initialized after initialize()."""
provider, _ = self._make_provider_with_config(
provider, _, _ = self._make_provider_with_config(
recall_mode="tools", init_on_session_start=False,
)
assert provider._session_initialized is False
@ -585,7 +585,7 @@ class TestToolsModeInitBehavior:
def test_tools_eager_init(self):
"""tools + initOnSessionStart=true → session IS initialized after initialize()."""
provider, _ = self._make_provider_with_config(
provider, _, _ = self._make_provider_with_config(
recall_mode="tools", init_on_session_start=True,
)
assert provider._session_initialized is True
@ -593,33 +593,34 @@ class TestToolsModeInitBehavior:
def test_tools_eager_prefetch_still_empty(self):
"""tools mode with eager init still returns empty from prefetch() (no auto-injection)."""
provider, _ = self._make_provider_with_config(
provider, _, _ = self._make_provider_with_config(
recall_mode="tools", init_on_session_start=True,
)
assert provider.prefetch("test query") == ""
def test_tools_lazy_prefetch_empty(self):
"""tools mode with lazy init also returns empty from prefetch()."""
provider, _ = self._make_provider_with_config(
provider, _, _ = self._make_provider_with_config(
recall_mode="tools", init_on_session_start=False,
)
assert provider.prefetch("test query") == ""
def test_explicit_peer_name_not_overridden_by_user_id(self):
"""Explicit peerName in config must not be replaced by gateway user_id."""
_, cfg = self._make_provider_with_config(
_, cfg, _ = self._make_provider_with_config(
recall_mode="tools", init_on_session_start=True,
peer_name="Kathie", user_id="8439114563",
)
assert cfg.peer_name == "Kathie"
def test_user_id_used_when_no_peer_name(self):
"""Gateway user_id is used as peer_name when no explicit peerName configured."""
_, cfg = self._make_provider_with_config(
"""Gateway user_id is passed separately from config peer_name."""
_, cfg, mock_manager_cls = self._make_provider_with_config(
recall_mode="tools", init_on_session_start=True,
peer_name=None, user_id="8439114563",
)
assert cfg.peer_name == "8439114563"
assert cfg.peer_name is None
assert mock_manager_cls.call_args.kwargs["runtime_user_peer_name"] == "8439114563"
class TestPerSessionMigrateGuard:
@ -815,6 +816,27 @@ class TestDialecticInputGuard:
# ---------------------------------------------------------------------------
def _settle_prewarm(provider):
"""Wait for the session-start prewarm dialectic thread, then return the
provider to a clean 'nothing fired yet' state so cadence/first-turn/
trivial-prompt tests can assert from a known baseline."""
if provider._prefetch_thread:
provider._prefetch_thread.join(timeout=3.0)
with provider._prefetch_lock:
provider._prefetch_result = ""
provider._prefetch_result_fired_at = -999
provider._prefetch_thread = None
provider._prefetch_thread_started_at = 0.0
provider._last_dialectic_turn = -999
provider._dialectic_empty_streak = 0
if getattr(provider, "_manager", None) is not None:
try:
provider._manager.dialectic_query.reset_mock()
provider._manager.prefetch_context.reset_mock()
except AttributeError:
pass
class TestDialecticCadenceDefaults:
"""Regression tests for dialectic_cadence default value."""
@ -840,12 +862,15 @@ class TestDialecticCadenceDefaults:
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
provider.initialize(session_id="test-session-001")
_settle_prewarm(provider)
return provider
def test_default_is_3(self):
"""Default dialectic_cadence should be 3 to avoid per-turn LLM calls."""
def test_unset_falls_back_to_1(self):
"""Unset dialecticCadence falls back to 1 (every turn) for backwards
compatibility with existing configs that predate the setting. The
setup wizard writes 2 explicitly on new configs."""
provider = self._make_provider()
assert provider._dialectic_cadence == 3
assert provider._dialectic_cadence == 1
def test_config_override(self):
"""dialecticCadence from config overrides the default."""
@ -908,6 +933,7 @@ class TestDialecticDepth:
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
provider.initialize(session_id="test-session-001")
_settle_prewarm(provider)
return provider
def test_default_depth_is_1(self):
@ -1027,46 +1053,6 @@ class TestDialecticDepth:
assert provider._manager.dialectic_query.call_count == 2
assert "Synthesis" in result
def test_first_turn_runs_dialectic_synchronously(self):
"""First turn should fire the dialectic synchronously (cold start)."""
from unittest.mock import MagicMock, patch
provider = self._make_provider(cfg_extra={"dialectic_depth": 1})
provider._manager = MagicMock()
provider._manager.dialectic_query.return_value = "cold start synthesis"
provider._manager.get_prefetch_context.return_value = None
provider._manager.pop_context_result.return_value = None
provider._session_key = "test"
provider._base_context_cache = "" # cold start
provider._last_dialectic_turn = -999 # never fired
result = provider.prefetch("hello world")
assert "cold start synthesis" in result
assert provider._manager.dialectic_query.call_count == 1
# After first-turn sync, _last_dialectic_turn should be updated
assert provider._last_dialectic_turn != -999
def test_first_turn_dialectic_does_not_double_fire(self):
"""After first-turn sync dialectic, queue_prefetch should skip (cadence)."""
from unittest.mock import MagicMock
provider = self._make_provider(cfg_extra={"dialectic_depth": 1})
provider._manager = MagicMock()
provider._manager.dialectic_query.return_value = "cold start synthesis"
provider._manager.get_prefetch_context.return_value = None
provider._manager.pop_context_result.return_value = None
provider._session_key = "test"
provider._base_context_cache = ""
provider._last_dialectic_turn = -999
provider._turn_count = 0
# First turn fires sync dialectic
provider.prefetch("hello")
assert provider._manager.dialectic_query.call_count == 1
# Now queue_prefetch on same turn should skip (cadence: 0 - 0 < 3)
provider._manager.dialectic_query.reset_mock()
provider.queue_prefetch("hello")
assert provider._manager.dialectic_query.call_count == 0
def test_run_dialectic_depth_bails_early_on_strong_signal(self):
"""Depth 2 skips pass 1 when pass 0 returns strong signal."""
from unittest.mock import MagicMock
@ -1083,6 +1069,584 @@ class TestDialecticDepth:
assert provider._manager.dialectic_query.call_count == 1
# ---------------------------------------------------------------------------
# Trivial-prompt heuristic + dialectic cadence silent-failure guards
# ---------------------------------------------------------------------------
class TestTrivialPromptHeuristic:
"""Trivial prompts ('ok', 'y', slash commands) must short-circuit injection."""
@staticmethod
def _make_provider():
from unittest.mock import patch, MagicMock
from plugins.memory.honcho.client import HonchoClientConfig
cfg = HonchoClientConfig(api_key="test-key", enabled=True, recall_mode="hybrid")
provider = HonchoMemoryProvider()
mock_manager = MagicMock()
mock_session = MagicMock()
mock_session.messages = []
mock_manager.get_or_create.return_value = mock_session
with patch("plugins.memory.honcho.client.HonchoClientConfig.from_global_config", return_value=cfg), \
patch("plugins.memory.honcho.client.get_honcho_client", return_value=MagicMock()), \
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mock_manager), \
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
provider.initialize(session_id="test-session-trivial")
_settle_prewarm(provider)
return provider
def test_classifier_catches_common_trivial_forms(self):
for t in ("ok", "OK", " ok ", "y", "yes", "sure", "thanks", "lgtm", "/help", "", " "):
assert HonchoMemoryProvider._is_trivial_prompt(t), f"expected trivial: {t!r}"
def test_classifier_lets_substantive_prompts_through(self):
for t in ("hello world", "what's my name", "explain this", "ok so what's next"):
assert not HonchoMemoryProvider._is_trivial_prompt(t), f"expected non-trivial: {t!r}"
def test_prefetch_skips_on_trivial_prompt(self):
provider = self._make_provider()
provider._session_key = "test"
provider._base_context_cache = "cached base"
provider._last_dialectic_turn = 0
provider._turn_count = 5
assert provider.prefetch("ok") == ""
assert provider.prefetch("/help") == ""
# Dialectic should not have fired
assert provider._manager.dialectic_query.call_count == 0
def test_queue_prefetch_skips_on_trivial_prompt(self):
provider = self._make_provider()
provider._session_key = "test"
provider._turn_count = 10
provider._last_dialectic_turn = -999 # would otherwise fire
# initialize() pre-warms; clear call counts before the assertion.
provider._manager.prefetch_context.reset_mock()
provider._manager.dialectic_query.reset_mock()
provider.queue_prefetch("y")
# Trivial prompts short-circuit both context refresh and dialectic fire.
assert provider._manager.prefetch_context.call_count == 0
assert provider._manager.dialectic_query.call_count == 0
class TestDialecticCadenceAdvancesOnSuccess:
"""Cadence tracker advances only when the dialectic call returns a
non-empty result. Empty results (transient API error, sparse representation)
must retry on the next eligible turn instead of waiting the full cadence."""
@staticmethod
def _make_provider():
from unittest.mock import patch, MagicMock
from plugins.memory.honcho.client import HonchoClientConfig
cfg = HonchoClientConfig(
api_key="test-key", enabled=True, recall_mode="hybrid", dialectic_depth=1,
)
provider = HonchoMemoryProvider()
mock_manager = MagicMock()
mock_session = MagicMock()
mock_session.messages = []
mock_manager.get_or_create.return_value = mock_session
with patch("plugins.memory.honcho.client.HonchoClientConfig.from_global_config", return_value=cfg), \
patch("plugins.memory.honcho.client.get_honcho_client", return_value=MagicMock()), \
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mock_manager), \
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
provider.initialize(session_id="test-session-retry")
_settle_prewarm(provider)
return provider
def test_empty_dialectic_result_does_not_advance_cadence(self):
import time as _time
provider = self._make_provider()
provider._session_key = "test"
provider._manager.dialectic_query.return_value = "" # silent failure
provider._turn_count = 5
provider._last_dialectic_turn = 0 # would fire (5 - 0 = 5 ≥ 3)
provider.queue_prefetch("hello")
# wait for the background thread to settle
if provider._prefetch_thread:
provider._prefetch_thread.join(timeout=2.0)
# Dialectic call was attempted
assert provider._manager.dialectic_query.call_count == 1
# But cadence tracker did NOT advance — next turn should retry
assert provider._last_dialectic_turn == 0
def test_non_empty_dialectic_result_advances_cadence(self):
provider = self._make_provider()
provider._session_key = "test"
provider._manager.dialectic_query.return_value = "real synthesis output"
provider._turn_count = 5
provider._last_dialectic_turn = 0
provider.queue_prefetch("hello")
if provider._prefetch_thread:
provider._prefetch_thread.join(timeout=2.0)
assert provider._last_dialectic_turn == 5
def test_in_flight_thread_is_not_stacked(self):
import threading as _threading
import time as _time
provider = self._make_provider()
provider._session_key = "test"
provider._turn_count = 10
provider._last_dialectic_turn = 0
# Simulate a prior thread still running (fresh, not stale)
hold = _threading.Event()
def _block():
hold.wait(timeout=5.0)
fresh = _threading.Thread(target=_block, daemon=True)
fresh.start()
provider._prefetch_thread = fresh
provider._prefetch_thread_started_at = _time.monotonic() # fresh start
provider.queue_prefetch("hello")
# Should have short-circuited — no new dialectic call
assert provider._manager.dialectic_query.call_count == 0
hold.set()
fresh.join(timeout=2.0)
class TestSessionStartDialecticPrewarm:
"""Session-start prewarm fires a depth-aware dialectic whose result is
consumed by turn 1 no duplicate .chat() and no dead-cache orphaning."""
@staticmethod
def _make_provider(cfg_extra=None, dialectic_result="prewarm synthesis"):
from unittest.mock import patch, MagicMock
from plugins.memory.honcho.client import HonchoClientConfig
defaults = dict(api_key="test-key", enabled=True, recall_mode="hybrid")
if cfg_extra:
defaults.update(cfg_extra)
cfg = HonchoClientConfig(**defaults)
provider = HonchoMemoryProvider()
mock_manager = MagicMock()
mock_manager.get_or_create.return_value = MagicMock(messages=[])
mock_manager.get_prefetch_context.return_value = None
mock_manager.pop_context_result.return_value = None
mock_manager.dialectic_query.return_value = dialectic_result
with patch("plugins.memory.honcho.client.HonchoClientConfig.from_global_config", return_value=cfg), \
patch("plugins.memory.honcho.client.get_honcho_client", return_value=MagicMock()), \
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mock_manager), \
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
provider.initialize(session_id="test-prewarm")
return provider
def test_prewarm_populates_prefetch_result(self):
p = self._make_provider()
# Wait for prewarm thread to land
if p._prefetch_thread:
p._prefetch_thread.join(timeout=3.0)
with p._prefetch_lock:
assert p._prefetch_result == "prewarm synthesis"
assert p._last_dialectic_turn == 0
def test_turn1_consumes_prewarm_without_duplicate_dialectic(self):
"""With prewarm result already in _prefetch_result, turn 1 prefetch
should NOT fire another dialectic."""
p = self._make_provider()
if p._prefetch_thread:
p._prefetch_thread.join(timeout=3.0)
p._manager.dialectic_query.reset_mock()
p._session_key = "test-prewarm"
p._base_context_cache = ""
p._turn_count = 1
result = p.prefetch("hello world")
assert "prewarm synthesis" in result
# The sync first-turn path must NOT have fired another .chat()
assert p._manager.dialectic_query.call_count == 0
def test_turn1_falls_back_to_sync_when_prewarm_missing(self):
"""If the prewarm produced nothing (empty graph, API blip), turn 1
still fires its own sync dialectic."""
p = self._make_provider(dialectic_result="") # prewarm returns empty
if p._prefetch_thread:
p._prefetch_thread.join(timeout=3.0)
with p._prefetch_lock:
assert p._prefetch_result == "" # prewarm landed nothing
# Switch dialectic_query to return something on the sync first-turn call
p._manager.dialectic_query.return_value = "sync recovery"
p._manager.dialectic_query.reset_mock()
p._session_key = "test-prewarm"
p._base_context_cache = ""
p._turn_count = 1
result = p.prefetch("hello world")
assert "sync recovery" in result
assert p._manager.dialectic_query.call_count == 1
class TestDialecticLiveness:
"""Liveness + observability: stale-thread recovery, stale-result discard,
empty-streak backoff, and the snapshot method used for diagnostics."""
@staticmethod
def _make_provider(cfg_extra=None):
from unittest.mock import patch, MagicMock
from plugins.memory.honcho.client import HonchoClientConfig
defaults = dict(api_key="test-key", enabled=True, recall_mode="hybrid", timeout=2.0)
if cfg_extra:
defaults.update(cfg_extra)
cfg = HonchoClientConfig(**defaults)
provider = HonchoMemoryProvider()
mock_manager = MagicMock()
mock_manager.get_or_create.return_value = MagicMock(messages=[])
mock_manager.get_prefetch_context.return_value = None
mock_manager.pop_context_result.return_value = None
mock_manager.dialectic_query.return_value = "" # default: silent
with patch("plugins.memory.honcho.client.HonchoClientConfig.from_global_config", return_value=cfg), \
patch("plugins.memory.honcho.client.get_honcho_client", return_value=MagicMock()), \
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mock_manager), \
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
provider.initialize(session_id="test-liveness")
_settle_prewarm(provider)
return provider
def test_stale_thread_is_treated_as_dead(self):
"""A thread older than timeout × multiplier no longer blocks new fires."""
import threading as _threading
p = self._make_provider()
p._session_key = "test"
p._turn_count = 10
p._last_dialectic_turn = 0
p._manager.dialectic_query.return_value = "fresh synthesis"
# Plant an alive thread with an old timestamp (stale)
hold = _threading.Event()
stuck = _threading.Thread(target=lambda: hold.wait(timeout=10.0), daemon=True)
stuck.start()
p._prefetch_thread = stuck
# timeout=2.0, multiplier=2.0, so anything older than 4s is stale
p._prefetch_thread_started_at = 0.0 # very old (1970 monotonic baseline)
p.queue_prefetch("hello")
# New thread should have been spawned since stuck one is stale
assert p._prefetch_thread is not stuck, "stale thread must be recycled"
if p._prefetch_thread:
p._prefetch_thread.join(timeout=2.0)
assert p._manager.dialectic_query.call_count == 1
hold.set()
stuck.join(timeout=2.0)
def test_stale_pending_result_is_discarded_on_read(self):
"""A pending dialectic result from many turns ago is discarded
instead of injected against a fresh conversational pivot."""
p = self._make_provider(cfg_extra={"raw": {"dialecticCadence": 2}})
p._session_key = "test"
p._base_context_cache = "base ctx"
with p._prefetch_lock:
p._prefetch_result = "ancient synthesis"
p._prefetch_result_fired_at = 1
# cadence=2, multiplier=2 → stale after 4 turns since fire
p._turn_count = 10
p._last_dialectic_turn = 1 # prevents sync first-turn path
result = p.prefetch("what's new")
assert "ancient synthesis" not in result, "stale pending must be discarded"
# Cache slot cleared
with p._prefetch_lock:
assert p._prefetch_result == ""
assert p._prefetch_result_fired_at == -999
def test_fresh_pending_result_is_kept(self):
"""A pending result within the staleness window is injected normally."""
p = self._make_provider(cfg_extra={"raw": {"dialecticCadence": 3}})
p._session_key = "test"
p._base_context_cache = ""
with p._prefetch_lock:
p._prefetch_result = "recent synthesis"
p._prefetch_result_fired_at = 8
p._turn_count = 9 # 1 turn since fire, well within cadence × 2 = 6
p._last_dialectic_turn = 8
result = p.prefetch("what's new")
assert "recent synthesis" in result
def test_empty_streak_widens_effective_cadence(self):
"""After N empty returns, the gate waits cadence + N turns."""
p = self._make_provider(cfg_extra={"raw": {"dialecticCadence": 1}})
p._dialectic_empty_streak = 3
# cadence=1, streak=3 → effective = 4
assert p._effective_cadence() == 4
def test_backoff_is_capped(self):
"""Effective cadence is capped at cadence × _BACKOFF_MAX."""
p = self._make_provider(cfg_extra={"raw": {"dialecticCadence": 2}})
p._dialectic_empty_streak = 100
# cadence=2, ceiling = 2 × 8 = 16
assert p._effective_cadence() == 16
def test_success_resets_empty_streak(self):
"""A non-empty result zeroes the streak so healthy operation restores
the base cadence immediately."""
p = self._make_provider(cfg_extra={"raw": {"dialecticCadence": 1}})
p._session_key = "test"
p._dialectic_empty_streak = 5
p._turn_count = 10
p._last_dialectic_turn = 0
p._manager.dialectic_query.return_value = "real output"
p.queue_prefetch("hello")
if p._prefetch_thread:
p._prefetch_thread.join(timeout=2.0)
assert p._dialectic_empty_streak == 0
assert p._last_dialectic_turn == 10
def test_empty_result_increments_streak(self):
p = self._make_provider(cfg_extra={"raw": {"dialecticCadence": 1}})
p._session_key = "test"
p._turn_count = 5
p._last_dialectic_turn = 0
p._manager.dialectic_query.return_value = "" # empty
p.queue_prefetch("hello")
if p._prefetch_thread:
p._prefetch_thread.join(timeout=2.0)
assert p._dialectic_empty_streak == 1
assert p._last_dialectic_turn == 0 # cadence not advanced
def test_liveness_snapshot_shape(self):
p = self._make_provider()
snap = p.liveness_snapshot()
for key in (
"turn_count", "last_dialectic_turn", "pending_result_fired_at",
"empty_streak", "effective_cadence", "thread_alive", "thread_age_seconds",
):
assert key in snap
class TestDialecticLifecycleSmoke:
"""End-to-end smoke walking a multi-turn session through prewarm,
turn 1 consume, trivial skip, cadence fire, empty-result retry,
heuristic bump, and session-end flush."""
@staticmethod
def _make_provider(cfg_extra=None):
from unittest.mock import patch, MagicMock
from plugins.memory.honcho.client import HonchoClientConfig
defaults = dict(
api_key="test-key", enabled=True, recall_mode="hybrid",
dialectic_reasoning_level="low", reasoning_heuristic=True,
reasoning_level_cap="high", dialectic_depth=1,
)
if cfg_extra:
defaults.update(cfg_extra)
cfg = HonchoClientConfig(**defaults)
provider = HonchoMemoryProvider()
mock_manager = MagicMock()
mock_session = MagicMock()
mock_session.messages = []
mock_manager.get_or_create.return_value = mock_session
mock_manager.get_prefetch_context.return_value = None
mock_manager.pop_context_result.return_value = None
with patch("plugins.memory.honcho.client.HonchoClientConfig.from_global_config", return_value=cfg), \
patch("plugins.memory.honcho.client.get_honcho_client", return_value=MagicMock()), \
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mock_manager), \
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
return provider, mock_manager, cfg
def _await_thread(self, provider):
if provider._prefetch_thread:
provider._prefetch_thread.join(timeout=3.0)
def test_full_multi_turn_session(self):
"""Walks init → turns 1..8 → session end. Asserts at every step that
the plugin did exactly what it should and nothing more.
Uses dialecticCadence=3 so we can exercise skip-turns between fires
and the silent-failure retry path without their gates tripping each
other. Trivial + slash skips apply independent of cadence.
"""
from unittest.mock import patch, MagicMock
provider, mgr, cfg = self._make_provider(
cfg_extra={"raw": {"dialecticCadence": 3}}
)
# Program the dialectic responses in the exact order they'll be requested.
# An extra or missing call fails the test — strong smoke signal.
responses = iter([
"prewarm: user is eri, works on hermes", # session-start prewarm
"cadence fire: long query synthesis", # turn 4 queue_prefetch
"", # turn 7 fire: silent failure
"retry success: fresh synthesis", # turn 8 queue_prefetch retry
])
mgr.dialectic_query.side_effect = lambda *a, **kw: next(responses)
# ---- init: prewarm fires ----
with patch("plugins.memory.honcho.client.HonchoClientConfig.from_global_config", return_value=cfg), \
patch("plugins.memory.honcho.client.get_honcho_client", return_value=MagicMock()), \
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mgr), \
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
provider.initialize(session_id="smoke-test")
self._await_thread(provider)
with provider._prefetch_lock:
assert provider._prefetch_result.startswith("prewarm"), \
"session-start prewarm must land in _prefetch_result"
assert provider._last_dialectic_turn == 0, "prewarm marks turn 0"
assert mgr.dialectic_query.call_count == 1
# ---- turn 1: consume prewarm, no duplicate dialectic ----
provider.on_turn_start(1, "hey")
inject1 = provider.prefetch("hey")
assert "prewarm" in inject1, "turn 1 must surface prewarm"
provider.sync_turn("hey", "hi there")
provider.queue_prefetch("hey") # cadence gate: (1-0)<3 → skip
self._await_thread(provider)
assert mgr.dialectic_query.call_count == 1, \
"turn 1 must not fire — prewarm covered it and cadence skips"
# ---- turn 2: trivial 'ok' → skip everything ----
mgr.prefetch_context.reset_mock()
provider.on_turn_start(2, "ok")
assert provider.prefetch("ok") == "", "trivial prompt must short-circuit injection"
provider.sync_turn("ok", "cool")
provider.queue_prefetch("ok")
self._await_thread(provider)
assert mgr.dialectic_query.call_count == 1, "trivial must not fire dialectic"
assert mgr.prefetch_context.call_count == 0, "trivial must not fire context refresh"
# ---- turn 3: slash '/help' → also skip ----
provider.on_turn_start(3, "/help")
assert provider.prefetch("/help") == ""
provider.queue_prefetch("/help")
assert mgr.dialectic_query.call_count == 1
# ---- turn 4: long query → cadence fires + heuristic bumps ----
long_q = "walk me through " + ("x " * 100) # ~200 chars → heuristic +1
provider.on_turn_start(4, long_q)
provider.prefetch(long_q)
provider.sync_turn(long_q, "sure")
provider.queue_prefetch(long_q) # (4-0)≥3 → fires
self._await_thread(provider)
assert mgr.dialectic_query.call_count == 2, "turn 4 cadence fire"
_, kwargs = mgr.dialectic_query.call_args
assert kwargs.get("reasoning_level") in ("medium", "high"), \
f"long query must bump reasoning level above 'low'; got {kwargs.get('reasoning_level')}"
assert provider._last_dialectic_turn == 4, "cadence tracker advances on success"
# ---- turns 56: cadence cooldown, no fires ----
for t in (5, 6):
provider.on_turn_start(t, "tell me more")
provider.queue_prefetch("tell me more")
self._await_thread(provider)
assert mgr.dialectic_query.call_count == 2, "turns 56 blocked by cadence window"
# ---- turn 7: fires but silent failure (empty dialectic) ----
provider.on_turn_start(7, "and then what")
provider.queue_prefetch("and then what") # (7-4)≥3 → fires
self._await_thread(provider)
assert mgr.dialectic_query.call_count == 3, "turn 7 fires"
assert provider._last_dialectic_turn == 4, \
"silent failure must NOT burn the cadence window"
# ---- turn 8: retries because cadence didn't advance ----
provider.on_turn_start(8, "try again")
provider.queue_prefetch("try again") # (8-4)≥3 → fires again
self._await_thread(provider)
assert mgr.dialectic_query.call_count == 4, \
"turn 8 retries because turn 7's empty result didn't advance cadence"
assert provider._last_dialectic_turn == 8, "retry success advances"
# ---- session end: flush messages ----
provider.on_session_end([])
mgr.flush_all.assert_called()
class TestReasoningHeuristic:
"""Char-count heuristic that scales the auto-injected reasoning level by
query length, clamped at reasoning_level_cap."""
@staticmethod
def _make_provider(cfg_extra=None):
from unittest.mock import patch, MagicMock
from plugins.memory.honcho.client import HonchoClientConfig
defaults = dict(
api_key="test-key", enabled=True, recall_mode="hybrid",
dialectic_reasoning_level="low", reasoning_heuristic=True,
reasoning_level_cap="high",
)
if cfg_extra:
defaults.update(cfg_extra)
cfg = HonchoClientConfig(**defaults)
provider = HonchoMemoryProvider()
mock_manager = MagicMock()
mock_manager.get_or_create.return_value = MagicMock(messages=[])
with patch("plugins.memory.honcho.client.HonchoClientConfig.from_global_config", return_value=cfg), \
patch("plugins.memory.honcho.client.get_honcho_client", return_value=MagicMock()), \
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mock_manager), \
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
provider.initialize(session_id="test-heuristic")
_settle_prewarm(provider)
return provider
def test_short_query_stays_at_base(self):
p = self._make_provider()
assert p._apply_reasoning_heuristic("low", "hey") == "low"
def test_medium_query_bumps_one_level(self):
p = self._make_provider()
q = "x" * 150
assert p._apply_reasoning_heuristic("low", q) == "medium"
def test_long_query_bumps_two_levels(self):
p = self._make_provider()
q = "x" * 500
assert p._apply_reasoning_heuristic("low", q) == "high"
def test_bump_respects_cap(self):
p = self._make_provider(cfg_extra={"reasoning_level_cap": "medium"})
q = "x" * 500 # would hit 'high' without the cap
assert p._apply_reasoning_heuristic("low", q) == "medium"
def test_max_never_auto_selected_with_default_cap(self):
p = self._make_provider(cfg_extra={"dialectic_reasoning_level": "high"})
q = "x" * 500 # base=high, bump would push to 'max'
assert p._apply_reasoning_heuristic("high", q) == "high"
def test_heuristic_disabled_returns_base(self):
p = self._make_provider(cfg_extra={"reasoning_heuristic": False})
q = "x" * 500
assert p._apply_reasoning_heuristic("low", q) == "low"
def test_resolve_pass_level_applies_heuristic_at_base_mapping(self):
"""Depth=1, pass 0 maps to 'base' → heuristic applies."""
p = self._make_provider()
q = "x" * 150
assert p._resolve_pass_level(0, query=q) == "medium"
def test_resolve_pass_level_does_not_touch_explicit_per_pass(self):
"""dialecticDepthLevels wins absolutely — no heuristic scaling."""
p = self._make_provider(cfg_extra={"dialectic_depth_levels": ["minimal"]})
q = "x" * 500 # heuristic would otherwise bump to 'high'
assert p._resolve_pass_level(0, query=q) == "minimal"
def test_resolve_pass_level_does_not_touch_lighter_passes(self):
"""Depth 3 pass 0 is hardcoded 'minimal' — heuristic must not bump it."""
p = self._make_provider(cfg_extra={"dialectic_depth": 3})
q = "x" * 500
assert p._resolve_pass_level(0, query=q) == "minimal"
# But the 'base' pass (idx 1 for depth 3) does get heuristic
assert p._resolve_pass_level(1, query=q) == "high"
# ---------------------------------------------------------------------------
# set_peer_card None guard
# ---------------------------------------------------------------------------

View file

@ -23,6 +23,10 @@ def _make_agent(monkeypatch):
class _Stub:
_interrupt_requested = False
_interrupt_message = None
# Bind to this thread's ident so interrupt() targets a real tid.
_execution_thread_id = threading.current_thread().ident
_interrupt_thread_signal_pending = False
log_prefix = ""
quiet_mode = True
verbose_logging = False
@ -40,6 +44,15 @@ def _make_agent(monkeypatch):
_current_tool = None
_last_activity = 0
_print_fn = print
# Worker-thread tracking state mirrored from AIAgent.__init__ so the
# real interrupt() method can fan out to concurrent-tool workers.
_active_children: list = []
def __init__(self):
# Instance-level (not class-level) so each test gets a fresh set.
self._tool_worker_threads: set = set()
self._tool_worker_threads_lock = threading.Lock()
self._active_children_lock = threading.Lock()
def _touch_activity(self, desc):
self._last_activity = time.time()
@ -60,8 +73,10 @@ def _make_agent(monkeypatch):
return False
stub = _Stub()
# Bind the real methods
# Bind the real methods under test
stub._execute_tool_calls_concurrent = _ra.AIAgent._execute_tool_calls_concurrent.__get__(stub)
stub.interrupt = _ra.AIAgent.interrupt.__get__(stub)
stub.clear_interrupt = _ra.AIAgent.clear_interrupt.__get__(stub)
stub._invoke_tool = MagicMock(side_effect=lambda *a, **kw: '{"ok": true}')
return stub
@ -137,3 +152,109 @@ def test_concurrent_preflight_interrupt_skips_all(monkeypatch):
assert "skipped due to user interrupt" in messages[1]["content"]
# _invoke_tool should never have been called
agent._invoke_tool.assert_not_called()
def test_running_concurrent_worker_sees_is_interrupted(monkeypatch):
"""Regression guard for the "interrupt-doesn't-reach-hung-tool" class of
bug Physikal reported in April 2026.
Before this fix, `AIAgent.interrupt()` called `_set_interrupt(True,
_execution_thread_id)` which only flagged the agent's *main* thread.
Tools running inside `_execute_tool_calls_concurrent` execute on
ThreadPoolExecutor worker threads whose tids are NOT the agent's, so
`is_interrupted()` (which checks the *current* thread's tid) returned
False inside those tools no matter how many times the gateway called
`.interrupt()`. Hung ssh / long curl / big make-build tools would run
to their own timeout.
This test runs a fake tool in the concurrent path that polls
`is_interrupted()` like a real terminal command does, then calls
`agent.interrupt()` from another thread, and asserts the poll sees True
within one second.
"""
from tools.interrupt import is_interrupted
agent = _make_agent(monkeypatch)
# Counter plus observation hooks so we can prove the worker saw the flip.
observed = {"saw_true": False, "poll_count": 0, "worker_tid": None}
worker_started = threading.Event()
def polling_tool(name, args, task_id, call_id=None):
observed["worker_tid"] = threading.current_thread().ident
worker_started.set()
deadline = time.monotonic() + 5.0
while time.monotonic() < deadline:
observed["poll_count"] += 1
if is_interrupted():
observed["saw_true"] = True
return '{"interrupted": true}'
time.sleep(0.05)
return '{"timed_out": true}'
agent._invoke_tool = MagicMock(side_effect=polling_tool)
tc1 = _FakeToolCall("hung_fake_tool_1", call_id="tc1")
tc2 = _FakeToolCall("hung_fake_tool_2", call_id="tc2")
msg = _FakeAssistantMsg([tc1, tc2])
messages = []
def _interrupt_after_start():
# Wait until at least one worker is running so its tid is tracked.
worker_started.wait(timeout=2.0)
time.sleep(0.2) # let the other worker enter too
agent.interrupt("stop requested by test")
t = threading.Thread(target=_interrupt_after_start)
t.start()
start = time.monotonic()
agent._execute_tool_calls_concurrent(msg, messages, "test_task")
elapsed = time.monotonic() - start
t.join(timeout=2.0)
# The worker must have actually polled is_interrupted — otherwise the
# test isn't exercising what it claims to.
assert observed["poll_count"] > 0, (
"polling_tool never ran — test scaffold issue"
)
# The worker must see the interrupt within ~1 s of agent.interrupt()
# being called. Before the fix this loop ran until its 5 s own-timeout.
assert observed["saw_true"], (
f"is_interrupted() never returned True inside the concurrent worker "
f"after agent.interrupt() — interrupt-propagation hole regressed. "
f"worker_tid={observed['worker_tid']!r} poll_count={observed['poll_count']}"
)
assert elapsed < 3.0, (
f"concurrent execution took {elapsed:.2f}s after interrupt — the fan-out "
f"to worker tids didn't shortcut the tool's poll loop as expected"
)
# Also verify cleanup: no stale worker tids should remain after all
# tools finished.
assert agent._tool_worker_threads == set(), (
f"worker tids leaked after run: {agent._tool_worker_threads}"
)
def test_clear_interrupt_clears_worker_tids(monkeypatch):
"""After clear_interrupt(), stale worker-tid bits must be cleared so the
next turn's tools — which may be scheduled onto recycled tids — don't
see a false interrupt."""
from tools.interrupt import is_interrupted, set_interrupt
agent = _make_agent(monkeypatch)
# Simulate a worker having registered but not yet exited cleanly (e.g. a
# hypothetical bug in the tear-down). Put a fake tid in the set and
# flag it interrupted.
fake_tid = threading.current_thread().ident # use real tid so is_interrupted can see it
with agent._tool_worker_threads_lock:
agent._tool_worker_threads.add(fake_tid)
set_interrupt(True, fake_tid)
assert is_interrupted() is True # sanity
agent.clear_interrupt()
assert is_interrupted() is False, (
"clear_interrupt() did not clear the interrupt bit for a tracked "
"worker tid — stale interrupt can leak into the next turn"
)

View file

@ -0,0 +1,39 @@
"""Regression tests for memory provider selection during AIAgent init."""
from types import SimpleNamespace
from unittest.mock import patch
def test_blank_memory_provider_does_not_auto_enable_honcho():
"""Blank memory.provider should remain opt-out even if Honcho fallback looks configured."""
cfg = {"memory": {"provider": ""}, "agent": {}}
honcho_cfg = SimpleNamespace(enabled=True, api_key="stale-key", base_url=None)
with (
patch("hermes_cli.config.load_config", return_value=cfg),
patch("hermes_cli.config.save_config") as save_config,
patch(
"plugins.memory.honcho.client.HonchoClientConfig.from_global_config",
return_value=honcho_cfg,
) as from_global_config,
patch("plugins.memory.load_memory_provider") as load_memory_provider,
patch("agent.model_metadata.get_model_context_length", return_value=204_800),
patch("run_agent.get_tool_definitions", return_value=[]),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
):
from run_agent import AIAgent
agent = AIAgent(
api_key="test-key-1234567890",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=False,
)
assert agent._memory_manager is None
from_global_config.assert_not_called()
load_memory_provider.assert_not_called()
save_config.assert_not_called()

View file

@ -317,6 +317,60 @@ class TestStripThinkBlocks:
result = agent._strip_think_blocks("<thought>orphaned reasoning without close")
assert "<thought>" not in result
# ─── Unterminated-block coverage (#8878, #9568, #10408) ──────────────
# Reasoning models served via NIM / MiniMax M2.7 frequently drop the
# closing tag, leaking raw reasoning into assistant content. The open
# tag appears at a block boundary (start of text or after a newline);
# everything from that tag to end-of-string is stripped.
def test_unterminated_think_block_content_stripped(self, agent):
"""Content after unterminated <think> is fully stripped."""
result = agent._strip_think_blocks("<think>orphaned reasoning without close")
assert "orphaned reasoning" not in result
assert result.strip() == ""
def test_unterminated_thought_block_content_stripped(self, agent):
"""Gemma-style <thought> with no close is fully stripped."""
result = agent._strip_think_blocks("<thought>orphaned reasoning without close")
assert "orphaned reasoning" not in result
assert result.strip() == ""
def test_unterminated_multiline_block_stripped(self, agent):
"""Multi-line unterminated blocks are stripped in full."""
result = agent._strip_think_blocks(
"<think>\nmulti\nline\nreasoning\nthat never closes"
)
assert "multi" not in result
assert "never closes" not in result
def test_unterminated_block_after_answer_preserves_prefix(self, agent):
"""Visible answer before a line-starting unterminated tag is kept."""
result = agent._strip_think_blocks(
"Answer is 42.\n<think>actually let me reconsider"
)
assert "Answer is 42." in result
assert "reconsider" not in result
def test_inline_think_mention_in_prose_not_over_stripped(self, agent):
"""Mid-line `<think>` mentioned in prose must not swallow the rest
of the content (the block-boundary check prevents this)."""
text = "Use the <think> tag like this in your prose."
result = agent._strip_think_blocks(text)
# Block-boundary check prevents unterminated-strip from firing
assert "prose" in result
assert "Use the" in result
def test_mixed_case_closed_pair_stripped(self, agent):
"""Mixed-case variants <THINK>…</THINK>, <Thinking>…</Thinking> are
handled by case-insensitive closed-pair regex, so the trailing
content is preserved."""
result = agent._strip_think_blocks("<THINK>upper</THINK>final")
assert "upper" not in result
assert "final" in result
result = agent._strip_think_blocks("<Thinking>mixed</Thinking>final")
assert "mixed" not in result
assert "final" in result
class TestExtractReasoning:
def test_reasoning_field(self, agent):
@ -1088,6 +1142,41 @@ class TestBuildAssistantMessage:
result = agent._build_assistant_message(msg, "tool_calls")
assert "extra_content" not in result["tool_calls"][0]
def test_think_blocks_stripped_from_content(self, agent):
"""Inline <think> blocks are stripped from stored content (#8878, #9568).
The reasoning is captured into ``msg['reasoning']`` via the inline
fallback in ``_extract_reasoning``; the raw tags in ``content`` are
redundant and leak to messaging platforms / pollute titles /
inflate context if left in place.
"""
msg = _mock_assistant_msg(
content="<think>internal reasoning</think>The actual answer."
)
result = agent._build_assistant_message(msg, "stop")
assert "<think>" not in result["content"]
assert "internal reasoning" not in result["content"]
assert "The actual answer." in result["content"]
# Reasoning preserved separately via inline extraction fallback
assert result["reasoning"] == "internal reasoning"
def test_think_blocks_stripped_preserves_normal_content(self, agent):
"""Content without reasoning tags passes through unchanged."""
msg = _mock_assistant_msg(content="No thinking here.")
result = agent._build_assistant_message(msg, "stop")
assert result["content"] == "No thinking here."
def test_unterminated_think_block_stripped(self, agent):
"""Unterminated <think> block (MiniMax / NIM dropped close tag) is
fully stripped from stored content."""
msg = _mock_assistant_msg(
content="<think>reasoning that never closes on this NIM endpoint"
)
result = agent._build_assistant_message(msg, "stop")
assert "<think>" not in result["content"]
assert "reasoning that never closes" not in result["content"]
assert result["content"] == ""
class TestFormatToolsForSystemMessage:
def test_no_tools_returns_empty_array(self, agent):
@ -1196,6 +1285,7 @@ class TestExecuteToolCalls:
tc = _mock_tool_call(name="web_search", arguments='{"q":"test"}', call_id="c1")
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc])
messages = []
agent.platform = "cli"
agent.tool_progress_callback = None
with patch("run_agent.handle_function_call", return_value="search result"), \
@ -1207,6 +1297,21 @@ class TestExecuteToolCalls:
assert len(messages) == 1
assert messages[0]["role"] == "tool"
def test_quiet_tool_output_suppressed_without_progress_callback_for_non_cli_agent(self, agent):
tc = _mock_tool_call(name="web_search", arguments='{"q":"test"}', call_id="c1")
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc])
messages = []
agent.platform = None
agent.tool_progress_callback = None
with patch("run_agent.handle_function_call", return_value="search result"), \
patch.object(agent, "_safe_print") as mock_print:
agent._execute_tool_calls(mock_msg, messages, "task-1")
mock_print.assert_not_called()
assert len(messages) == 1
assert messages[0]["role"] == "tool"
def test_vprint_suppressed_in_parseable_quiet_mode(self, agent):
agent.suppress_status_output = True
@ -1787,6 +1892,30 @@ class TestRunConversation:
assert all("message_count" in c and "messages" not in c for c in pre_request_calls)
assert all("usage" in c and "response" not in c for c in post_request_calls)
def test_content_with_tool_calls_stays_silent_for_non_cli_quiet_mode(self, agent):
self._setup_agent(agent)
agent.platform = None
tc = _mock_tool_call(name="web_search", arguments="{}", call_id="c1")
resp1 = _mock_response(
content="I'll search for that.",
finish_reason="tool_calls",
tool_calls=[tc],
)
resp2 = _mock_response(content="Done searching", finish_reason="stop")
agent.client.chat.completions.create.side_effect = [resp1, resp2]
with (
patch("run_agent.handle_function_call", return_value="search result"),
patch.object(agent, "_safe_print") as mock_print,
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
):
result = agent.run_conversation("search something")
assert result["final_response"] == "Done searching"
mock_print.assert_not_called()
def test_interrupt_breaks_loop(self, agent):
self._setup_agent(agent)

View file

@ -0,0 +1,228 @@
"""Tests for AIAgent.steer() — mid-run user message injection.
/steer lets the user add a note to the agent's next tool result without
interrupting the current tool call. The agent sees the note inline with
tool output on its next iteration, preserving message-role alternation
and prompt-cache integrity.
"""
from __future__ import annotations
import threading
import pytest
from run_agent import AIAgent
def _bare_agent() -> AIAgent:
"""Build an AIAgent without running __init__, then install the steer
state manually matches the existing object.__new__ stub pattern
used elsewhere in the test suite.
"""
agent = object.__new__(AIAgent)
agent._pending_steer = None
agent._pending_steer_lock = threading.Lock()
return agent
class TestSteerAcceptance:
def test_accepts_non_empty_text(self):
agent = _bare_agent()
assert agent.steer("go ahead and check the logs") is True
assert agent._pending_steer == "go ahead and check the logs"
def test_rejects_empty_string(self):
agent = _bare_agent()
assert agent.steer("") is False
assert agent._pending_steer is None
def test_rejects_whitespace_only(self):
agent = _bare_agent()
assert agent.steer(" \n\t ") is False
assert agent._pending_steer is None
def test_rejects_none(self):
agent = _bare_agent()
assert agent.steer(None) is False # type: ignore[arg-type]
assert agent._pending_steer is None
def test_strips_surrounding_whitespace(self):
agent = _bare_agent()
assert agent.steer(" hello world \n") is True
assert agent._pending_steer == "hello world"
def test_concatenates_multiple_steers_with_newlines(self):
agent = _bare_agent()
agent.steer("first note")
agent.steer("second note")
agent.steer("third note")
assert agent._pending_steer == "first note\nsecond note\nthird note"
class TestSteerDrain:
def test_drain_returns_and_clears(self):
agent = _bare_agent()
agent.steer("hello")
assert agent._drain_pending_steer() == "hello"
assert agent._pending_steer is None
def test_drain_on_empty_returns_none(self):
agent = _bare_agent()
assert agent._drain_pending_steer() is None
class TestSteerInjection:
def test_appends_to_last_tool_result(self):
agent = _bare_agent()
agent.steer("please also check auth.log")
messages = [
{"role": "user", "content": "what's in /var/log?"},
{"role": "assistant", "tool_calls": [{"id": "a"}, {"id": "b"}]},
{"role": "tool", "content": "ls output A", "tool_call_id": "a"},
{"role": "tool", "content": "ls output B", "tool_call_id": "b"},
]
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=2)
# The LAST tool result is modified; earlier ones are untouched.
assert messages[2]["content"] == "ls output A"
assert "ls output B" in messages[3]["content"]
assert "[USER STEER" in messages[3]["content"]
assert "please also check auth.log" in messages[3]["content"]
# And pending_steer is consumed.
assert agent._pending_steer is None
def test_no_op_when_no_steer_pending(self):
agent = _bare_agent()
messages = [
{"role": "assistant", "tool_calls": [{"id": "a"}]},
{"role": "tool", "content": "output", "tool_call_id": "a"},
]
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1)
assert messages[-1]["content"] == "output" # unchanged
def test_no_op_when_num_tool_msgs_zero(self):
agent = _bare_agent()
agent.steer("steer")
messages = [{"role": "user", "content": "hi"}]
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=0)
# Steer should remain pending (nothing to drain into)
assert agent._pending_steer == "steer"
def test_marker_is_unambiguous_about_origin(self):
"""The injection marker must make clear the text is from the user
and not tool output this is the cache-safe way to signal
provenance without violating message-role alternation.
"""
agent = _bare_agent()
agent.steer("stop after next step")
messages = [{"role": "tool", "content": "x", "tool_call_id": "1"}]
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1)
content = messages[-1]["content"]
assert "USER STEER" in content
assert "not tool output" in content.lower() or "injected mid-run" in content.lower()
def test_multimodal_content_list_preserved(self):
"""Anthropic-style list content should be preserved, with the steer
appended as a text block."""
agent = _bare_agent()
agent.steer("extra note")
original_blocks = [{"type": "text", "text": "existing output"}]
messages = [
{"role": "tool", "content": list(original_blocks), "tool_call_id": "1"}
]
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1)
new_content = messages[-1]["content"]
assert isinstance(new_content, list)
assert len(new_content) == 2
assert new_content[0] == {"type": "text", "text": "existing output"}
assert new_content[1]["type"] == "text"
assert "extra note" in new_content[1]["text"]
def test_restashed_when_no_tool_result_in_batch(self):
"""If the 'batch' contains no tool-role messages (e.g. all skipped
after an interrupt), the steer should be put back into the pending
slot so the caller's fallback path can deliver it."""
agent = _bare_agent()
agent.steer("ping")
messages = [
{"role": "user", "content": "x"},
{"role": "assistant", "content": "y"},
]
# Claim there were N tool msgs, but the tail has none — simulates
# the interrupt-cancelled case.
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=2)
# Messages untouched
assert messages[-1]["content"] == "y"
# And the steer is back in pending so the fallback can grab it
assert agent._pending_steer == "ping"
class TestSteerThreadSafety:
def test_concurrent_steer_calls_preserve_all_text(self):
agent = _bare_agent()
N = 200
def worker(idx: int) -> None:
agent.steer(f"note-{idx}")
threads = [threading.Thread(target=worker, args=(i,)) for i in range(N)]
for t in threads:
t.start()
for t in threads:
t.join()
text = agent._drain_pending_steer()
assert text is not None
# Every single note must be preserved — none dropped by the lock.
lines = text.split("\n")
assert len(lines) == N
assert set(lines) == {f"note-{i}" for i in range(N)}
class TestSteerClearedOnInterrupt:
def test_clear_interrupt_drops_pending_steer(self):
"""A hard interrupt supersedes any pending steer — the agent's
next tool iteration won't happen, so delivering the steer later
would be surprising."""
agent = _bare_agent()
# Minimal surface needed by clear_interrupt()
agent._interrupt_requested = True
agent._interrupt_message = None
agent._interrupt_thread_signal_pending = False
agent._execution_thread_id = None
agent._tool_worker_threads = None
agent._tool_worker_threads_lock = None
agent.steer("will be dropped")
assert agent._pending_steer == "will be dropped"
agent.clear_interrupt()
assert agent._pending_steer is None
class TestSteerCommandRegistry:
def test_steer_in_command_registry(self):
"""The /steer slash command must be registered so it reaches all
platforms (CLI, gateway, TUI autocomplete, Telegram/Slack menus).
"""
from hermes_cli.commands import resolve_command, ACTIVE_SESSION_BYPASS_COMMANDS
cmd = resolve_command("steer")
assert cmd is not None
assert cmd.name == "steer"
assert cmd.category == "Session"
assert cmd.args_hint == "<prompt>"
def test_steer_in_bypass_set(self):
"""When the agent is running, /steer MUST bypass the Level-1
base-adapter queue so it reaches the gateway runner's /steer
handler. Otherwise it would be queued as user text and only
delivered at turn end defeating the whole point.
"""
from hermes_cli.commands import ACTIVE_SESSION_BYPASS_COMMANDS, should_bypass_active_session
assert "steer" in ACTIVE_SESSION_BYPASS_COMMANDS
assert should_bypass_active_session("steer") is True
if __name__ == "__main__": # pragma: no cover
pytest.main([__file__, "-v"])

View file

@ -141,6 +141,50 @@ class TestStreamingAccumulator:
assert tc[0].function.name == "terminal"
assert tc[0].function.arguments == '{"command": "ls"}'
@patch("run_agent.AIAgent._create_request_openai_client")
@patch("run_agent.AIAgent._close_request_openai_client")
def test_tool_name_not_duplicated_when_resent_per_chunk(self, mock_close, mock_create):
"""MiniMax M2.7 via NVIDIA NIM resends the full name in every chunk.
Bug #8259: the old += accumulation produced "read_fileread_file".
Assignment (matching OpenAI Node SDK / LiteLLM) prevents this.
"""
from run_agent import AIAgent
chunks = [
_make_stream_chunk(tool_calls=[
_make_tool_call_delta(index=0, tc_id="call_nim", name="read_file")
]),
_make_stream_chunk(tool_calls=[
_make_tool_call_delta(index=0, tc_id="call_nim", name="read_file", arguments='{"path":')
]),
_make_stream_chunk(tool_calls=[
_make_tool_call_delta(index=0, tc_id="call_nim", name="read_file", arguments=' "x.py"}')
]),
_make_stream_chunk(finish_reason="tool_calls"),
]
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = iter(chunks)
mock_create.return_value = mock_client
agent = AIAgent(
model="test/model",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
agent.api_mode = "chat_completions"
agent._interrupt_requested = False
response = agent._interruptible_streaming_api_call({})
tc = response.choices[0].message.tool_calls
assert tc is not None
assert len(tc) == 1
assert tc[0].function.name == "read_file"
assert tc[0].function.arguments == '{"path": "x.py"}'
@patch("run_agent.AIAgent._create_request_openai_client")
@patch("run_agent.AIAgent._close_request_openai_client")
def test_tool_call_extra_content_preserved(self, mock_close, mock_create):
@ -952,3 +996,138 @@ class TestAnthropicStreamCallbacks:
agent._interruptible_streaming_api_call({})
assert touch_calls.count("receiving stream response") == len(events)
class TestPartialToolCallWarning:
"""Regression: when a stream dies mid tool-call argument generation after
text was already delivered, the partial-stream stub at run_agent.py
line ~6107 used to silently set ``tool_calls=None`` and return
``finish_reason=stop``, losing the attempted action with zero user-facing
signal. Live-observed Apr 2026 with MiniMax M2.7 on a 6-minute audit
task agent streamed commentary, emitted a write_file tool call,
MiniMax stalled for 240 s mid-arguments, stale-stream detector killed
the connection, the stub returned, session ended with no file written
and no error shown.
Fix: when the stream accumulator captured any tool-call names before the
error, the stub now appends a user-visible warning to content AND fires
it as a stream delta so the user sees it immediately.
"""
@patch("run_agent.AIAgent._create_request_openai_client")
@patch("run_agent.AIAgent._close_request_openai_client")
def test_partial_tool_call_surfaces_warning(self, mock_close, mock_create):
"""Stream with text + partial tool-call name + mid-stream error
produces a stub whose content contains the user-visible warning
and whose tool_calls is None."""
from run_agent import AIAgent
class _StallError(RuntimeError):
pass
def _stalling_stream():
yield _make_stream_chunk(content="Let me write the audit: ")
yield _make_stream_chunk(tool_calls=[
_make_tool_call_delta(index=0, tc_id="call_1", name="write_file"),
])
yield _make_stream_chunk(tool_calls=[
_make_tool_call_delta(index=0, arguments='{"path": "/tmp/x", '),
])
raise _StallError("simulated upstream stall")
mock_client = MagicMock()
mock_client.chat.completions.create.side_effect = lambda *a, **kw: _stalling_stream()
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
agent.api_mode = "chat_completions"
agent._interrupt_requested = False
fired_deltas: list = []
agent._fire_stream_delta = lambda text: fired_deltas.append(text)
agent._current_streamed_assistant_text = "Let me write the audit: "
import os as _os
_prev = _os.environ.get("HERMES_STREAM_RETRIES")
_os.environ["HERMES_STREAM_RETRIES"] = "0"
try:
response = agent._interruptible_streaming_api_call({})
finally:
if _prev is None:
_os.environ.pop("HERMES_STREAM_RETRIES", None)
else:
_os.environ["HERMES_STREAM_RETRIES"] = _prev
content = response.choices[0].message.content or ""
assert "Let me write the audit:" in content, (
f"Partial text not preserved in stub: {content!r}"
)
assert "Stream stalled mid tool-call" in content, (
f"Stub content is missing the dropped-tool-call warning; users "
f"get silent failure. Got content={content!r}"
)
assert "write_file" in content, (
f"Warning should name the dropped tool. Got: {content!r}"
)
assert response.choices[0].message.tool_calls is None
assert any("Stream stalled mid tool-call" in d for d in fired_deltas), (
f"Warning was not surfaced as a live stream delta. "
f"fired_deltas={fired_deltas}"
)
@patch("run_agent.AIAgent._create_request_openai_client")
@patch("run_agent.AIAgent._close_request_openai_client")
def test_partial_text_only_no_warning(self, mock_close, mock_create):
"""Text-only partial stream (no tool call mid-flight) keeps the
pre-fix behaviour: bare recovered text, no warning noise."""
from run_agent import AIAgent
class _StallError(RuntimeError):
pass
def _stalling_stream():
yield _make_stream_chunk(content="Here's my answer so far")
raise _StallError("simulated upstream stall")
mock_client = MagicMock()
mock_client.chat.completions.create.side_effect = lambda *a, **kw: _stalling_stream()
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
agent.api_mode = "chat_completions"
agent._interrupt_requested = False
agent._current_streamed_assistant_text = "Here's my answer so far"
import os as _os
_prev = _os.environ.get("HERMES_STREAM_RETRIES")
_os.environ["HERMES_STREAM_RETRIES"] = "0"
try:
response = agent._interruptible_streaming_api_call({})
finally:
if _prev is None:
_os.environ.pop("HERMES_STREAM_RETRIES", None)
else:
_os.environ["HERMES_STREAM_RETRIES"] = _prev
content = response.choices[0].message.content or ""
assert content == "Here's my answer so far", (
f"Pre-fix behaviour regressed for text-only partial streams: {content!r}"
)
assert "Stream stalled" not in content, (
f"Unexpected warning on text-only partial stream: {content!r}"
)

View file

@ -479,6 +479,141 @@ class TestFTS5Search:
assert s('my-app.config.ts') == '"my-app.config.ts"'
# =========================================================================
# CJK (Chinese/Japanese/Korean) LIKE fallback
# =========================================================================
class TestCJKSearchFallback:
"""Regression tests for CJK search (see #11511).
SQLite FTS5's default tokenizer treats contiguous CJK runs as a single
token ("和其他agent的聊天记录" one token), so substring queries like
"记忆断裂" return 0 rows despite the data being present. SessionDB falls
back to LIKE substring matching whenever FTS5 returns no results and
the query contains CJK characters.
"""
def test_cjk_detection_covers_all_ranges(self):
from hermes_state import SessionDB
f = SessionDB._contains_cjk
# Chinese (CJK Unified Ideographs)
assert f("记忆断裂") is True
# Japanese Hiragana + Katakana
assert f("こんにちは") is True
assert f("カタカナ") is True
# Korean Hangul syllables (both early and late — guards against
# the \ud7a0-\ud7af typo seen in one of the duplicate PRs)
assert f("안녕하세요") is True
assert f("기억") is True
# Non-CJK
assert f("hello world") is False
assert f("日本語mixedwithenglish") is True
assert f("") is False
def test_chinese_multichar_query_returns_results(self, db):
"""The headline bug: multi-char Chinese query must not return []."""
db.create_session(session_id="s1", source="cli")
db.append_message(
"s1", role="user",
content="昨天和其他Agent的聊天记录记忆断裂问题复现了",
)
results = db.search_messages("记忆断裂")
assert len(results) == 1
assert results[0]["session_id"] == "s1"
def test_chinese_bigram_query(self, db):
db.create_session(session_id="s1", source="telegram")
db.append_message("s1", role="user", content="今天讨论A2A通信协议的实现")
results = db.search_messages("通信")
assert len(results) == 1
def test_korean_query_returns_results(self, db):
"""Guards against Hangul range typos (\\uac00-\\ud7af, not \\ud7a0-)."""
db.create_session(session_id="s1", source="cli")
db.append_message("s1", role="user", content="안녕하세요 반갑습니다")
results = db.search_messages("안녕")
assert len(results) == 1
def test_japanese_query_returns_results(self, db):
db.create_session(session_id="s1", source="cli")
db.append_message("s1", role="user", content="こんにちは世界")
assert len(db.search_messages("こんにちは")) == 1
assert len(db.search_messages("世界")) == 1
def test_cjk_fallback_preserves_source_filter(self, db):
"""Guards against the SQL-builder bug where filter clauses land
after LIMIT/OFFSET (seen in one of the duplicate PRs)."""
db.create_session(session_id="s1", source="cli")
db.create_session(session_id="s2", source="telegram")
db.append_message("s1", role="user", content="记忆断裂在CLI")
db.append_message("s2", role="user", content="记忆断裂在Telegram")
results = db.search_messages("记忆断裂", source_filter=["telegram"])
assert len(results) == 1
assert results[0]["source"] == "telegram"
def test_cjk_fallback_preserves_exclude_sources(self, db):
db.create_session(session_id="s1", source="cli")
db.create_session(session_id="s2", source="tool")
db.append_message("s1", role="user", content="记忆断裂在CLI")
db.append_message("s2", role="assistant", content="记忆断裂在tool")
results = db.search_messages("记忆断裂", exclude_sources=["tool"])
sources = {r["source"] for r in results}
assert "tool" not in sources
assert "cli" in sources
def test_cjk_fallback_preserves_role_filter(self, db):
db.create_session(session_id="s1", source="cli")
db.append_message("s1", role="user", content="用户说的记忆断裂")
db.append_message("s1", role="assistant", content="助手说的记忆断裂")
results = db.search_messages("记忆断裂", role_filter=["assistant"])
assert len(results) == 1
assert results[0]["role"] == "assistant"
def test_cjk_snippet_is_centered_on_match(self, db):
"""Snippet should contain the search term, not just the first N chars."""
db.create_session(session_id="s1", source="cli")
long_prefix = "这是一段很长的前缀用来把匹配位置推到文档中间" * 3
long_suffix = "这是一段很长的后缀内容填充剩余空间" * 3
db.append_message(
"s1", role="user",
content=f"{long_prefix}记忆断裂{long_suffix}",
)
results = db.search_messages("记忆断裂")
assert len(results) == 1
# The centered substr() snippet must include the matched term.
assert "记忆断裂" in results[0]["snippet"]
def test_english_query_still_uses_fts5_fast_path(self, db):
"""English queries must not trigger the LIKE fallback (fast path regression)."""
db.create_session(session_id="s1", source="cli")
db.append_message("s1", role="user", content="Deploy docker containers")
results = db.search_messages("docker")
assert len(results) == 1
# No CJK in query → LIKE fallback must not run. We don't assert this
# directly (no instrumentation), but the FTS5 path produces an
# FTS5-style snippet with highlight markers when the term is short.
# At minimum: english queries must still match.
def test_cjk_query_with_no_matches_returns_empty(self, db):
db.create_session(session_id="s1", source="cli")
db.append_message("s1", role="user", content="unrelated English content")
results = db.search_messages("记忆断裂")
assert results == []
def test_mixed_cjk_english_query(self, db):
"""Mixed queries should still fall back to LIKE when FTS5 misses."""
db.create_session(session_id="s1", source="cli")
db.append_message("s1", role="user", content="讨论Agent通信协议")
# "Agent通信" is CJK+English — FTS5 default tokenizer indexes the
# whole CJK run with embedded "agent" as separate tokens; the LIKE
# fallback handles the substring correctly.
results = db.search_messages("Agent通信")
assert len(results) == 1
# =========================================================================
# Session search and listing
# =========================================================================

View file

@ -363,6 +363,28 @@ def test_image_attach_appends_local_image(monkeypatch):
assert len(server._sessions["sid"]["attached_images"]) == 1
def test_commands_catalog_surfaces_quick_commands(monkeypatch):
monkeypatch.setattr(server, "_load_cfg", lambda: {"quick_commands": {
"build": {"type": "exec", "command": "npm run build"},
"git": {"type": "alias", "target": "/shell git"},
"notes": {"type": "exec", "command": "cat NOTES.md", "description": "Open design notes"},
}})
resp = server.handle_request({"id": "1", "method": "commands.catalog", "params": {}})
pairs = dict(resp["result"]["pairs"])
assert "npm run build" in pairs["/build"]
assert pairs["/git"].startswith("alias →")
assert pairs["/notes"] == "Open design notes"
user_cat = next(c for c in resp["result"]["categories"] if c["name"] == "User commands")
user_pairs = dict(user_cat["pairs"])
assert set(user_pairs) == {"/build", "/git", "/notes"}
assert resp["result"]["canon"]["/build"] == "/build"
assert resp["result"]["canon"]["/notes"] == "/notes"
def test_command_dispatch_exec_nonzero_surfaces_error(monkeypatch):
monkeypatch.setattr(server, "_load_cfg", lambda: {"quick_commands": {"boom": {"type": "exec", "command": "boom"}}})
monkeypatch.setattr(
@ -438,3 +460,651 @@ def test_rollback_restore_resolves_number_and_file_path():
assert resp["result"]["success"] is True
assert calls["args"][1] == "bbb222"
assert calls["args"][2] == "src/app.tsx"
# ── session.steer ────────────────────────────────────────────────────
def test_session_steer_calls_agent_steer_when_agent_supports_it():
"""The TUI RPC method must call agent.steer(text) and return a
queued status without touching interrupt state.
"""
calls = {}
class _Agent:
def steer(self, text):
calls["steer_text"] = text
return True
def interrupt(self, *args, **kwargs):
calls["interrupt_called"] = True
server._sessions["sid"] = _session(agent=_Agent())
try:
resp = server.handle_request(
{
"id": "1",
"method": "session.steer",
"params": {"session_id": "sid", "text": "also check auth.log"},
}
)
finally:
server._sessions.pop("sid", None)
assert "result" in resp, resp
assert resp["result"]["status"] == "queued"
assert resp["result"]["text"] == "also check auth.log"
assert calls["steer_text"] == "also check auth.log"
assert "interrupt_called" not in calls # must NOT interrupt
def test_session_steer_rejects_empty_text():
server._sessions["sid"] = _session(agent=types.SimpleNamespace(steer=lambda t: True))
try:
resp = server.handle_request(
{
"id": "1",
"method": "session.steer",
"params": {"session_id": "sid", "text": " "},
}
)
finally:
server._sessions.pop("sid", None)
assert "error" in resp, resp
assert resp["error"]["code"] == 4002
def test_session_steer_errors_when_agent_has_no_steer_method():
server._sessions["sid"] = _session(agent=types.SimpleNamespace()) # no steer()
try:
resp = server.handle_request(
{
"id": "1",
"method": "session.steer",
"params": {"session_id": "sid", "text": "hi"},
}
)
finally:
server._sessions.pop("sid", None)
assert "error" in resp, resp
assert resp["error"]["code"] == 4010
def test_session_info_includes_mcp_servers(monkeypatch):
fake_status = [
{"name": "github", "transport": "http", "tools": 12, "connected": True},
{"name": "filesystem", "transport": "stdio", "tools": 4, "connected": True},
{"name": "broken", "transport": "stdio", "tools": 0, "connected": False},
]
fake_mod = types.ModuleType("tools.mcp_tool")
fake_mod.get_mcp_status = lambda: fake_status
monkeypatch.setitem(sys.modules, "tools.mcp_tool", fake_mod)
info = server._session_info(types.SimpleNamespace(tools=[], model=""))
assert info["mcp_servers"] == fake_status
# ---------------------------------------------------------------------------
# History-mutating commands must reject while session.running is True.
# Without these guards, prompt.submit's post-run history write either
# clobbers the mutation (version matches) or silently drops the agent's
# output (version mismatch) — both produce UI<->backend state desync.
# ---------------------------------------------------------------------------
def test_session_undo_rejects_while_running():
"""Fix for TUI silent-drop #1: /undo must not mutate history
while the agent is mid-turn would either clobber the undo or
cause prompt.submit to silently drop the agent's response."""
server._sessions["sid"] = _session(running=True, history=[
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
])
try:
resp = server.handle_request(
{"id": "1", "method": "session.undo", "params": {"session_id": "sid"}}
)
assert resp.get("error"), "session.undo should reject while running"
assert resp["error"]["code"] == 4009
assert "session busy" in resp["error"]["message"]
# History must be unchanged
assert len(server._sessions["sid"]["history"]) == 2
finally:
server._sessions.pop("sid", None)
def test_session_undo_allowed_when_idle():
"""Regression guard: when not running, /undo still works."""
server._sessions["sid"] = _session(running=False, history=[
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
])
try:
resp = server.handle_request(
{"id": "1", "method": "session.undo", "params": {"session_id": "sid"}}
)
assert resp.get("result"), f"got error: {resp.get('error')}"
assert resp["result"]["removed"] == 2
assert server._sessions["sid"]["history"] == []
finally:
server._sessions.pop("sid", None)
def test_session_compress_rejects_while_running(monkeypatch):
server._sessions["sid"] = _session(running=True)
try:
resp = server.handle_request(
{"id": "1", "method": "session.compress", "params": {"session_id": "sid"}}
)
assert resp.get("error")
assert resp["error"]["code"] == 4009
finally:
server._sessions.pop("sid", None)
def test_rollback_restore_rejects_full_history_while_running(monkeypatch):
"""Full-history rollback must reject; file-scoped rollback still allowed."""
server._sessions["sid"] = _session(running=True)
try:
resp = server.handle_request(
{"id": "1", "method": "rollback.restore", "params": {"session_id": "sid", "hash": "abc"}}
)
assert resp.get("error"), "full-history rollback should reject while running"
assert resp["error"]["code"] == 4009
finally:
server._sessions.pop("sid", None)
def test_prompt_submit_history_version_mismatch_surfaces_warning(monkeypatch):
"""Fix for TUI silent-drop #2: the defensive backstop at prompt.submit
must attach a 'warning' to message.complete when history was
mutated externally during the turn (instead of silently dropping
the agent's output)."""
# Agent bumps history_version itself mid-run to simulate an external
# mutation slipping past the guards.
session_ref = {"s": None}
class _RacyAgent:
def run_conversation(self, prompt, conversation_history=None, stream_callback=None):
# Simulate: something external bumped history_version
# while we were running.
with session_ref["s"]["history_lock"]:
session_ref["s"]["history_version"] += 1
return {"final_response": "agent reply", "messages": [{"role": "assistant", "content": "agent reply"}]}
class _ImmediateThread:
def __init__(self, target=None, daemon=None):
self._target = target
def start(self):
self._target()
server._sessions["sid"] = _session(agent=_RacyAgent())
session_ref["s"] = server._sessions["sid"]
emits: list[tuple] = []
try:
monkeypatch.setattr(server.threading, "Thread", _ImmediateThread)
monkeypatch.setattr(server, "_get_usage", lambda _a: {})
monkeypatch.setattr(server, "render_message", lambda _t, _c: "")
monkeypatch.setattr(server, "_emit", lambda *a: emits.append(a))
resp = server.handle_request(
{"id": "1", "method": "prompt.submit", "params": {"session_id": "sid", "text": "hi"}}
)
assert resp.get("result"), f"got error: {resp.get('error')}"
# History should NOT contain the agent's output (version mismatch)
assert server._sessions["sid"]["history"] == []
# message.complete must carry a 'warning' so the UI / operator
# knows the output was not persisted.
complete_calls = [a for a in emits if a[0] == "message.complete"]
assert len(complete_calls) == 1
_, _, payload = complete_calls[0]
assert "warning" in payload, (
"message.complete must include a 'warning' field on "
"history_version mismatch — otherwise the UI silently "
"shows output that was never persisted"
)
assert "not saved" in payload["warning"].lower() or "changed" in payload["warning"].lower()
finally:
server._sessions.pop("sid", None)
def test_prompt_submit_history_version_match_persists_normally(monkeypatch):
"""Regression guard: the backstop does not affect the happy path."""
class _Agent:
def run_conversation(self, prompt, conversation_history=None, stream_callback=None):
return {"final_response": "reply", "messages": [{"role": "assistant", "content": "reply"}]}
class _ImmediateThread:
def __init__(self, target=None, daemon=None):
self._target = target
def start(self):
self._target()
server._sessions["sid"] = _session(agent=_Agent())
emits: list[tuple] = []
try:
monkeypatch.setattr(server.threading, "Thread", _ImmediateThread)
monkeypatch.setattr(server, "_get_usage", lambda _a: {})
monkeypatch.setattr(server, "render_message", lambda _t, _c: "")
monkeypatch.setattr(server, "_emit", lambda *a: emits.append(a))
resp = server.handle_request(
{"id": "1", "method": "prompt.submit", "params": {"session_id": "sid", "text": "hi"}}
)
assert resp.get("result")
# History was written
assert server._sessions["sid"]["history"] == [{"role": "assistant", "content": "reply"}]
assert server._sessions["sid"]["history_version"] == 1
# No warning should be attached
complete_calls = [a for a in emits if a[0] == "message.complete"]
assert len(complete_calls) == 1
_, _, payload = complete_calls[0]
assert "warning" not in payload
finally:
server._sessions.pop("sid", None)
# ---------------------------------------------------------------------------
# session.interrupt must only cancel pending prompts owned by the calling
# session — it must not blast-resolve clarify/sudo/secret prompts on
# unrelated sessions sharing the same tui_gateway process. Without
# session scoping the other sessions' prompts silently resolve to empty
# strings, unblocking their agent threads as if the user cancelled.
# ---------------------------------------------------------------------------
def test_interrupt_only_clears_own_session_pending():
"""session.interrupt on session A must NOT release pending prompts
that belong to session B."""
import types
session_a = _session()
session_a["agent"] = types.SimpleNamespace(interrupt=lambda: None)
session_b = _session()
session_b["agent"] = types.SimpleNamespace(interrupt=lambda: None)
server._sessions["sid_a"] = session_a
server._sessions["sid_b"] = session_b
try:
# Simulate pending prompts on both sessions (what _block creates
# while a clarify/sudo/secret request is outstanding).
ev_a = threading.Event()
ev_b = threading.Event()
server._pending["rid-a"] = ("sid_a", ev_a)
server._pending["rid-b"] = ("sid_b", ev_b)
server._answers.clear()
# Interrupt session A.
resp = server.handle_request(
{"id": "1", "method": "session.interrupt", "params": {"session_id": "sid_a"}}
)
assert resp.get("result"), f"got error: {resp.get('error')}"
# Session A's pending must be released to empty.
assert ev_a.is_set(), "sid_a pending Event should be set after interrupt"
assert server._answers.get("rid-a") == ""
# Session B's pending MUST remain untouched — no cross-session blast.
assert not ev_b.is_set(), (
"CRITICAL: session.interrupt on sid_a released a pending prompt "
"belonging to sid_b — other sessions' clarify/sudo/secret "
"prompts are being silently cancelled"
)
assert "rid-b" not in server._answers
finally:
server._sessions.pop("sid_a", None)
server._sessions.pop("sid_b", None)
server._pending.pop("rid-a", None)
server._pending.pop("rid-b", None)
server._answers.pop("rid-a", None)
server._answers.pop("rid-b", None)
def test_interrupt_clears_multiple_own_pending():
"""When a single session has multiple pending prompts (uncommon but
possible via nested tool calls), interrupt must release all of them."""
import types
sess = _session()
sess["agent"] = types.SimpleNamespace(interrupt=lambda: None)
server._sessions["sid"] = sess
try:
ev1, ev2 = threading.Event(), threading.Event()
server._pending["r1"] = ("sid", ev1)
server._pending["r2"] = ("sid", ev2)
resp = server.handle_request(
{"id": "1", "method": "session.interrupt", "params": {"session_id": "sid"}}
)
assert resp.get("result")
assert ev1.is_set() and ev2.is_set()
assert server._answers.get("r1") == "" and server._answers.get("r2") == ""
finally:
server._sessions.pop("sid", None)
for key in ("r1", "r2"):
server._pending.pop(key, None)
server._answers.pop(key, None)
def test_clear_pending_without_sid_clears_all():
"""_clear_pending(None) is the shutdown path — must still release
every pending prompt regardless of owning session."""
ev1, ev2, ev3 = threading.Event(), threading.Event(), threading.Event()
server._pending["a"] = ("sid_x", ev1)
server._pending["b"] = ("sid_y", ev2)
server._pending["c"] = ("sid_z", ev3)
try:
server._clear_pending(None)
assert ev1.is_set() and ev2.is_set() and ev3.is_set()
finally:
for key in ("a", "b", "c"):
server._pending.pop(key, None)
server._answers.pop(key, None)
def test_respond_unpacks_sid_tuple_correctly():
"""After the (sid, Event) tuple change, _respond must still work."""
ev = threading.Event()
server._pending["rid-x"] = ("sid_x", ev)
try:
resp = server.handle_request(
{"id": "1", "method": "clarify.respond",
"params": {"request_id": "rid-x", "answer": "the answer"}}
)
assert resp.get("result")
assert ev.is_set()
assert server._answers.get("rid-x") == "the answer"
finally:
server._pending.pop("rid-x", None)
server._answers.pop("rid-x", None)
# ---------------------------------------------------------------------------
# /model switch and other agent-mutating commands must reject while the
# session is running. agent.switch_model() mutates self.model, self.provider,
# self.base_url, self.client etc. in place — the worker thread running
# agent.run_conversation is reading those on every iteration. Same class of
# bug as the session.undo / session.compress mid-run silent-drop; same fix
# pattern: reject with 4009 while running.
# ---------------------------------------------------------------------------
def test_config_set_model_rejects_while_running(monkeypatch):
"""/model via config.set must reject during an in-flight turn."""
seen = {"called": False}
def _fake_apply(sid, session, raw):
seen["called"] = True
return {"value": raw, "warning": ""}
monkeypatch.setattr(server, "_apply_model_switch", _fake_apply)
server._sessions["sid"] = _session(running=True)
try:
resp = server.handle_request({
"id": "1", "method": "config.set",
"params": {"session_id": "sid", "key": "model", "value": "anthropic/claude-sonnet-4.6"},
})
assert resp.get("error")
assert resp["error"]["code"] == 4009
assert "session busy" in resp["error"]["message"]
assert not seen["called"], (
"_apply_model_switch was called mid-turn — would race with "
"the worker thread reading agent.model / agent.client"
)
finally:
server._sessions.pop("sid", None)
def test_config_set_model_allowed_when_idle(monkeypatch):
"""Regression guard: idle sessions can still switch models."""
seen = {"called": False}
def _fake_apply(sid, session, raw):
seen["called"] = True
return {"value": "newmodel", "warning": ""}
monkeypatch.setattr(server, "_apply_model_switch", _fake_apply)
server._sessions["sid"] = _session(running=False)
try:
resp = server.handle_request({
"id": "1", "method": "config.set",
"params": {"session_id": "sid", "key": "model", "value": "newmodel"},
})
assert resp.get("result")
assert resp["result"]["value"] == "newmodel"
assert seen["called"]
finally:
server._sessions.pop("sid", None)
def test_mirror_slash_side_effects_rejects_mutating_commands_while_running(monkeypatch):
"""Slash worker passthrough (e.g. /model, /personality, /prompt,
/compress) must reject during an in-flight turn. Same race as
config.set mutates live agent state while run_conversation is
reading it."""
import types
applied = {"model": False, "compress": False}
def _fake_apply_model(sid, session, arg):
applied["model"] = True
return {"value": arg, "warning": ""}
def _fake_compress(session, focus):
applied["compress"] = True
return (0, {})
monkeypatch.setattr(server, "_apply_model_switch", _fake_apply_model)
monkeypatch.setattr(server, "_compress_session_history", _fake_compress)
session = _session(running=True)
session["agent"] = types.SimpleNamespace(model="x")
for cmd, expected_name in [
("/model new/model", "model"),
("/personality default", "personality"),
("/prompt", "prompt"),
("/compress", "compress"),
]:
warning = server._mirror_slash_side_effects("sid", session, cmd)
assert "session busy" in warning, (
f"{cmd} should have returned busy warning, got: {warning!r}"
)
assert f"/{expected_name}" in warning
# None of the mutating side-effect helpers should have fired.
assert not applied["model"], "model switch fired despite running session"
assert not applied["compress"], "compress fired despite running session"
def test_mirror_slash_side_effects_allowed_when_idle(monkeypatch):
"""Regression guard: idle session still runs the side effects."""
import types
applied = {"model": False}
def _fake_apply_model(sid, session, arg):
applied["model"] = True
return {"value": arg, "warning": ""}
monkeypatch.setattr(server, "_apply_model_switch", _fake_apply_model)
session = _session(running=False)
session["agent"] = types.SimpleNamespace(model="x")
warning = server._mirror_slash_side_effects("sid", session, "/model foo")
# Should NOT contain "session busy" — the switch went through.
assert "session busy" not in warning
assert applied["model"]
# ---------------------------------------------------------------------------
# session.create / session.close race: fast /new churn must not orphan the
# slash_worker subprocess or the global approval-notify registration.
# ---------------------------------------------------------------------------
def test_session_create_close_race_does_not_orphan_worker(monkeypatch):
"""Regression guard: if session.close runs while session.create's
_build thread is still constructing the agent, the build thread
must detect the orphan and clean up the slash_worker + notify
registration it's about to install. Without the cleanup those
resources leak the subprocess stays alive until atexit and the
notify callback lingers in the global registry."""
import threading
closed_workers: list[str] = []
unregistered_keys: list[str] = []
class _FakeWorker:
def __init__(self, key, model):
self.key = key
self._closed = False
def close(self):
self._closed = True
closed_workers.append(self.key)
class _FakeAgent:
def __init__(self):
self.model = "x"
self.provider = "openrouter"
self.base_url = ""
self.api_key = ""
# Make _build block until we release it — simulates slow agent init
release_build = threading.Event()
def _slow_make_agent(sid, key):
release_build.wait(timeout=3.0)
return _FakeAgent()
# Stub everything _build touches
monkeypatch.setattr(server, "_make_agent", _slow_make_agent)
monkeypatch.setattr(server, "_SlashWorker", _FakeWorker)
monkeypatch.setattr(server, "_get_db", lambda: types.SimpleNamespace(create_session=lambda *a, **kw: None))
monkeypatch.setattr(server, "_session_info", lambda _a: {"model": "x"})
monkeypatch.setattr(server, "_probe_credentials", lambda _a: None)
monkeypatch.setattr(server, "_wire_callbacks", lambda _sid: None)
monkeypatch.setattr(server, "_emit", lambda *a, **kw: None)
# Shim register/unregister to observe leaks
import tools.approval as _approval
monkeypatch.setattr(_approval, "register_gateway_notify",
lambda key, cb: None)
monkeypatch.setattr(_approval, "unregister_gateway_notify",
lambda key: unregistered_keys.append(key))
monkeypatch.setattr(_approval, "load_permanent_allowlist", lambda: None)
# Start: session.create spawns _build thread, returns synchronously
resp = server.handle_request({
"id": "1", "method": "session.create", "params": {"cols": 80},
})
assert resp.get("result"), f"got error: {resp.get('error')}"
sid = resp["result"]["session_id"]
# Build thread is blocked in _slow_make_agent. Close the session
# NOW — this pops _sessions[sid] before _build can install the
# worker/notify.
close_resp = server.handle_request({
"id": "2", "method": "session.close", "params": {"session_id": sid},
})
assert close_resp.get("result", {}).get("closed") is True
# At this point session.close saw slash_worker=None (not yet
# installed) so it didn't close anything. Release the build thread
# and let it finish — it should detect the orphan and clean up the
# worker it just allocated + unregister the notify.
release_build.set()
# Give the build thread a moment to run through its finally.
for _ in range(100):
if closed_workers:
break
import time
time.sleep(0.02)
assert len(closed_workers) == 1, (
f"orphan worker was not cleaned up — closed_workers={closed_workers}"
)
# Notify may be unregistered by both session.close (unconditional)
# and the orphan-cleanup path; the key guarantee is that the build
# thread does at least one unregister call (any prior close
# already popped the callback; the duplicate is a no-op).
assert len(unregistered_keys) >= 1, (
f"orphan notify registration was not unregistered — "
f"unregistered_keys={unregistered_keys}"
)
def test_session_create_no_race_keeps_worker_alive(monkeypatch):
"""Regression guard: when session.close does NOT race, the build
thread must install the worker + notify normally and leave them
alone (no over-eager cleanup)."""
closed_workers: list[str] = []
unregistered_keys: list[str] = []
class _FakeWorker:
def __init__(self, key, model):
self.key = key
def close(self):
closed_workers.append(self.key)
class _FakeAgent:
def __init__(self):
self.model = "x"
self.provider = "openrouter"
self.base_url = ""
self.api_key = ""
monkeypatch.setattr(server, "_make_agent", lambda sid, key: _FakeAgent())
monkeypatch.setattr(server, "_SlashWorker", _FakeWorker)
monkeypatch.setattr(server, "_get_db", lambda: types.SimpleNamespace(create_session=lambda *a, **kw: None))
monkeypatch.setattr(server, "_session_info", lambda _a: {"model": "x"})
monkeypatch.setattr(server, "_probe_credentials", lambda _a: None)
monkeypatch.setattr(server, "_wire_callbacks", lambda _sid: None)
monkeypatch.setattr(server, "_emit", lambda *a, **kw: None)
import tools.approval as _approval
monkeypatch.setattr(_approval, "register_gateway_notify", lambda key, cb: None)
monkeypatch.setattr(_approval, "unregister_gateway_notify",
lambda key: unregistered_keys.append(key))
monkeypatch.setattr(_approval, "load_permanent_allowlist", lambda: None)
resp = server.handle_request({
"id": "1", "method": "session.create", "params": {"cols": 80},
})
sid = resp["result"]["session_id"]
# Wait for the build to finish (ready event inside session dict).
session = server._sessions[sid]
session["agent_ready"].wait(timeout=2.0)
# Build finished without a close race — nothing should have been
# cleaned up by the orphan check.
assert closed_workers == [], (
f"build thread closed its own worker despite no race: {closed_workers}"
)
assert unregistered_keys == [], (
f"build thread unregistered its own notify despite no race: {unregistered_keys}"
)
# Session should have the live worker installed.
assert session.get("slash_worker") is not None
# Cleanup
server._sessions.pop(sid, None)

View file

@ -0,0 +1,408 @@
"""Unit tests for browser_cdp tool.
Uses a tiny in-process ``websockets`` server to simulate a CDP endpoint
gives real protocol coverage (connect, send, recv, close) without needing
a real Chrome instance.
"""
from __future__ import annotations
import asyncio
import json
import threading
import time
from typing import Any, Dict, List
import pytest
import websockets
from websockets.asyncio.server import serve
from tools import browser_cdp_tool
# ---------------------------------------------------------------------------
# In-process CDP mock server
# ---------------------------------------------------------------------------
class _CDPServer:
"""A tiny CDP-over-WebSocket mock.
Each client gets a greeting-free stream. The server replies to each
inbound request whose ``id`` is set, using the registered handler for
that method. If no handler is registered, returns a generic CDP error.
"""
def __init__(self) -> None:
self._handlers: Dict[str, Any] = {}
self._responses: List[Dict[str, Any]] = []
self._loop: asyncio.AbstractEventLoop | None = None
self._server: Any = None
self._thread: threading.Thread | None = None
self._host = "127.0.0.1"
self._port = 0
# --- handler registration --------------------------------------------
def on(self, method: str, handler):
"""Register a handler ``handler(params, session_id) -> dict or Exception``."""
self._handlers[method] = handler
# --- lifecycle -------------------------------------------------------
def start(self) -> str:
ready = threading.Event()
def _run() -> None:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
async def _handler(ws):
try:
async for raw in ws:
msg = json.loads(raw)
call_id = msg.get("id")
method = msg.get("method", "")
params = msg.get("params", {}) or {}
session_id = msg.get("sessionId")
self._responses.append(msg)
fn = self._handlers.get(method)
if fn is None:
reply = {
"id": call_id,
"error": {
"code": -32601,
"message": f"No handler for {method}",
},
}
else:
try:
result = fn(params, session_id)
if isinstance(result, Exception):
raise result
reply = {"id": call_id, "result": result}
except Exception as exc:
reply = {
"id": call_id,
"error": {"code": -1, "message": str(exc)},
}
if session_id:
reply["sessionId"] = session_id
await ws.send(json.dumps(reply))
except websockets.exceptions.ConnectionClosed:
pass
async def _serve() -> None:
self._server = await serve(_handler, self._host, 0)
sock = next(iter(self._server.sockets))
self._port = sock.getsockname()[1]
ready.set()
await self._server.wait_closed()
try:
self._loop.run_until_complete(_serve())
finally:
self._loop.close()
self._thread = threading.Thread(target=_run, daemon=True)
self._thread.start()
if not ready.wait(timeout=5.0):
raise RuntimeError("CDP mock server failed to start within 5s")
return f"ws://{self._host}:{self._port}/devtools/browser/mock"
def stop(self) -> None:
if self._loop and self._server:
def _close() -> None:
self._server.close()
self._loop.call_soon_threadsafe(_close)
if self._thread:
self._thread.join(timeout=3.0)
def received(self) -> List[Dict[str, Any]]:
return list(self._responses)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def cdp_server(monkeypatch):
"""Start a CDP mock and route tool resolution to it."""
server = _CDPServer()
ws_url = server.start()
monkeypatch.setattr(
browser_cdp_tool, "_resolve_cdp_endpoint", lambda: ws_url
)
try:
yield server
finally:
server.stop()
# ---------------------------------------------------------------------------
# Input validation
# ---------------------------------------------------------------------------
def test_missing_method_returns_error():
result = json.loads(browser_cdp_tool.browser_cdp(method=""))
assert "error" in result
assert "method" in result["error"].lower()
assert result.get("cdp_docs") == browser_cdp_tool.CDP_DOCS_URL
def test_non_string_method_returns_error():
result = json.loads(browser_cdp_tool.browser_cdp(method=123)) # type: ignore[arg-type]
assert "error" in result
assert "method" in result["error"].lower()
def test_non_dict_params_returns_error(monkeypatch):
monkeypatch.setattr(
browser_cdp_tool, "_resolve_cdp_endpoint", lambda: "ws://localhost:9999"
)
result = json.loads(
browser_cdp_tool.browser_cdp(method="Target.getTargets", params="not-a-dict") # type: ignore[arg-type]
)
assert "error" in result
assert "object" in result["error"].lower() or "dict" in result["error"].lower()
# ---------------------------------------------------------------------------
# Endpoint resolution
# ---------------------------------------------------------------------------
def test_no_endpoint_returns_helpful_error(monkeypatch):
monkeypatch.setattr(browser_cdp_tool, "_resolve_cdp_endpoint", lambda: "")
result = json.loads(browser_cdp_tool.browser_cdp(method="Target.getTargets"))
assert "error" in result
assert "/browser connect" in result["error"]
assert result.get("cdp_docs") == browser_cdp_tool.CDP_DOCS_URL
def test_non_ws_endpoint_returns_error(monkeypatch):
monkeypatch.setattr(
browser_cdp_tool, "_resolve_cdp_endpoint", lambda: "http://localhost:9222"
)
result = json.loads(browser_cdp_tool.browser_cdp(method="Target.getTargets"))
assert "error" in result
assert "WebSocket" in result["error"]
def test_websockets_missing_returns_error(monkeypatch):
monkeypatch.setattr(browser_cdp_tool, "_WS_AVAILABLE", False)
result = json.loads(browser_cdp_tool.browser_cdp(method="Target.getTargets"))
assert "error" in result
assert "websockets" in result["error"].lower()
# ---------------------------------------------------------------------------
# Happy-path: browser-level call
# ---------------------------------------------------------------------------
def test_browser_level_success(cdp_server):
cdp_server.on(
"Target.getTargets",
lambda params, sid: {
"targetInfos": [
{"targetId": "A", "type": "page", "title": "Tab 1", "url": "about:blank"},
{"targetId": "B", "type": "page", "title": "Tab 2", "url": "https://a.test"},
]
},
)
result = json.loads(browser_cdp_tool.browser_cdp(method="Target.getTargets"))
assert result["success"] is True
assert result["method"] == "Target.getTargets"
assert "target_id" not in result
assert len(result["result"]["targetInfos"]) == 2
# Verify the server actually received exactly one call (no extra traffic)
calls = cdp_server.received()
assert len(calls) == 1
assert calls[0]["method"] == "Target.getTargets"
assert "sessionId" not in calls[0]
def test_empty_params_sends_empty_object(cdp_server):
cdp_server.on("Browser.getVersion", lambda params, sid: {"product": "Mock/1.0"})
json.loads(browser_cdp_tool.browser_cdp(method="Browser.getVersion"))
assert cdp_server.received()[0]["params"] == {}
# ---------------------------------------------------------------------------
# Happy-path: target-attached call
# ---------------------------------------------------------------------------
def test_target_attach_then_call(cdp_server):
cdp_server.on(
"Target.attachToTarget",
lambda params, sid: {"sessionId": f"sess-{params['targetId']}"},
)
cdp_server.on(
"Runtime.evaluate",
lambda params, sid: {
"result": {"type": "string", "value": f"evaluated[{sid}]"},
},
)
result = json.loads(
browser_cdp_tool.browser_cdp(
method="Runtime.evaluate",
params={"expression": "document.title", "returnByValue": True},
target_id="tab-A",
)
)
assert result["success"] is True
assert result["target_id"] == "tab-A"
assert result["result"]["result"]["value"] == "evaluated[sess-tab-A]"
calls = cdp_server.received()
# First call: attach
assert calls[0]["method"] == "Target.attachToTarget"
assert calls[0]["params"] == {"targetId": "tab-A", "flatten": True}
# Second call: dispatched method on the session
assert calls[1]["method"] == "Runtime.evaluate"
assert calls[1]["sessionId"] == "sess-tab-A"
# ---------------------------------------------------------------------------
# CDP error responses
# ---------------------------------------------------------------------------
def test_cdp_method_error_returns_tool_error(cdp_server):
# No handler registered -> server returns CDP error
result = json.loads(
browser_cdp_tool.browser_cdp(method="NonExistent.method")
)
assert "error" in result
assert "CDP error" in result["error"]
assert result.get("method") == "NonExistent.method"
def test_attach_failure_returns_tool_error(cdp_server):
# Target.attachToTarget has no handler -> server errors on attach
result = json.loads(
browser_cdp_tool.browser_cdp(
method="Runtime.evaluate",
params={"expression": "1+1"},
target_id="missing",
)
)
assert "error" in result
assert "Target.attachToTarget" in result["error"]
# ---------------------------------------------------------------------------
# Timeouts
# ---------------------------------------------------------------------------
def test_timeout_when_server_never_replies(cdp_server):
# Register a handler that blocks forever
def slow(params, sid):
time.sleep(10)
return {}
cdp_server.on("Page.slowMethod", slow)
result = json.loads(
browser_cdp_tool.browser_cdp(
method="Page.slowMethod", timeout=0.5
)
)
assert "error" in result
assert "tim" in result["error"].lower()
# ---------------------------------------------------------------------------
# Timeout clamping
# ---------------------------------------------------------------------------
def test_timeout_clamped_above_max(cdp_server):
cdp_server.on("Browser.getVersion", lambda p, s: {"product": "ok"})
# timeout=10_000 should be clamped to 300 but still succeed
result = json.loads(
browser_cdp_tool.browser_cdp(method="Browser.getVersion", timeout=10_000)
)
assert result["success"] is True
def test_invalid_timeout_falls_back_to_default(cdp_server):
cdp_server.on("Browser.getVersion", lambda p, s: {"product": "ok"})
result = json.loads(
browser_cdp_tool.browser_cdp(method="Browser.getVersion", timeout="nope") # type: ignore[arg-type]
)
assert result["success"] is True
# ---------------------------------------------------------------------------
# Registry integration
# ---------------------------------------------------------------------------
def test_registered_in_browser_toolset():
from tools.registry import registry
entry = registry.get_entry("browser_cdp")
assert entry is not None
assert entry.toolset == "browser"
assert entry.schema["name"] == "browser_cdp"
assert entry.schema["parameters"]["required"] == ["method"]
assert "Chrome DevTools Protocol" in entry.schema["description"]
assert browser_cdp_tool.CDP_DOCS_URL in entry.schema["description"]
def test_dispatch_through_registry(cdp_server):
from tools.registry import registry
cdp_server.on("Target.getTargets", lambda p, s: {"targetInfos": []})
raw = registry.dispatch(
"browser_cdp", {"method": "Target.getTargets"}, task_id="t1"
)
result = json.loads(raw)
assert result["success"] is True
assert result["method"] == "Target.getTargets"
# ---------------------------------------------------------------------------
# check_fn gating
# ---------------------------------------------------------------------------
def test_check_fn_false_when_no_cdp_url(monkeypatch):
"""Gate closes when no CDP URL is set — even if the browser toolset is
otherwise configured."""
import tools.browser_tool as bt
monkeypatch.setattr(bt, "check_browser_requirements", lambda: True)
monkeypatch.setattr(bt, "_get_cdp_override", lambda: "")
assert browser_cdp_tool._browser_cdp_check() is False
def test_check_fn_true_when_cdp_url_set(monkeypatch):
"""Gate opens as soon as a CDP URL is resolvable."""
import tools.browser_tool as bt
monkeypatch.setattr(bt, "check_browser_requirements", lambda: True)
monkeypatch.setattr(
bt, "_get_cdp_override", lambda: "ws://localhost:9222/devtools/browser/x"
)
assert browser_cdp_tool._browser_cdp_check() is True
def test_check_fn_false_when_browser_requirements_fail(monkeypatch):
"""Even with a CDP URL, gate closes if the overall browser toolset is
unavailable (e.g. agent-browser not installed)."""
import tools.browser_tool as bt
monkeypatch.setattr(bt, "check_browser_requirements", lambda: False)
monkeypatch.setattr(
bt, "_get_cdp_override", lambda: "ws://localhost:9222/devtools/browser/x"
)
assert browser_cdp_tool._browser_cdp_check() is False

View file

@ -0,0 +1,455 @@
#!/usr/bin/env python3
"""Tests for execute_code's strict / project execution modes.
The mode switch controls two things:
- working directory: staging tmpdir (strict) vs session CWD (project)
- interpreter: sys.executable (strict) vs active venv's python (project)
Security-critical invariants env scrubbing, tool whitelist, resource caps
must apply identically in both modes. These tests guard all three layers.
Mode is sourced exclusively from ``code_execution.mode`` in config.yaml
there is no env-var override. Tests patch ``_load_config`` directly.
"""
import json
import os
import sys
import unittest
from contextlib import contextmanager
from unittest.mock import patch
import pytest
os.environ["TERMINAL_ENV"] = "local"
@pytest.fixture(autouse=True)
def _force_local_terminal(monkeypatch):
"""Mirror test_code_execution.py — guarantee local backend under xdist."""
monkeypatch.setenv("TERMINAL_ENV", "local")
from tools.code_execution_tool import (
SANDBOX_ALLOWED_TOOLS,
DEFAULT_EXECUTION_MODE,
EXECUTION_MODES,
_get_execution_mode,
_is_usable_python,
_resolve_child_cwd,
_resolve_child_python,
build_execute_code_schema,
execute_code,
)
@contextmanager
def _mock_mode(mode):
"""Context manager that pins code_execution.mode to the given value."""
with patch("tools.code_execution_tool._load_config",
return_value={"mode": mode}):
yield
def _mock_handle_function_call(function_name, function_args, task_id=None, user_task=None):
"""Minimal mock dispatcher reused across tests."""
if function_name == "terminal":
return json.dumps({"output": "mock", "exit_code": 0})
if function_name == "read_file":
return json.dumps({"content": "line1\n", "total_lines": 1})
return json.dumps({"error": f"Unknown tool: {function_name}"})
# ---------------------------------------------------------------------------
# Mode resolution
# ---------------------------------------------------------------------------
class TestGetExecutionMode(unittest.TestCase):
"""_get_execution_mode reads config.yaml only (no env var surface)."""
def test_default_is_project(self):
self.assertEqual(DEFAULT_EXECUTION_MODE, "project")
def test_config_project(self):
with patch("tools.code_execution_tool._load_config",
return_value={"mode": "project"}):
self.assertEqual(_get_execution_mode(), "project")
def test_config_strict(self):
with patch("tools.code_execution_tool._load_config",
return_value={"mode": "strict"}):
self.assertEqual(_get_execution_mode(), "strict")
def test_config_case_insensitive(self):
with patch("tools.code_execution_tool._load_config",
return_value={"mode": "STRICT"}):
self.assertEqual(_get_execution_mode(), "strict")
def test_config_strips_whitespace(self):
with patch("tools.code_execution_tool._load_config",
return_value={"mode": " project "}):
self.assertEqual(_get_execution_mode(), "project")
def test_empty_config_falls_back_to_default(self):
with patch("tools.code_execution_tool._load_config", return_value={}):
self.assertEqual(_get_execution_mode(), DEFAULT_EXECUTION_MODE)
def test_bogus_config_falls_back_to_default(self):
with patch("tools.code_execution_tool._load_config",
return_value={"mode": "banana"}):
self.assertEqual(_get_execution_mode(), DEFAULT_EXECUTION_MODE)
def test_none_config_falls_back_to_default(self):
with patch("tools.code_execution_tool._load_config",
return_value={"mode": None}):
# str(None).lower() = "none" → not in EXECUTION_MODES → default
self.assertEqual(_get_execution_mode(), DEFAULT_EXECUTION_MODE)
def test_execution_modes_tuple(self):
"""Canonical set of modes — tests + config layer rely on this shape."""
self.assertEqual(set(EXECUTION_MODES), {"project", "strict"})
# ---------------------------------------------------------------------------
# Interpreter resolver
# ---------------------------------------------------------------------------
class TestResolveChildPython(unittest.TestCase):
"""_resolve_child_python — picks the right interpreter per mode."""
def test_strict_always_sys_executable(self):
"""Strict mode never leaves sys.executable, even if venv is set."""
with patch.dict(os.environ, {"VIRTUAL_ENV": "/some/venv"}):
self.assertEqual(_resolve_child_python("strict"), sys.executable)
def test_project_with_no_venv_falls_back(self):
"""Project mode without VIRTUAL_ENV or CONDA_PREFIX → sys.executable."""
env = {k: v for k, v in os.environ.items()
if k not in ("VIRTUAL_ENV", "CONDA_PREFIX")}
with patch.dict(os.environ, env, clear=True):
self.assertEqual(_resolve_child_python("project"), sys.executable)
def test_project_with_virtualenv_picks_venv_python(self):
"""Project mode + VIRTUAL_ENV pointing at a real venv → that python."""
import tempfile, pathlib
with tempfile.TemporaryDirectory() as td:
fake_venv = pathlib.Path(td)
(fake_venv / "bin").mkdir()
# Symlink to real python so the version check actually passes
(fake_venv / "bin" / "python").symlink_to(sys.executable)
with patch.dict(os.environ, {"VIRTUAL_ENV": str(fake_venv)}):
# Clear cache — _is_usable_python memoizes on path
_is_usable_python.cache_clear()
result = _resolve_child_python("project")
self.assertEqual(result, str(fake_venv / "bin" / "python"))
def test_project_with_broken_venv_falls_back(self):
"""VIRTUAL_ENV set but bin/python missing → sys.executable."""
import tempfile
with tempfile.TemporaryDirectory() as td:
# No bin/python inside — broken venv
with patch.dict(os.environ, {"VIRTUAL_ENV": td}):
_is_usable_python.cache_clear()
self.assertEqual(_resolve_child_python("project"), sys.executable)
def test_project_prefers_virtualenv_over_conda(self):
"""If both VIRTUAL_ENV and CONDA_PREFIX are set, VIRTUAL_ENV wins."""
import tempfile, pathlib
with tempfile.TemporaryDirectory() as ve_td, tempfile.TemporaryDirectory() as conda_td:
ve = pathlib.Path(ve_td)
(ve / "bin").mkdir()
(ve / "bin" / "python").symlink_to(sys.executable)
conda = pathlib.Path(conda_td)
(conda / "bin").mkdir()
(conda / "bin" / "python").symlink_to(sys.executable)
with patch.dict(os.environ, {"VIRTUAL_ENV": str(ve), "CONDA_PREFIX": str(conda)}):
_is_usable_python.cache_clear()
result = _resolve_child_python("project")
self.assertEqual(result, str(ve / "bin" / "python"))
def test_is_usable_python_rejects_nonexistent(self):
_is_usable_python.cache_clear()
self.assertFalse(_is_usable_python("/does/not/exist/python"))
def test_is_usable_python_accepts_real_python(self):
_is_usable_python.cache_clear()
self.assertTrue(_is_usable_python(sys.executable))
# ---------------------------------------------------------------------------
# CWD resolver
# ---------------------------------------------------------------------------
class TestResolveChildCwd(unittest.TestCase):
def test_strict_uses_staging_dir(self):
self.assertEqual(_resolve_child_cwd("strict", "/tmp/staging"), "/tmp/staging")
def test_project_without_terminal_cwd_uses_getcwd(self):
env = {k: v for k, v in os.environ.items() if k != "TERMINAL_CWD"}
with patch.dict(os.environ, env, clear=True):
self.assertEqual(_resolve_child_cwd("project", "/tmp/staging"), os.getcwd())
def test_project_uses_terminal_cwd_when_set(self):
import tempfile
with tempfile.TemporaryDirectory() as td:
with patch.dict(os.environ, {"TERMINAL_CWD": td}):
self.assertEqual(_resolve_child_cwd("project", "/tmp/staging"), td)
def test_project_bogus_terminal_cwd_falls_back_to_getcwd(self):
with patch.dict(os.environ, {"TERMINAL_CWD": "/does/not/exist/anywhere"}):
self.assertEqual(_resolve_child_cwd("project", "/tmp/staging"), os.getcwd())
def test_project_expands_tilde(self):
import pathlib
home = str(pathlib.Path.home())
with patch.dict(os.environ, {"TERMINAL_CWD": "~"}):
self.assertEqual(_resolve_child_cwd("project", "/tmp/staging"), home)
# ---------------------------------------------------------------------------
# Schema description
# ---------------------------------------------------------------------------
class TestModeAwareSchema(unittest.TestCase):
def test_strict_description_mentions_temp_dir(self):
desc = build_execute_code_schema(mode="strict")["description"]
self.assertIn("temp dir", desc)
def test_project_description_mentions_session_and_venv(self):
desc = build_execute_code_schema(mode="project")["description"]
self.assertIn("session", desc)
self.assertIn("venv", desc)
def test_neither_description_uses_sandbox_language(self):
"""REGRESSION GUARD for commit 39b83f34.
Agents on local backends falsely believed they were sandboxed and
refused networking tasks. Do not reintroduce any 'sandbox' /
'isolated' / 'cloud' language in the tool description.
"""
for mode in EXECUTION_MODES:
desc = build_execute_code_schema(mode=mode)["description"].lower()
for forbidden in ("sandbox", "isolated", "cloud"):
self.assertNotIn(forbidden, desc,
f"mode={mode}: '{forbidden}' leaked into description")
def test_descriptions_are_similar_length(self):
"""Both modes should have roughly the same-size description."""
strict = len(build_execute_code_schema(mode="strict")["description"])
project = len(build_execute_code_schema(mode="project")["description"])
self.assertLess(abs(strict - project), 200)
def test_default_mode_reads_config(self):
"""build_execute_code_schema() with mode=None reads config.yaml."""
with _mock_mode("strict"):
desc = build_execute_code_schema()["description"]
self.assertIn("temp dir", desc)
with _mock_mode("project"):
desc = build_execute_code_schema()["description"]
self.assertIn("session", desc)
# ---------------------------------------------------------------------------
# Integration: what actually happens when execute_code runs per mode
# ---------------------------------------------------------------------------
@pytest.mark.skipif(sys.platform == "win32", reason="execute_code is POSIX-only")
class TestExecuteCodeModeIntegration(unittest.TestCase):
"""End-to-end: verify the subprocess actually runs where we expect."""
def _run(self, code, mode, enabled_tools=None, extra_env=None):
env_overrides = extra_env or {}
with _mock_mode(mode):
with patch.dict(os.environ, env_overrides):
with patch("model_tools.handle_function_call",
side_effect=_mock_handle_function_call):
raw = execute_code(
code=code,
task_id=f"test-{mode}",
enabled_tools=enabled_tools or list(SANDBOX_ALLOWED_TOOLS),
)
return json.loads(raw)
def test_strict_mode_runs_in_tmpdir(self):
"""Strict mode: script's os.getcwd() is the staging tmpdir."""
result = self._run("import os; print(os.getcwd())", mode="strict")
self.assertEqual(result["status"], "success")
self.assertIn("hermes_sandbox_", result["output"])
def test_project_mode_runs_in_session_cwd(self):
"""Project mode: script's os.getcwd() is the session's working dir."""
import tempfile
with tempfile.TemporaryDirectory() as td:
result = self._run(
"import os; print(os.getcwd())",
mode="project",
extra_env={"TERMINAL_CWD": td},
)
self.assertEqual(result["status"], "success")
# Resolve symlinks (macOS /tmp → /private/tmp) on both sides
self.assertEqual(
os.path.realpath(result["output"].strip()),
os.path.realpath(td),
)
def test_project_mode_interpreter_is_venv_python(self):
"""Project mode: sys.executable inside the child is the venv's python
when VIRTUAL_ENV is set to a real venv."""
# The hermes-agent venv is always active during tests, so this also
# happens to equal sys.executable of the parent. What we're asserting
# is: resolver picked a venv-bin/python path, not that it differs
# from sys.executable.
result = self._run("import sys; print(sys.executable)", mode="project")
self.assertEqual(result["status"], "success")
# Either VIRTUAL_ENV-bin/python or sys.executable fallback, both OK.
output = result["output"].strip()
ve = os.environ.get("VIRTUAL_ENV", "").strip()
if ve:
self.assertTrue(
output.startswith(ve) or output == sys.executable,
f"project-mode python should be under VIRTUAL_ENV={ve} or sys.executable={sys.executable}, got {output}",
)
def test_project_mode_can_still_import_hermes_tools(self):
"""Regression: hermes_tools still importable from non-tmpdir CWD.
This is the PYTHONPATH fix without it, switching to session CWD
breaks `from hermes_tools import terminal`.
"""
import tempfile
with tempfile.TemporaryDirectory() as td:
code = (
"from hermes_tools import terminal\n"
"r = terminal('echo x')\n"
"print(r.get('output', 'MISSING'))\n"
)
result = self._run(code, mode="project", extra_env={"TERMINAL_CWD": td})
self.assertEqual(result["status"], "success")
self.assertIn("mock", result["output"])
def test_strict_mode_can_still_import_hermes_tools(self):
"""Regression: strict mode's tmpdir CWD still works for imports."""
code = (
"from hermes_tools import terminal\n"
"r = terminal('echo x')\n"
"print(r.get('output', 'MISSING'))\n"
)
result = self._run(code, mode="strict")
self.assertEqual(result["status"], "success")
self.assertIn("mock", result["output"])
# ---------------------------------------------------------------------------
# SECURITY-CRITICAL regression guards
#
# These MUST pass in both strict and project mode. The whole tiered-mode
# proposition rests on the claim that switching from strict to project only
# changes CWD + interpreter, not the security posture.
# ---------------------------------------------------------------------------
@pytest.mark.skipif(sys.platform == "win32", reason="execute_code is POSIX-only")
class TestSecurityInvariantsAcrossModes(unittest.TestCase):
def _run(self, code, mode):
with _mock_mode(mode):
with patch("model_tools.handle_function_call",
side_effect=_mock_handle_function_call):
raw = execute_code(
code=code,
task_id=f"test-sec-{mode}",
enabled_tools=list(SANDBOX_ALLOWED_TOOLS),
)
return json.loads(raw)
def test_api_keys_scrubbed_in_strict_mode(self):
code = (
"import os\n"
"print('KEY=' + os.environ.get('OPENAI_API_KEY', 'MISSING'))\n"
"print('TOK=' + os.environ.get('ANTHROPIC_API_KEY', 'MISSING'))\n"
)
with patch.dict(os.environ, {
"OPENAI_API_KEY": "sk-should-not-leak",
"ANTHROPIC_API_KEY": "ant-should-not-leak",
}):
result = self._run(code, mode="strict")
self.assertEqual(result["status"], "success")
self.assertIn("KEY=MISSING", result["output"])
self.assertIn("TOK=MISSING", result["output"])
self.assertNotIn("sk-should-not-leak", result["output"])
self.assertNotIn("ant-should-not-leak", result["output"])
def test_api_keys_scrubbed_in_project_mode(self):
"""CRITICAL: the project-mode default does NOT leak user credentials."""
code = (
"import os\n"
"print('KEY=' + os.environ.get('OPENAI_API_KEY', 'MISSING'))\n"
"print('TOK=' + os.environ.get('ANTHROPIC_API_KEY', 'MISSING'))\n"
"print('SEC=' + os.environ.get('GITHUB_TOKEN', 'MISSING'))\n"
)
with patch.dict(os.environ, {
"OPENAI_API_KEY": "sk-should-not-leak",
"ANTHROPIC_API_KEY": "ant-should-not-leak",
"GITHUB_TOKEN": "ghp-should-not-leak",
}):
result = self._run(code, mode="project")
self.assertEqual(result["status"], "success")
for needle in ("KEY=MISSING", "TOK=MISSING", "SEC=MISSING"):
self.assertIn(needle, result["output"])
for leaked in ("sk-should-not-leak", "ant-should-not-leak", "ghp-should-not-leak"):
self.assertNotIn(leaked, result["output"])
def test_secret_substrings_scrubbed_in_project_mode(self):
"""SECRET/PASSWORD/CREDENTIAL/PASSWD/AUTH filters still apply."""
code = (
"import os\n"
"for k in ('MY_SECRET', 'DB_PASSWORD', 'VAULT_CREDENTIAL', "
"'LDAP_PASSWD', 'AUTH_TOKEN'):\n"
" print(f'{k}=' + os.environ.get(k, 'MISSING'))\n"
)
with patch.dict(os.environ, {
"MY_SECRET": "secret-should-not-leak",
"DB_PASSWORD": "password-should-not-leak",
"VAULT_CREDENTIAL": "cred-should-not-leak",
"LDAP_PASSWD": "passwd-should-not-leak",
"AUTH_TOKEN": "auth-should-not-leak",
}):
result = self._run(code, mode="project")
self.assertEqual(result["status"], "success")
for leaked in ("secret-should-not-leak", "password-should-not-leak",
"cred-should-not-leak", "passwd-should-not-leak",
"auth-should-not-leak"):
self.assertNotIn(leaked, result["output"])
def test_tool_whitelist_enforced_in_strict_mode(self):
"""A script cannot RPC-call tools outside SANDBOX_ALLOWED_TOOLS."""
# execute_code is NOT in SANDBOX_ALLOWED_TOOLS (no recursion)
self.assertNotIn("execute_code", SANDBOX_ALLOWED_TOOLS)
code = (
"import hermes_tools as ht\n"
"print('execute_code_available:', hasattr(ht, 'execute_code'))\n"
"print('delegate_task_available:', hasattr(ht, 'delegate_task'))\n"
)
result = self._run(code, mode="strict")
self.assertEqual(result["status"], "success")
self.assertIn("execute_code_available: False", result["output"])
self.assertIn("delegate_task_available: False", result["output"])
def test_tool_whitelist_enforced_in_project_mode(self):
"""CRITICAL: project mode does NOT widen the tool whitelist."""
code = (
"import hermes_tools as ht\n"
"print('execute_code_available:', hasattr(ht, 'execute_code'))\n"
"print('delegate_task_available:', hasattr(ht, 'delegate_task'))\n"
)
result = self._run(code, mode="project")
self.assertEqual(result["status"], "success")
self.assertIn("execute_code_available: False", result["output"])
self.assertIn("delegate_task_available: False", result["output"])
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,256 @@
"""Tests for approvals.cron_mode — configurable approval behavior for cron jobs."""
import os
import pytest
import tools.approval as approval_module
from tools.approval import (
_get_cron_approval_mode,
check_all_command_guards,
check_dangerous_command,
detect_dangerous_command,
)
@pytest.fixture(autouse=True)
def _clear_approval_state():
approval_module._permanent_approved.clear()
approval_module.clear_session("default")
approval_module.clear_session("test-session")
yield
approval_module._permanent_approved.clear()
approval_module.clear_session("default")
approval_module.clear_session("test-session")
# ---------------------------------------------------------------------------
# _get_cron_approval_mode() config parsing
# ---------------------------------------------------------------------------
class TestCronApprovalModeParsing:
def test_default_is_deny(self):
"""When no config is set, cron_mode defaults to 'deny'."""
from unittest.mock import patch as mock_patch
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {}}):
assert _get_cron_approval_mode() == "deny"
def test_explicit_deny(self):
from unittest.mock import patch as mock_patch
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"cron_mode": "deny"}}):
assert _get_cron_approval_mode() == "deny"
def test_explicit_approve(self):
from unittest.mock import patch as mock_patch
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"cron_mode": "approve"}}):
assert _get_cron_approval_mode() == "approve"
def test_off_maps_to_approve(self):
"""'off' is an alias for 'approve' (matches --yolo semantics)."""
from unittest.mock import patch as mock_patch
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"cron_mode": "off"}}):
assert _get_cron_approval_mode() == "approve"
def test_allow_maps_to_approve(self):
from unittest.mock import patch as mock_patch
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"cron_mode": "allow"}}):
assert _get_cron_approval_mode() == "approve"
def test_yes_maps_to_approve(self):
from unittest.mock import patch as mock_patch
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"cron_mode": "yes"}}):
assert _get_cron_approval_mode() == "approve"
def test_case_insensitive(self):
from unittest.mock import patch as mock_patch
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"cron_mode": "APPROVE"}}):
assert _get_cron_approval_mode() == "approve"
def test_unknown_value_defaults_to_deny(self):
from unittest.mock import patch as mock_patch
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"cron_mode": "maybe"}}):
assert _get_cron_approval_mode() == "deny"
def test_config_load_failure_defaults_to_deny(self):
"""If config loading fails entirely, default to deny (safe)."""
from unittest.mock import patch as mock_patch
with mock_patch("hermes_cli.config.load_config", side_effect=RuntimeError("config broken")):
assert _get_cron_approval_mode() == "deny"
def test_yaml_boolean_false_maps_to_deny(self):
"""YAML 1.1 parses bare 'off' as False. Ensure it maps to deny."""
from unittest.mock import patch as mock_patch
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"cron_mode": False}}):
# str(False) = "False", which is not in the approve set, so deny
assert _get_cron_approval_mode() == "deny"
# ---------------------------------------------------------------------------
# check_dangerous_command() with cron session
# ---------------------------------------------------------------------------
class TestCronDenyMode:
"""When HERMES_CRON_SESSION is set and cron_mode=deny, dangerous commands are blocked."""
def test_dangerous_command_blocked_in_cron_deny_mode(self, monkeypatch):
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
from unittest.mock import patch as mock_patch
with mock_patch("tools.approval._get_cron_approval_mode", return_value="deny"):
result = check_dangerous_command("rm -rf /tmp/stuff", "local")
assert not result["approved"]
assert "BLOCKED" in result["message"]
assert "cron_mode" in result["message"]
def test_safe_command_allowed_in_cron_deny_mode(self, monkeypatch):
"""Non-dangerous commands still work even with cron_mode=deny."""
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
from unittest.mock import patch as mock_patch
with mock_patch("tools.approval._get_cron_approval_mode", return_value="deny"):
result = check_dangerous_command("ls -la", "local")
assert result["approved"]
def test_multiple_dangerous_patterns_blocked(self, monkeypatch):
"""All dangerous patterns are blocked, not just rm."""
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
dangerous_commands = [
"rm -rf /",
"chmod 777 /etc/passwd",
"mkfs.ext4 /dev/sda1",
"dd if=/dev/zero of=/dev/sda",
]
from unittest.mock import patch as mock_patch
with mock_patch("tools.approval._get_cron_approval_mode", return_value="deny"):
for cmd in dangerous_commands:
is_dangerous, _, _ = detect_dangerous_command(cmd)
if is_dangerous:
result = check_dangerous_command(cmd, "local")
assert not result["approved"], f"Should be blocked: {cmd}"
assert "BLOCKED" in result["message"]
def test_block_message_includes_description(self, monkeypatch):
"""The block message should mention what pattern was matched."""
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
from unittest.mock import patch as mock_patch
with mock_patch("tools.approval._get_cron_approval_mode", return_value="deny"):
result = check_dangerous_command("rm -rf /tmp/stuff", "local")
assert not result["approved"]
# Should contain the description of what was flagged
assert "dangerous" in result["message"].lower() or "delete" in result["message"].lower()
class TestCronApproveMode:
"""When HERMES_CRON_SESSION is set and cron_mode=approve, dangerous commands pass through."""
def test_dangerous_command_allowed_in_cron_approve_mode(self, monkeypatch):
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
from unittest.mock import patch as mock_patch
with mock_patch("tools.approval._get_cron_approval_mode", return_value="approve"):
result = check_dangerous_command("rm -rf /tmp/stuff", "local")
assert result["approved"]
# ---------------------------------------------------------------------------
# check_all_command_guards() with cron session
# ---------------------------------------------------------------------------
class TestCronDenyModeAllGuards:
"""The combined guard function also respects cron_mode."""
def test_dangerous_command_blocked_in_combined_guard(self, monkeypatch):
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
monkeypatch.delenv("HERMES_EXEC_ASK", raising=False)
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
from unittest.mock import patch as mock_patch
with mock_patch("tools.approval._get_cron_approval_mode", return_value="deny"):
result = check_all_command_guards("rm -rf /tmp/stuff", "local")
assert not result["approved"]
assert "BLOCKED" in result["message"]
def test_safe_command_allowed_in_combined_guard(self, monkeypatch):
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
monkeypatch.delenv("HERMES_EXEC_ASK", raising=False)
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
from unittest.mock import patch as mock_patch
with mock_patch("tools.approval._get_cron_approval_mode", return_value="deny"):
result = check_all_command_guards("echo hello", "local")
assert result["approved"]
def test_combined_guard_approve_mode(self, monkeypatch):
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
monkeypatch.delenv("HERMES_EXEC_ASK", raising=False)
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
from unittest.mock import patch as mock_patch
with mock_patch("tools.approval._get_cron_approval_mode", return_value="approve"):
result = check_all_command_guards("rm -rf /tmp/stuff", "local")
assert result["approved"]
# ---------------------------------------------------------------------------
# Edge cases: cron mode interaction with other approval mechanisms
# ---------------------------------------------------------------------------
class TestCronModeInteractions:
"""Cron mode should NOT interfere with other approval bypass mechanisms."""
def test_container_env_still_auto_approves(self, monkeypatch):
"""Docker/sandbox environments bypass approvals regardless of cron_mode."""
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
from unittest.mock import patch as mock_patch
with mock_patch("tools.approval._get_cron_approval_mode", return_value="deny"):
result = check_dangerous_command("rm -rf /", "docker")
assert result["approved"]
def test_yolo_overrides_cron_deny(self, monkeypatch):
"""--yolo still works even if cron_mode=deny."""
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
monkeypatch.setenv("HERMES_YOLO_MODE", "1")
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
from unittest.mock import patch as mock_patch
with mock_patch("tools.approval._get_cron_approval_mode", return_value="deny"):
result = check_dangerous_command("rm -rf /", "local")
assert result["approved"]
def test_non_cron_non_interactive_still_auto_approves(self, monkeypatch):
"""Non-cron, non-interactive sessions (e.g. scripted usage) still auto-approve."""
monkeypatch.delenv("HERMES_CRON_SESSION", raising=False)
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
result = check_dangerous_command("rm -rf /tmp/stuff", "local")
assert result["approved"]

View file

@ -192,23 +192,23 @@ class TestUnifiedCronjobTool:
result = json.loads(
cronjob(
action="create",
skills=["blogwatcher", "find-nearby"],
skills=["blogwatcher", "maps"],
prompt="Use both skills and combine the result.",
schedule="every 1h",
name="Combo job",
)
)
assert result["success"] is True
assert result["skills"] == ["blogwatcher", "find-nearby"]
assert result["skills"] == ["blogwatcher", "maps"]
listing = json.loads(cronjob(action="list"))
assert listing["jobs"][0]["skills"] == ["blogwatcher", "find-nearby"]
assert listing["jobs"][0]["skills"] == ["blogwatcher", "maps"]
def test_multi_skill_default_name_prefers_prompt_when_present(self):
result = json.loads(
cronjob(
action="create",
skills=["blogwatcher", "find-nearby"],
skills=["blogwatcher", "maps"],
prompt="Use both skills and combine the result.",
schedule="every 1h",
)
@ -220,7 +220,7 @@ class TestUnifiedCronjobTool:
created = json.loads(
cronjob(
action="create",
skills=["blogwatcher", "find-nearby"],
skills=["blogwatcher", "maps"],
prompt="Use both skills and combine the result.",
schedule="every 1h",
)

View file

@ -274,6 +274,7 @@ class TestDelegateTask(unittest.TestCase):
model=None,
max_iterations=10,
parent_agent=parent,
task_count=1,
)
self.assertIs(mock_child._print_fn, sink)
@ -294,6 +295,7 @@ class TestDelegateTask(unittest.TestCase):
model=None,
max_iterations=10,
parent_agent=parent,
task_count=1,
)
self.assertTrue(callable(mock_child.thinking_callback))
@ -363,6 +365,7 @@ class TestToolNamePreservation(unittest.TestCase):
model=None,
max_iterations=10,
parent_agent=parent,
task_count=1,
)
except NameError as exc:
self.fail(
@ -1000,6 +1003,7 @@ class TestChildCredentialPoolResolution(unittest.TestCase):
model=None,
max_iterations=10,
parent_agent=parent,
task_count=1,
)
self.assertEqual(mock_child._credential_pool, mock_pool)
@ -1225,6 +1229,7 @@ class TestDelegationReasoningEffort(unittest.TestCase):
_build_child_agent(
task_index=0, goal="test", context=None, toolsets=None,
model=None, max_iterations=50, parent_agent=parent,
task_count=1,
)
call_kwargs = MockAgent.call_args[1]
self.assertEqual(call_kwargs["reasoning_config"], {"enabled": True, "effort": "xhigh"})
@ -1241,6 +1246,7 @@ class TestDelegationReasoningEffort(unittest.TestCase):
_build_child_agent(
task_index=0, goal="test", context=None, toolsets=None,
model=None, max_iterations=50, parent_agent=parent,
task_count=1,
)
call_kwargs = MockAgent.call_args[1]
self.assertEqual(call_kwargs["reasoning_config"], {"enabled": True, "effort": "low"})
@ -1257,6 +1263,7 @@ class TestDelegationReasoningEffort(unittest.TestCase):
_build_child_agent(
task_index=0, goal="test", context=None, toolsets=None,
model=None, max_iterations=50, parent_agent=parent,
task_count=1,
)
call_kwargs = MockAgent.call_args[1]
self.assertEqual(call_kwargs["reasoning_config"], {"enabled": False})
@ -1273,6 +1280,7 @@ class TestDelegationReasoningEffort(unittest.TestCase):
_build_child_agent(
task_index=0, goal="test", context=None, toolsets=None,
model=None, max_iterations=50, parent_agent=parent,
task_count=1,
)
call_kwargs = MockAgent.call_args[1]
self.assertEqual(call_kwargs["reasoning_config"], {"enabled": True, "effort": "medium"})

View file

@ -0,0 +1,145 @@
"""Regression tests for _wait_for_process subprocess cleanup on exception exit.
When the poll loop exits via KeyboardInterrupt or SystemExit (SIGTERM via
cli.py signal handler, SIGINT on the main thread in non-interactive -q mode,
or explicit sys.exit from some caller), the child subprocess must be killed
before the exception propagates otherwise the local backend's use of
os.setsid leaves an orphan with PPID=1.
The live repro that motivated this: hermes chat -q ... 'sleep 300', SIGTERM
to the python process, sleep 300 survived with PPID=1 for the full 300 s
because _wait_for_process never got to call _kill_process before python
died. See commit message for full context.
"""
import os
import signal
import subprocess
import threading
import time
import pytest
from tools.environments.local import LocalEnvironment
@pytest.fixture(autouse=True)
def _isolate_hermes_home(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
(tmp_path / "logs").mkdir(exist_ok=True)
def _pgid_still_alive(pgid: int) -> bool:
"""Return True if any process in the given process group is still alive."""
try:
os.killpg(pgid, 0) # signal 0 = existence check
return True
except ProcessLookupError:
return False
def test_wait_for_process_kills_subprocess_on_keyboardinterrupt():
"""When KeyboardInterrupt arrives mid-poll, the subprocess group must be
killed before the exception is re-raised."""
env = LocalEnvironment(cwd="/tmp")
try:
result_holder = {}
proc_holder = {}
started = threading.Event()
raise_at = [None] # set by the main thread to tell worker when
# Drive execute() on a separate thread so we can SIGNAL-interrupt it
# via a thread-targeted exception without killing our test process.
def worker():
# Spawn a subprocess that will definitely be alive long enough
# to observe the cleanup, via env.execute(...) — the normal path
# that goes through _wait_for_process.
try:
result_holder["result"] = env.execute("sleep 30", timeout=60)
except BaseException as e: # noqa: BLE001 — we want to observe it
result_holder["exception"] = type(e).__name__
t = threading.Thread(target=worker, daemon=True)
t.start()
# Wait until the subprocess actually exists. LocalEnvironment.execute
# does init_session() (one spawn) before the real command, so we need
# to wait until a sleep 30 is visible. Use pgrep-style lookup via
# /proc to find the bash process running our sleep.
deadline = time.monotonic() + 5.0
target_pid = None
while time.monotonic() < deadline:
# Walk our children and grand-children to find one running 'sleep 30'
try:
import psutil # optional — fall back if absent
for p in psutil.Process(os.getpid()).children(recursive=True):
try:
if "sleep 30" in " ".join(p.cmdline()):
target_pid = p.pid
break
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue
except ImportError:
# Fall back to ps
ps = subprocess.run(
["ps", "-eo", "pid,ppid,pgid,cmd"], capture_output=True, text=True,
)
for line in ps.stdout.splitlines():
if "sleep 30" in line and "grep" not in line:
parts = line.split()
if parts and parts[0].isdigit():
target_pid = int(parts[0])
break
if target_pid:
break
time.sleep(0.1)
assert target_pid is not None, (
"test setup: couldn't find 'sleep 30' subprocess after 5 s"
)
pgid = os.getpgid(target_pid)
assert _pgid_still_alive(pgid), "sanity: subprocess should be alive"
# Now inject a KeyboardInterrupt into the worker thread the same
# way CPython's signal machinery would. We use ctypes.PyThreadState_SetAsyncExc
# which is how signal delivery to non-main threads is simulated.
import ctypes
import sys as _sys
# py-thread-state exception targets need the ident, not the Thread
tid = t.ident
assert tid is not None
# Fire KeyboardInterrupt into the worker thread
ret = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_ulong(tid), ctypes.py_object(KeyboardInterrupt),
)
assert ret == 1, f"SetAsyncExc returned {ret}, expected 1"
# Give the worker a moment to: hit the exception at the next poll,
# run the except-block cleanup (_kill_process), and exit.
t.join(timeout=5.0)
assert not t.is_alive(), "worker didn't exit within 5 s of the interrupt"
# The critical assertion: the subprocess GROUP must be dead. Not
# just the bash wrapper — the 'sleep 30' child too.
# Give the SIGTERM+1s wait+SIGKILL escalation a moment to complete.
deadline = time.monotonic() + 3.0
while time.monotonic() < deadline:
if not _pgid_still_alive(pgid):
break
time.sleep(0.1)
assert not _pgid_still_alive(pgid), (
f"subprocess group {pgid} is STILL ALIVE after worker received "
f"KeyboardInterrupt — orphan bug regressed. This is the "
f"sleep-300-survives-SIGTERM scenario from Physikal's Apr 2026 "
f"report. See tools/environments/base.py _wait_for_process "
f"except-block."
)
# And the worker should have observed the KeyboardInterrupt (i.e.
# it re-raised cleanly, not silently swallowed).
assert result_holder.get("exception") == "KeyboardInterrupt", (
f"worker result: {result_holder!r} — expected KeyboardInterrupt "
f"propagation after cleanup"
)
finally:
try:
env.cleanup()
except Exception:
pass

View file

@ -296,6 +296,8 @@ class TestBuiltinDiscovery:
"tools.code_execution_tool",
"tools.cronjob_tools",
"tools.delegate_tool",
"tools.feishu_doc_tool",
"tools.feishu_drive_tool",
"tools.file_tools",
"tools.homeassistant_tool",
"tools.image_generation_tool",

View file

@ -4,6 +4,7 @@ import io
import json
import sys
import threading
import time
from unittest.mock import MagicMock, patch
import pytest
@ -120,7 +121,9 @@ def test_block_and_respond(capture):
rid = next(iter(server._pending))
server._answers[rid] = "my_answer"
server._pending[rid].set()
# _pending values are (sid, Event) tuples — unpack to set the Event
_, ev = server._pending[rid]
ev.set()
threading.Event().wait(0.1)
assert result[0] == "my_answer"
@ -128,7 +131,8 @@ def test_block_and_respond(capture):
def test_clear_pending(server):
ev = threading.Event()
server._pending["r1"] = ev
# _pending values are (sid, Event) tuples
server._pending["r1"] = ("sid-x", ev)
server._clear_pending()
assert ev.is_set()
@ -231,3 +235,279 @@ def test_cli_exec_blocked(server, argv):
])
def test_cli_exec_allowed(server, argv):
assert server._cli_exec_blocked(argv) is None
# ── slash.exec skill command interception ────────────────────────────
def test_slash_exec_rejects_skill_commands(server):
"""slash.exec must reject skill commands so the TUI falls through to command.dispatch."""
# Register a mock session
sid = "test-session"
server._sessions[sid] = {"session_key": sid, "agent": None}
# Mock scan_skill_commands to return a known skill
fake_skills = {"/hermes-agent-dev": {"name": "hermes-agent-dev", "description": "Dev workflow"}}
with patch("agent.skill_commands.get_skill_commands", return_value=fake_skills):
resp = server.handle_request({
"id": "r1",
"method": "slash.exec",
"params": {"command": "hermes-agent-dev", "session_id": sid},
})
# Should return an error so the TUI's .catch() fires command.dispatch
assert "error" in resp
assert resp["error"]["code"] == 4018
assert "skill command" in resp["error"]["message"]
@pytest.mark.parametrize("cmd", ["retry", "queue hello", "q hello", "steer fix the test", "plan"])
def test_slash_exec_rejects_pending_input_commands(server, cmd):
"""slash.exec must reject commands that use _pending_input in the CLI."""
sid = "test-session"
server._sessions[sid] = {"session_key": sid, "agent": None}
resp = server.handle_request({
"id": "r1",
"method": "slash.exec",
"params": {"command": cmd, "session_id": sid},
})
assert "error" in resp
assert resp["error"]["code"] == 4018
assert "pending-input command" in resp["error"]["message"]
def test_command_dispatch_queue_sends_message(server):
"""command.dispatch /queue returns {type: 'send', message: ...} for the TUI."""
sid = "test-session"
server._sessions[sid] = {"session_key": sid}
resp = server.handle_request({
"id": "r1",
"method": "command.dispatch",
"params": {"name": "queue", "arg": "tell me about quantum computing", "session_id": sid},
})
assert "error" not in resp
result = resp["result"]
assert result["type"] == "send"
assert result["message"] == "tell me about quantum computing"
def test_command_dispatch_queue_requires_arg(server):
"""command.dispatch /queue without an argument returns an error."""
sid = "test-session"
server._sessions[sid] = {"session_key": sid}
resp = server.handle_request({
"id": "r2",
"method": "command.dispatch",
"params": {"name": "queue", "arg": "", "session_id": sid},
})
assert "error" in resp
assert resp["error"]["code"] == 4004
def test_command_dispatch_steer_fallback_sends_message(server):
"""command.dispatch /steer with no active agent falls back to send."""
sid = "test-session"
server._sessions[sid] = {"session_key": sid, "agent": None}
resp = server.handle_request({
"id": "r3",
"method": "command.dispatch",
"params": {"name": "steer", "arg": "focus on testing", "session_id": sid},
})
assert "error" not in resp
result = resp["result"]
assert result["type"] == "send"
assert result["message"] == "focus on testing"
def test_command_dispatch_retry_finds_last_user_message(server):
"""command.dispatch /retry walks session['history'] to find the last user message."""
sid = "test-session"
history = [
{"role": "user", "content": "first question"},
{"role": "assistant", "content": "first answer"},
{"role": "user", "content": "second question"},
{"role": "assistant", "content": "second answer"},
]
server._sessions[sid] = {
"session_key": sid,
"agent": None,
"history": history,
"history_lock": threading.Lock(),
"history_version": 0,
}
resp = server.handle_request({
"id": "r4",
"method": "command.dispatch",
"params": {"name": "retry", "session_id": sid},
})
assert "error" not in resp
result = resp["result"]
assert result["type"] == "send"
assert result["message"] == "second question"
# Verify history was truncated: everything from last user message onward removed
assert len(server._sessions[sid]["history"]) == 2
assert server._sessions[sid]["history"][-1]["role"] == "assistant"
assert server._sessions[sid]["history_version"] == 1
def test_command_dispatch_retry_empty_history(server):
"""command.dispatch /retry with empty history returns error."""
sid = "test-session"
server._sessions[sid] = {
"session_key": sid,
"agent": None,
"history": [],
"history_lock": threading.Lock(),
"history_version": 0,
}
resp = server.handle_request({
"id": "r5",
"method": "command.dispatch",
"params": {"name": "retry", "session_id": sid},
})
assert "error" in resp
assert resp["error"]["code"] == 4018
def test_command_dispatch_retry_handles_multipart_content(server):
"""command.dispatch /retry extracts text from multipart content lists."""
sid = "test-session"
history = [
{"role": "user", "content": [
{"type": "text", "text": "analyze this"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
]},
{"role": "assistant", "content": "I see the image."},
]
server._sessions[sid] = {
"session_key": sid,
"agent": None,
"history": history,
"history_lock": threading.Lock(),
"history_version": 0,
}
resp = server.handle_request({
"id": "r6",
"method": "command.dispatch",
"params": {"name": "retry", "session_id": sid},
})
assert "error" not in resp
result = resp["result"]
assert result["type"] == "send"
assert result["message"] == "analyze this"
def test_command_dispatch_returns_skill_payload(server):
"""command.dispatch returns structured skill payload for the TUI to send()."""
sid = "test-session"
server._sessions[sid] = {"session_key": sid}
fake_skills = {"/hermes-agent-dev": {"name": "hermes-agent-dev", "description": "Dev workflow"}}
fake_msg = "Loaded skill content here"
with patch("agent.skill_commands.scan_skill_commands", return_value=fake_skills), \
patch("agent.skill_commands.build_skill_invocation_message", return_value=fake_msg):
resp = server.handle_request({
"id": "r2",
"method": "command.dispatch",
"params": {"name": "hermes-agent-dev", "session_id": sid},
})
assert "error" not in resp
result = resp["result"]
assert result["type"] == "skill"
assert result["message"] == fake_msg
assert result["name"] == "hermes-agent-dev"
# ── dispatch(): pool routing for long handlers (#12546) ──────────────
def test_dispatch_runs_short_handlers_inline(server):
"""Non-long handlers return their response synchronously from dispatch()."""
server._methods["fast.ping"] = lambda rid, params: server._ok(rid, {"pong": True})
resp = server.dispatch({"id": "r1", "method": "fast.ping", "params": {}})
assert resp == {"jsonrpc": "2.0", "id": "r1", "result": {"pong": True}}
def test_dispatch_offloads_long_handlers_and_emits_via_stdout(capture):
"""Long handlers run on the pool and write their response via write_json."""
server, buf = capture
server._methods["slash.exec"] = lambda rid, params: server._ok(rid, {"output": "hi"})
resp = server.dispatch({"id": "r2", "method": "slash.exec", "params": {}})
assert resp is None
for _ in range(50):
if buf.getvalue():
break
time.sleep(0.01)
written = json.loads(buf.getvalue())
assert written == {"jsonrpc": "2.0", "id": "r2", "result": {"output": "hi"}}
def test_dispatch_long_handler_does_not_block_fast_handler(server):
"""A slow long handler must not prevent a concurrent fast handler from completing."""
released = threading.Event()
server._methods["slash.exec"] = lambda rid, params: (released.wait(timeout=5), server._ok(rid, {"done": True}))[1]
server._methods["fast.ping"] = lambda rid, params: server._ok(rid, {"pong": True})
t0 = time.monotonic()
assert server.dispatch({"id": "slow", "method": "slash.exec", "params": {}}) is None
fast_resp = server.dispatch({"id": "fast", "method": "fast.ping", "params": {}})
fast_elapsed = time.monotonic() - t0
assert fast_resp["result"] == {"pong": True}
assert fast_elapsed < 0.5, f"fast handler blocked for {fast_elapsed:.2f}s behind slow handler"
released.set()
def test_dispatch_long_handler_exception_produces_error_response(capture):
"""An exception inside a pool-dispatched handler still yields a JSON-RPC error."""
server, buf = capture
def boom(rid, params):
raise RuntimeError("kaboom")
server._methods["slash.exec"] = boom
server.dispatch({"id": "r3", "method": "slash.exec", "params": {}})
for _ in range(50):
if buf.getvalue():
break
time.sleep(0.01)
written = json.loads(buf.getvalue())
assert written["id"] == "r3"
assert written["error"]["code"] == -32000
assert "kaboom" in written["error"]["message"]
def test_dispatch_unknown_long_method_still_goes_inline(server):
"""Method name not in _LONG_HANDLERS takes the sync path even if handler is slow."""
server._methods["some.method"] = lambda rid, params: server._ok(rid, {"ok": True})
resp = server.dispatch({"id": "r4", "method": "some.method", "params": {}})
assert resp["result"] == {"ok": True}