mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-08 08:11:38 +00:00
Merge branch 'main' of github.com:NousResearch/hermes-agent into fix/show-reasoning-per-platform
This commit is contained in:
commit
26bd52c1ba
291 changed files with 23713 additions and 2985 deletions
|
|
@ -697,7 +697,12 @@ class TestIsConnectionError:
|
|||
|
||||
|
||||
class TestKimiForCodingTemperature:
|
||||
"""kimi-for-coding now requires temperature=0.6 exactly."""
|
||||
"""Moonshot kimi-for-coding models require fixed temperatures.
|
||||
|
||||
k2.5 / k2-turbo-preview / k2-0905-preview → 0.6 (non-thinking lock).
|
||||
k2-thinking / k2-thinking-turbo → 1.0 (thinking lock).
|
||||
kimi-k2-instruct* and every other model preserve the caller's temperature.
|
||||
"""
|
||||
|
||||
def test_build_call_kwargs_forces_fixed_temperature(self):
|
||||
from agent.auxiliary_client import _build_call_kwargs
|
||||
|
|
@ -772,12 +777,55 @@ class TestKimiForCodingTemperature:
|
|||
assert kwargs["model"] == "kimi-for-coding"
|
||||
assert kwargs["temperature"] == 0.6
|
||||
|
||||
def test_non_kimi_model_still_preserves_temperature(self):
|
||||
@pytest.mark.parametrize(
|
||||
"model,expected",
|
||||
[
|
||||
("kimi-k2.5", 0.6),
|
||||
("kimi-k2-turbo-preview", 0.6),
|
||||
("kimi-k2-0905-preview", 0.6),
|
||||
("kimi-k2-thinking", 1.0),
|
||||
("kimi-k2-thinking-turbo", 1.0),
|
||||
("moonshotai/kimi-k2.5", 0.6),
|
||||
("moonshotai/Kimi-K2-Thinking", 1.0),
|
||||
],
|
||||
)
|
||||
def test_kimi_k2_family_temperature_override(self, model, expected):
|
||||
"""Moonshot kimi-k2.* models only accept fixed temperatures.
|
||||
|
||||
Non-thinking models → 0.6, thinking-mode models → 1.0.
|
||||
"""
|
||||
from agent.auxiliary_client import _build_call_kwargs
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
provider="kimi-coding",
|
||||
model="kimi-k2.5",
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
assert kwargs["temperature"] == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"anthropic/claude-sonnet-4-6",
|
||||
"gpt-5.4",
|
||||
# kimi-k2-instruct is the non-coding K2 family — temperature is
|
||||
# variable (recommended 0.6 but not enforced). Must not clamp.
|
||||
"kimi-k2-instruct",
|
||||
"moonshotai/Kimi-K2-Instruct",
|
||||
"moonshotai/Kimi-K2-Instruct-0905",
|
||||
"kimi-k2-instruct-0905",
|
||||
# Hypothetical future kimi name not in the whitelist.
|
||||
"kimi-k2-experimental",
|
||||
],
|
||||
)
|
||||
def test_non_restricted_model_preserves_temperature(self, model):
|
||||
from agent.auxiliary_client import _build_call_kwargs
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
provider="openrouter",
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.3,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -781,3 +781,127 @@ class TestTokenBudgetTailProtection:
|
|||
# Tool at index 2 is outside the protected tail (last 3 = indices 2,3,4)
|
||||
# so it might or might not be pruned depending on boundary
|
||||
assert isinstance(pruned, int)
|
||||
|
||||
|
||||
class TestTruncateToolCallArgsJson:
|
||||
"""Regression tests for #11762.
|
||||
|
||||
The previous implementation produced invalid JSON by slicing
|
||||
``function.arguments`` mid-string, which caused non-retryable 400s from
|
||||
strict providers (observed on MiniMax) and stuck long sessions in a
|
||||
re-send loop. The helper here must always emit parseable JSON whose
|
||||
shape matches the original — shrunken, not corrupted.
|
||||
"""
|
||||
|
||||
def _helper(self):
|
||||
from agent.context_compressor import _truncate_tool_call_args_json
|
||||
return _truncate_tool_call_args_json
|
||||
|
||||
def test_shrunken_args_remain_valid_json(self):
|
||||
import json as _json
|
||||
shrink = self._helper()
|
||||
original = _json.dumps({
|
||||
"path": "~/.hermes/skills/shopping/browser-setup-notes.md",
|
||||
"content": "# Shopping Browser Setup Notes\n\n" + "abc " * 400,
|
||||
})
|
||||
assert len(original) > 500
|
||||
shrunk = shrink(original)
|
||||
parsed = _json.loads(shrunk) # must not raise
|
||||
assert parsed["path"] == "~/.hermes/skills/shopping/browser-setup-notes.md"
|
||||
assert parsed["content"].endswith("...[truncated]")
|
||||
assert len(shrunk) < len(original)
|
||||
|
||||
def test_non_json_arguments_pass_through(self):
|
||||
shrink = self._helper()
|
||||
not_json = "this is not json at all, " * 50
|
||||
assert shrink(not_json) == not_json
|
||||
|
||||
def test_short_string_leaves_unchanged(self):
|
||||
import json as _json
|
||||
shrink = self._helper()
|
||||
payload = _json.dumps({"command": "ls -la", "cwd": "/tmp"})
|
||||
assert _json.loads(shrink(payload)) == {"command": "ls -la", "cwd": "/tmp"}
|
||||
|
||||
def test_nested_structures_are_walked(self):
|
||||
import json as _json
|
||||
shrink = self._helper()
|
||||
payload = _json.dumps({
|
||||
"messages": [
|
||||
{"role": "user", "content": "x" * 500},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
],
|
||||
"meta": {"note": "y" * 500},
|
||||
})
|
||||
parsed = _json.loads(shrink(payload))
|
||||
assert parsed["messages"][0]["content"].endswith("...[truncated]")
|
||||
assert parsed["messages"][1]["content"] == "ok"
|
||||
assert parsed["meta"]["note"].endswith("...[truncated]")
|
||||
|
||||
def test_non_string_leaves_preserved(self):
|
||||
import json as _json
|
||||
shrink = self._helper()
|
||||
payload = _json.dumps({
|
||||
"retries": 3,
|
||||
"enabled": True,
|
||||
"timeout": None,
|
||||
"items": [1, 2, 3],
|
||||
"note": "z" * 500,
|
||||
})
|
||||
parsed = _json.loads(shrink(payload))
|
||||
assert parsed["retries"] == 3
|
||||
assert parsed["enabled"] is True
|
||||
assert parsed["timeout"] is None
|
||||
assert parsed["items"] == [1, 2, 3]
|
||||
assert parsed["note"].endswith("...[truncated]")
|
||||
|
||||
def test_scalar_json_string_gets_shrunk(self):
|
||||
import json as _json
|
||||
shrink = self._helper()
|
||||
payload = _json.dumps("q" * 500)
|
||||
parsed = _json.loads(shrink(payload))
|
||||
assert isinstance(parsed, str)
|
||||
assert parsed.endswith("...[truncated]")
|
||||
|
||||
def test_unicode_preserved(self):
|
||||
import json as _json
|
||||
shrink = self._helper()
|
||||
payload = _json.dumps({"content": "非德满" + ("a" * 500)})
|
||||
out = shrink(payload)
|
||||
# ensure_ascii=False keeps CJK intact rather than emitting \uXXXX
|
||||
assert "非德满" in out
|
||||
|
||||
def test_pass3_emits_valid_json_for_downstream_provider(self):
|
||||
"""End-to-end: Pass 3 must never produce the exact failure payload
|
||||
that caused the 400 loop (unterminated string, missing brace)."""
|
||||
import json as _json
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(
|
||||
model="test/model",
|
||||
threshold_percent=0.85,
|
||||
protect_first_n=1,
|
||||
protect_last_n=1,
|
||||
quiet_mode=True,
|
||||
)
|
||||
huge_content = "# Shopping Browser Setup Notes\n\n## Overview\n" + "x " * 400
|
||||
args_payload = _json.dumps({
|
||||
"path": "~/.hermes/skills/shopping/browser-setup-notes.md",
|
||||
"content": huge_content,
|
||||
})
|
||||
assert len(args_payload) > 500 # triggers the Pass-3 shrink
|
||||
messages = [
|
||||
{"role": "user", "content": "please write two files"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [
|
||||
{"id": "call_1", "type": "function",
|
||||
"function": {"name": "write_file", "arguments": args_payload}},
|
||||
]},
|
||||
{"role": "tool", "tool_call_id": "call_1",
|
||||
"content": '{"bytes_written": 727}'},
|
||||
{"role": "user", "content": "ok"},
|
||||
{"role": "assistant", "content": "done"},
|
||||
]
|
||||
result, _ = c._prune_old_tool_results(messages, protect_tail_count=2)
|
||||
shrunk = result[1]["tool_calls"][0]["function"]["arguments"]
|
||||
# Must parse — otherwise downstream provider returns 400
|
||||
parsed = _json.loads(shrunk)
|
||||
assert parsed["path"] == "~/.hermes/skills/shopping/browser-setup-notes.md"
|
||||
assert parsed["content"].endswith("...[truncated]")
|
||||
|
|
|
|||
|
|
@ -971,8 +971,6 @@ class TestHonchoCadenceTracking:
|
|||
class FakeManager:
|
||||
def prefetch_context(self, key, query=None):
|
||||
pass
|
||||
def prefetch_dialectic(self, key, query):
|
||||
pass
|
||||
|
||||
p._manager = FakeManager()
|
||||
|
||||
|
|
|
|||
|
|
@ -208,34 +208,81 @@ class TestMem0UserIdScoping:
|
|||
|
||||
|
||||
class TestHonchoUserIdScoping:
|
||||
"""Verify Honcho plugin uses gateway user_id for peer_name when provided."""
|
||||
"""Verify Honcho plugin keeps runtime user scoping separate from config peer_name."""
|
||||
|
||||
def test_gateway_user_id_overrides_peer_name(self):
|
||||
"""When user_id is in kwargs and no explicit peer_name, user_id should be used."""
|
||||
def test_gateway_user_id_is_passed_as_runtime_peer(self):
|
||||
"""Gateway user_id should scope Honcho sessions without mutating config peer_name."""
|
||||
from plugins.memory.honcho import HonchoMemoryProvider
|
||||
|
||||
provider = HonchoMemoryProvider()
|
||||
|
||||
# Create a mock config with NO explicit peer_name
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.enabled = True
|
||||
mock_cfg.api_key = "test-key"
|
||||
mock_cfg.base_url = None
|
||||
mock_cfg.peer_name = "" # No explicit peer_name — user_id should fill it
|
||||
mock_cfg.recall_mode = "tools" # Use tools mode to defer session init
|
||||
mock_cfg.peer_name = "static-user"
|
||||
mock_cfg.recall_mode = "context"
|
||||
mock_cfg.context_tokens = None
|
||||
mock_cfg.raw = {}
|
||||
mock_cfg.dialectic_depth = 1
|
||||
mock_cfg.dialectic_depth_levels = None
|
||||
mock_cfg.init_on_session_start = False
|
||||
mock_cfg.ai_peer = "hermes"
|
||||
mock_cfg.resolve_session_name.return_value = "test-sess"
|
||||
mock_cfg.session_strategy = "shared"
|
||||
|
||||
with patch(
|
||||
"plugins.memory.honcho.client.HonchoClientConfig.from_global_config",
|
||||
return_value=mock_cfg,
|
||||
):
|
||||
), patch(
|
||||
"plugins.memory.honcho.client.get_honcho_client",
|
||||
return_value=MagicMock(),
|
||||
), patch(
|
||||
"plugins.memory.honcho.session.HonchoSessionManager",
|
||||
) as mock_manager_cls:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_or_create.return_value = MagicMock(messages=[])
|
||||
mock_manager_cls.return_value = mock_manager
|
||||
provider.initialize(
|
||||
session_id="test-sess",
|
||||
user_id="discord_user_789",
|
||||
platform="discord",
|
||||
)
|
||||
|
||||
# The config's peer_name should have been overridden with the user_id
|
||||
assert mock_cfg.peer_name == "discord_user_789"
|
||||
assert mock_cfg.peer_name == "static-user"
|
||||
assert mock_manager_cls.call_args.kwargs["runtime_user_peer_name"] == "discord_user_789"
|
||||
|
||||
def test_session_manager_prefers_runtime_user_id_over_config_peer_name(self):
|
||||
"""Session manager should isolate gateway users even when config peer_name is static."""
|
||||
from plugins.memory.honcho.session import HonchoSessionManager
|
||||
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.peer_name = "static-user"
|
||||
mock_cfg.ai_peer = "hermes"
|
||||
mock_cfg.write_frequency = "sync"
|
||||
mock_cfg.dialectic_reasoning_level = "low"
|
||||
mock_cfg.dialectic_dynamic = True
|
||||
mock_cfg.dialectic_max_chars = 600
|
||||
mock_cfg.observation_mode = "directional"
|
||||
mock_cfg.user_observe_me = True
|
||||
mock_cfg.user_observe_others = True
|
||||
mock_cfg.ai_observe_me = True
|
||||
mock_cfg.ai_observe_others = True
|
||||
|
||||
manager = HonchoSessionManager(
|
||||
honcho=MagicMock(),
|
||||
config=mock_cfg,
|
||||
runtime_user_peer_name="discord_user_789",
|
||||
)
|
||||
|
||||
with patch.object(manager, "_get_or_create_peer", return_value=MagicMock()), patch.object(
|
||||
manager,
|
||||
"_get_or_create_honcho_session",
|
||||
return_value=(MagicMock(), []),
|
||||
):
|
||||
session = manager.get_or_create("discord:channel-1")
|
||||
|
||||
assert session.user_peer_id == "discord_user_789"
|
||||
|
||||
def test_no_user_id_preserves_config_peer_name(self):
|
||||
"""Without user_id, the config peer_name should be preserved."""
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ class TestBuildChildProgressCallback:
|
|||
parent._delegate_spinner = None
|
||||
parent.tool_progress_callback = None
|
||||
|
||||
cb = _build_child_progress_callback(0, parent)
|
||||
cb = _build_child_progress_callback(0, "test goal", parent)
|
||||
assert cb is None
|
||||
|
||||
def test_cli_spinner_tool_event(self):
|
||||
|
|
@ -93,7 +93,7 @@ class TestBuildChildProgressCallback:
|
|||
parent._delegate_spinner = spinner
|
||||
parent.tool_progress_callback = None
|
||||
|
||||
cb = _build_child_progress_callback(0, parent)
|
||||
cb = _build_child_progress_callback(0, "test goal", parent)
|
||||
assert cb is not None
|
||||
|
||||
cb("tool.started", "web_search", "quantum computing", {})
|
||||
|
|
@ -113,7 +113,7 @@ class TestBuildChildProgressCallback:
|
|||
parent._delegate_spinner = spinner
|
||||
parent.tool_progress_callback = None
|
||||
|
||||
cb = _build_child_progress_callback(0, parent)
|
||||
cb = _build_child_progress_callback(0, "test goal", parent)
|
||||
cb("_thinking", "I'll search for papers first")
|
||||
|
||||
output = buf.getvalue()
|
||||
|
|
@ -121,54 +121,64 @@ class TestBuildChildProgressCallback:
|
|||
assert "search for papers" in output
|
||||
|
||||
def test_gateway_batched_progress(self):
|
||||
"""Gateway path should batch tool calls and flush at BATCH_SIZE."""
|
||||
"""Gateway path: each tool.started relays a subagent.tool event, and a
|
||||
subagent.progress summary fires once BATCH_SIZE tools accumulate."""
|
||||
parent = MagicMock()
|
||||
parent._delegate_spinner = None
|
||||
parent_cb = MagicMock()
|
||||
parent.tool_progress_callback = parent_cb
|
||||
|
||||
cb = _build_child_progress_callback(0, parent)
|
||||
|
||||
# Send 4 tool calls — shouldn't flush yet (BATCH_SIZE = 5)
|
||||
|
||||
cb = _build_child_progress_callback(0, "test goal", parent)
|
||||
|
||||
# Each tool.started relays a subagent.tool event immediately (per-tool relay).
|
||||
for i in range(4):
|
||||
cb("tool.started", f"tool_{i}", f"arg_{i}", {})
|
||||
parent_cb.assert_not_called()
|
||||
|
||||
# 5th call should trigger flush
|
||||
cb("tool.started", "tool_4", "arg_4", {})
|
||||
parent_cb.assert_called_once()
|
||||
call_args = parent_cb.call_args
|
||||
assert "tool_0" in call_args[0][1]
|
||||
assert "tool_4" in call_args[0][1]
|
||||
# 4 per-tool relays so far, no batch summary yet (BATCH_SIZE=5)
|
||||
events = [c.args[0] for c in parent_cb.call_args_list]
|
||||
assert events == ["subagent.tool"] * 4
|
||||
|
||||
def test_thinking_not_relayed_to_gateway(self):
|
||||
"""Thinking events should NOT be sent to gateway (too noisy)."""
|
||||
# 5th call triggers another per-tool relay PLUS the batch-size summary
|
||||
cb("tool.started", "tool_4", "arg_4", {})
|
||||
events = [c.args[0] for c in parent_cb.call_args_list]
|
||||
assert events == ["subagent.tool"] * 5 + ["subagent.progress"]
|
||||
summary_call = parent_cb.call_args_list[-1]
|
||||
summary_text = summary_call.kwargs.get("preview") or summary_call.args[2]
|
||||
assert "tool_0" in summary_text
|
||||
assert "tool_4" in summary_text
|
||||
|
||||
def test_thinking_relayed_to_gateway(self):
|
||||
"""Thinking events are relayed as subagent.thinking events."""
|
||||
parent = MagicMock()
|
||||
parent._delegate_spinner = None
|
||||
parent_cb = MagicMock()
|
||||
parent.tool_progress_callback = parent_cb
|
||||
|
||||
cb = _build_child_progress_callback(0, parent)
|
||||
|
||||
cb = _build_child_progress_callback(0, "test goal", parent)
|
||||
cb("_thinking", "some reasoning text")
|
||||
|
||||
parent_cb.assert_not_called()
|
||||
|
||||
parent_cb.assert_called_once()
|
||||
assert parent_cb.call_args.args[0] == "subagent.thinking"
|
||||
assert parent_cb.call_args.args[2] == "some reasoning text"
|
||||
|
||||
def test_parallel_callbacks_independent(self):
|
||||
"""Each child's callback should have independent batch state."""
|
||||
"""Each child's callback batches tool names independently."""
|
||||
parent = MagicMock()
|
||||
parent._delegate_spinner = None
|
||||
parent_cb = MagicMock()
|
||||
parent.tool_progress_callback = parent_cb
|
||||
|
||||
cb0 = _build_child_progress_callback(0, parent)
|
||||
cb1 = _build_child_progress_callback(1, parent)
|
||||
|
||||
# Send 3 calls to each — neither should flush (batch size = 5)
|
||||
|
||||
cb0 = _build_child_progress_callback(0, "goal a", parent)
|
||||
cb1 = _build_child_progress_callback(1, "goal b", parent)
|
||||
|
||||
# 3 tool.started per child = 6 per-tool relays; neither should hit
|
||||
# the batch-size summary (batch size = 5, counted per-child).
|
||||
for i in range(3):
|
||||
cb0(f"tool_{i}")
|
||||
cb1(f"other_{i}")
|
||||
|
||||
parent_cb.assert_not_called()
|
||||
cb0("tool.started", f"tool_{i}", f"a_{i}", {})
|
||||
cb1("tool.started", f"other_{i}", f"b_{i}", {})
|
||||
|
||||
events = [c.args[0] for c in parent_cb.call_args_list]
|
||||
assert events.count("subagent.tool") == 6
|
||||
assert "subagent.progress" not in events
|
||||
|
||||
def test_task_index_prefix_in_batch_mode(self):
|
||||
"""Batch mode (task_count > 1) should show 1-indexed prefix for all tasks."""
|
||||
|
|
@ -182,7 +192,7 @@ class TestBuildChildProgressCallback:
|
|||
parent.tool_progress_callback = None
|
||||
|
||||
# task_index=0 in a batch of 3 → prefix "[1]"
|
||||
cb0 = _build_child_progress_callback(0, parent, task_count=3)
|
||||
cb0 = _build_child_progress_callback(0, "test goal", parent, task_count=3)
|
||||
cb0("web_search", "test")
|
||||
output = buf.getvalue()
|
||||
assert "[1]" in output
|
||||
|
|
@ -190,7 +200,7 @@ class TestBuildChildProgressCallback:
|
|||
# task_index=2 in a batch of 3 → prefix "[3]"
|
||||
buf.truncate(0)
|
||||
buf.seek(0)
|
||||
cb2 = _build_child_progress_callback(2, parent, task_count=3)
|
||||
cb2 = _build_child_progress_callback(2, "test goal", parent, task_count=3)
|
||||
cb2("web_search", "test")
|
||||
output = buf.getvalue()
|
||||
assert "[3]" in output
|
||||
|
|
@ -206,7 +216,7 @@ class TestBuildChildProgressCallback:
|
|||
parent._delegate_spinner = spinner
|
||||
parent.tool_progress_callback = None
|
||||
|
||||
cb = _build_child_progress_callback(0, parent, task_count=1)
|
||||
cb = _build_child_progress_callback(0, "test goal", parent, task_count=1)
|
||||
cb("tool.started", "web_search", "test", {})
|
||||
|
||||
output = buf.getvalue()
|
||||
|
|
@ -321,26 +331,31 @@ class TestBatchFlush:
|
|||
"""Tests for gateway batch flush on subagent completion."""
|
||||
|
||||
def test_flush_sends_remaining_batch(self):
|
||||
"""_flush should send remaining tool names to gateway."""
|
||||
"""_flush should send a final subagent.progress summary of any unsent
|
||||
tool names in the batch (less than BATCH_SIZE)."""
|
||||
parent = MagicMock()
|
||||
parent._delegate_spinner = None
|
||||
parent_cb = MagicMock()
|
||||
parent.tool_progress_callback = parent_cb
|
||||
|
||||
cb = _build_child_progress_callback(0, parent)
|
||||
cb = _build_child_progress_callback(0, "test goal", parent)
|
||||
|
||||
# Send 3 tools (below batch size of 5)
|
||||
# Send 3 tools (below batch size of 5) — each relays subagent.tool
|
||||
cb("tool.started", "web_search", "query1", {})
|
||||
cb("tool.started", "read_file", "file.txt", {})
|
||||
cb("tool.started", "write_file", "out.txt", {})
|
||||
parent_cb.assert_not_called()
|
||||
events = [c.args[0] for c in parent_cb.call_args_list]
|
||||
assert events == ["subagent.tool"] * 3 # per-tool relays so far
|
||||
assert "subagent.progress" not in events # no batch-size summary yet
|
||||
|
||||
# Flush should send the remaining 3
|
||||
# Flush should send the remaining 3 as a summary
|
||||
cb._flush()
|
||||
parent_cb.assert_called_once()
|
||||
summary = parent_cb.call_args[0][1]
|
||||
assert "web_search" in summary
|
||||
assert "write_file" in summary
|
||||
events = [c.args[0] for c in parent_cb.call_args_list]
|
||||
assert events[-1] == "subagent.progress"
|
||||
summary_call = parent_cb.call_args_list[-1]
|
||||
summary_text = summary_call.kwargs.get("preview") or summary_call.args[2]
|
||||
assert "web_search" in summary_text
|
||||
assert "write_file" in summary_text
|
||||
|
||||
def test_flush_noop_when_batch_empty(self):
|
||||
"""_flush should not send anything when batch is empty."""
|
||||
|
|
@ -349,7 +364,7 @@ class TestBatchFlush:
|
|||
parent_cb = MagicMock()
|
||||
parent.tool_progress_callback = parent_cb
|
||||
|
||||
cb = _build_child_progress_callback(0, parent)
|
||||
cb = _build_child_progress_callback(0, "test goal", parent)
|
||||
cb._flush()
|
||||
parent_cb.assert_not_called()
|
||||
|
||||
|
|
@ -364,7 +379,7 @@ class TestBatchFlush:
|
|||
parent._delegate_spinner = spinner
|
||||
parent.tool_progress_callback = None
|
||||
|
||||
cb = _build_child_progress_callback(0, parent)
|
||||
cb = _build_child_progress_callback(0, "test goal", parent)
|
||||
cb("tool.started", "web_search", "test", {})
|
||||
cb._flush() # Should not crash
|
||||
|
||||
|
|
|
|||
|
|
@ -237,6 +237,13 @@ class TestCLIStatusBar:
|
|||
cli_obj._spinner_text = ""
|
||||
assert cli_obj._spinner_widget_height(width=90) == 0
|
||||
|
||||
def test_spinner_height_uses_display_width_for_wide_characters(self):
|
||||
cli_obj = _make_cli()
|
||||
cli_obj._spinner_text = "你" * 40
|
||||
cli_obj._tool_start_time = 0
|
||||
|
||||
assert cli_obj._spinner_widget_height(width=64) == 2
|
||||
|
||||
def test_voice_status_bar_compacts_on_narrow_terminals(self):
|
||||
cli_obj = _make_cli()
|
||||
cli_obj._voice_mode = True
|
||||
|
|
|
|||
21
tests/cli/test_gquota_command.py
Normal file
21
tests/cli/test_gquota_command.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def test_gquota_uses_chat_console_when_tui_is_live():
|
||||
from agent.google_oauth import GoogleOAuthError
|
||||
from cli import HermesCLI
|
||||
|
||||
cli = HermesCLI.__new__(HermesCLI)
|
||||
cli.console = MagicMock()
|
||||
cli._app = object()
|
||||
|
||||
live_console = MagicMock()
|
||||
|
||||
with patch("cli.ChatConsole", return_value=live_console), \
|
||||
patch("agent.google_oauth.get_valid_access_token", side_effect=GoogleOAuthError("No Google OAuth credentials found")), \
|
||||
patch("agent.google_oauth.load_credentials", return_value=None), \
|
||||
patch("agent.google_code_assist.retrieve_user_quota"):
|
||||
cli._handle_gquota_command("/gquota")
|
||||
|
||||
assert live_console.print.call_count == 2
|
||||
cli.console.print.assert_not_called()
|
||||
|
|
@ -33,6 +33,20 @@ class TestCLIQuickCommands:
|
|||
printed = self._printed_plain(cli.console.print.call_args[0][0])
|
||||
assert printed == "daily-note"
|
||||
|
||||
def test_exec_command_uses_chat_console_when_tui_is_live(self):
|
||||
cli = self._make_cli({"dn": {"type": "exec", "command": "echo daily-note"}})
|
||||
cli._app = object()
|
||||
live_console = MagicMock()
|
||||
|
||||
with patch("cli.ChatConsole", return_value=live_console):
|
||||
result = cli.process_command("/dn")
|
||||
|
||||
assert result is True
|
||||
live_console.print.assert_called_once()
|
||||
printed = self._printed_plain(live_console.print.call_args[0][0])
|
||||
assert printed == "daily-note"
|
||||
cli.console.print.assert_not_called()
|
||||
|
||||
def test_exec_command_stderr_shown_on_no_stdout(self):
|
||||
cli = self._make_cli({"err": {"type": "exec", "command": "echo error >&2"}})
|
||||
result = cli.process_command("/err")
|
||||
|
|
|
|||
|
|
@ -473,6 +473,7 @@ class TestInlineThinkBlockExtraction(unittest.TestCase):
|
|||
agent.verbose_logging = False
|
||||
agent.reasoning_callback = None
|
||||
agent.stream_delta_callback = None # non-streaming by default
|
||||
agent._stream_callback = None # non-streaming by default
|
||||
return agent
|
||||
|
||||
def test_single_think_block_extracted(self):
|
||||
|
|
@ -619,6 +620,7 @@ class TestReasoningDeltasFiredFlag(unittest.TestCase):
|
|||
agent = AIAgent.__new__(AIAgent)
|
||||
agent.reasoning_callback = None
|
||||
agent.stream_delta_callback = None
|
||||
agent._stream_callback = None
|
||||
agent.verbose_logging = False
|
||||
return agent
|
||||
|
||||
|
|
|
|||
|
|
@ -344,6 +344,127 @@ class TestDisplayResumedHistory:
|
|||
assert "Just thinking" not in output
|
||||
assert "Hi there!" in output
|
||||
|
||||
def test_think_tags_stripped(self):
|
||||
"""<think>...</think> blocks should be stripped from display (#11316)."""
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "Solve this"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<think>\nI need to reason carefully here.\n</think>\n\nThe answer is 7.",
|
||||
},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "<think>" not in output
|
||||
assert "</think>" not in output
|
||||
assert "I need to reason carefully here" not in output
|
||||
assert "The answer is 7" in output
|
||||
|
||||
def test_thinking_tags_stripped(self):
|
||||
"""<thinking>...</thinking> blocks should be stripped from display."""
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<thinking>\nLet me compute: 2 + 2 = 4\n</thinking>\n\nThe answer is 4.",
|
||||
},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "<thinking>" not in output
|
||||
assert "Let me compute" not in output
|
||||
assert "The answer is 4" in output
|
||||
|
||||
def test_reasoning_tags_stripped(self):
|
||||
"""<reasoning>...</reasoning> blocks should be stripped from display."""
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "Explain gravity"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
"<reasoning>\nGravity is a fundamental force...\n</reasoning>\n\n"
|
||||
"Gravity pulls objects together."
|
||||
),
|
||||
},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "<reasoning>" not in output
|
||||
assert "fundamental force" not in output
|
||||
assert "Gravity pulls objects together" in output
|
||||
|
||||
def test_thought_tags_stripped(self):
|
||||
"""<thought>...</thought> blocks (Gemma 4) should be stripped."""
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "Say hello"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<thought>\nInternal thought here.\n</thought>\n\nHello!",
|
||||
},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "<thought>" not in output
|
||||
assert "Internal thought here" not in output
|
||||
assert "Hello!" in output
|
||||
|
||||
def test_unclosed_think_tag_stripped(self):
|
||||
"""Unclosed <think> (truncated generation) should not leak reasoning."""
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "Truncated response"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Some text before.\n<think>\nUnfinished reasoning...",
|
||||
},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "<think>" not in output
|
||||
assert "Unfinished reasoning" not in output
|
||||
assert "Some text before" in output
|
||||
|
||||
def test_multiple_reasoning_blocks_all_stripped(self):
|
||||
"""Multiple interleaved reasoning blocks are all stripped."""
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "Complex question"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
"<think>\nFirst thought.\n</think>\n"
|
||||
"Partial text.\n"
|
||||
"<reasoning>\nSecond thought.\n</reasoning>\n"
|
||||
"Final answer."
|
||||
),
|
||||
},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "First thought" not in output
|
||||
assert "Second thought" not in output
|
||||
assert "Partial text" in output
|
||||
assert "Final answer" in output
|
||||
|
||||
def test_orphan_closing_think_tag_stripped(self):
|
||||
"""A stray </think> with no matching open should not render to user."""
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "Broken output"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "some leftover reasoning</think>Visible answer.",
|
||||
},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "</think>" not in output
|
||||
assert "Visible answer" in output
|
||||
|
||||
def test_assistant_with_text_and_tool_calls(self):
|
||||
"""When an assistant message has both text content AND tool_calls."""
|
||||
cli = _make_cli()
|
||||
|
|
|
|||
|
|
@ -1024,7 +1024,7 @@ class TestRunJobSkillBacked:
|
|||
"id": "multi-skill-job",
|
||||
"name": "multi skill test",
|
||||
"prompt": "Combine the results.",
|
||||
"skills": ["blogwatcher", "find-nearby"],
|
||||
"skills": ["blogwatcher", "maps"],
|
||||
}
|
||||
|
||||
fake_db = MagicMock()
|
||||
|
|
@ -1057,12 +1057,12 @@ class TestRunJobSkillBacked:
|
|||
assert error is None
|
||||
assert final_response == "ok"
|
||||
assert skill_view_mock.call_count == 2
|
||||
assert [call.args[0] for call in skill_view_mock.call_args_list] == ["blogwatcher", "find-nearby"]
|
||||
assert [call.args[0] for call in skill_view_mock.call_args_list] == ["blogwatcher", "maps"]
|
||||
|
||||
prompt_arg = mock_agent.run_conversation.call_args.args[0]
|
||||
assert prompt_arg.index("blogwatcher") < prompt_arg.index("find-nearby")
|
||||
assert prompt_arg.index("blogwatcher") < prompt_arg.index("maps")
|
||||
assert "Instructions for blogwatcher." in prompt_arg
|
||||
assert "Instructions for find-nearby." in prompt_arg
|
||||
assert "Instructions for maps." in prompt_arg
|
||||
assert "Combine the results." in prompt_arg
|
||||
|
||||
|
||||
|
|
@ -1175,6 +1175,180 @@ class TestBuildJobPromptSilentHint:
|
|||
assert system_pos < prompt_pos
|
||||
|
||||
|
||||
class TestParseWakeGate:
|
||||
"""Unit tests for _parse_wake_gate — pure function, no side effects."""
|
||||
|
||||
def test_empty_output_wakes(self):
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
assert _parse_wake_gate("") is True
|
||||
assert _parse_wake_gate(None) is True
|
||||
|
||||
def test_whitespace_only_wakes(self):
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
assert _parse_wake_gate(" \n\n \t\n") is True
|
||||
|
||||
def test_non_json_last_line_wakes(self):
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
assert _parse_wake_gate("hello world") is True
|
||||
assert _parse_wake_gate("line 1\nline 2\nplain text") is True
|
||||
|
||||
def test_json_non_dict_wakes(self):
|
||||
"""Bare arrays, numbers, strings must not be interpreted as a gate."""
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
assert _parse_wake_gate("[1, 2, 3]") is True
|
||||
assert _parse_wake_gate("42") is True
|
||||
assert _parse_wake_gate('"wakeAgent"') is True
|
||||
|
||||
def test_wake_gate_false_skips(self):
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
assert _parse_wake_gate('{"wakeAgent": false}') is False
|
||||
|
||||
def test_wake_gate_true_wakes(self):
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
assert _parse_wake_gate('{"wakeAgent": true}') is True
|
||||
|
||||
def test_wake_gate_missing_wakes(self):
|
||||
"""A JSON dict without a wakeAgent key defaults to waking."""
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
assert _parse_wake_gate('{"data": {"foo": "bar"}}') is True
|
||||
|
||||
def test_non_boolean_false_still_wakes(self):
|
||||
"""Only strict ``False`` skips — truthy/falsy shortcuts are too risky."""
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
assert _parse_wake_gate('{"wakeAgent": 0}') is True
|
||||
assert _parse_wake_gate('{"wakeAgent": null}') is True
|
||||
assert _parse_wake_gate('{"wakeAgent": ""}') is True
|
||||
|
||||
def test_only_last_non_empty_line_parsed(self):
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
multi = 'some log output\nmore output\n{"wakeAgent": false}'
|
||||
assert _parse_wake_gate(multi) is False
|
||||
|
||||
def test_trailing_blank_lines_ignored(self):
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
multi = '{"wakeAgent": false}\n\n\n'
|
||||
assert _parse_wake_gate(multi) is False
|
||||
|
||||
def test_non_last_json_line_does_not_gate(self):
|
||||
"""A JSON gate on an earlier line with plain text after it does NOT trigger."""
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
multi = '{"wakeAgent": false}\nactually this is the real output'
|
||||
assert _parse_wake_gate(multi) is True
|
||||
|
||||
|
||||
class TestRunJobWakeGate:
|
||||
"""Integration tests for run_job wake-gate short-circuit."""
|
||||
|
||||
def _make_job(self, name="wake-gate-test", script="check.py"):
|
||||
"""Minimal valid cron job dict for run_job."""
|
||||
return {
|
||||
"id": f"job_{name}",
|
||||
"name": name,
|
||||
"prompt": "Do a thing",
|
||||
"schedule": "*/5 * * * *",
|
||||
"script": script,
|
||||
}
|
||||
|
||||
def test_wake_false_skips_agent_and_returns_silent(self, caplog):
|
||||
"""When _run_job_script output ends with {wakeAgent: false}, the agent
|
||||
is not invoked and run_job returns the SILENT marker so delivery is
|
||||
suppressed."""
|
||||
from cron.scheduler import SILENT_MARKER
|
||||
import cron.scheduler as scheduler
|
||||
|
||||
with patch.object(scheduler, "_run_job_script",
|
||||
return_value=(True, '{"wakeAgent": false}')), \
|
||||
patch("run_agent.AIAgent") as agent_cls:
|
||||
success, doc, final, err = scheduler.run_job(self._make_job())
|
||||
|
||||
assert success is True
|
||||
assert err is None
|
||||
assert final == SILENT_MARKER
|
||||
assert "Script gate returned `wakeAgent=false`" in doc
|
||||
agent_cls.assert_not_called()
|
||||
|
||||
def test_wake_true_runs_agent_with_injected_output(self):
|
||||
"""When the script returns {wakeAgent: true, data: ...}, the agent is
|
||||
invoked and the data line still shows up in the prompt."""
|
||||
import cron.scheduler as scheduler
|
||||
|
||||
script_output = '{"wakeAgent": true, "data": {"new": 3}}'
|
||||
agent = MagicMock()
|
||||
agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "ok", "messages": []
|
||||
})
|
||||
with patch.object(scheduler, "_run_job_script",
|
||||
return_value=(True, script_output)), \
|
||||
patch("run_agent.AIAgent", return_value=agent) as agent_cls:
|
||||
success, doc, final, err = scheduler.run_job(self._make_job())
|
||||
|
||||
agent_cls.assert_called_once()
|
||||
# The script output should be visible in the prompt passed to
|
||||
# run_conversation.
|
||||
call_kwargs = agent.run_conversation.call_args
|
||||
prompt_arg = call_kwargs.args[0] if call_kwargs.args else call_kwargs.kwargs.get("user_message", "")
|
||||
assert script_output in prompt_arg
|
||||
assert success is True
|
||||
assert err is None
|
||||
|
||||
def test_script_runs_only_once_on_wake(self):
|
||||
"""Wake-true path must not re-run the script inside _build_job_prompt
|
||||
(script would execute twice otherwise, wasting work and risking
|
||||
double-side-effects)."""
|
||||
import cron.scheduler as scheduler
|
||||
|
||||
call_count = 0
|
||||
def _script_stub(path):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return (True, "regular output")
|
||||
|
||||
agent = MagicMock()
|
||||
agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "ok", "messages": []
|
||||
})
|
||||
with patch.object(scheduler, "_run_job_script", side_effect=_script_stub), \
|
||||
patch("run_agent.AIAgent", return_value=agent):
|
||||
scheduler.run_job(self._make_job())
|
||||
|
||||
assert call_count == 1, f"script ran {call_count}x, expected exactly 1"
|
||||
|
||||
def test_script_failure_does_not_trigger_gate(self):
|
||||
"""If _run_job_script returns success=False, the gate is NOT evaluated
|
||||
and the agent still runs (the failure is reported as context)."""
|
||||
import cron.scheduler as scheduler
|
||||
|
||||
# Malicious or broken script whose stderr happens to contain the
|
||||
# gate JSON — we must NOT honor it because ran_ok is False.
|
||||
agent = MagicMock()
|
||||
agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "ok", "messages": []
|
||||
})
|
||||
with patch.object(scheduler, "_run_job_script",
|
||||
return_value=(False, '{"wakeAgent": false}')), \
|
||||
patch("run_agent.AIAgent", return_value=agent) as agent_cls:
|
||||
success, doc, final, err = scheduler.run_job(self._make_job())
|
||||
|
||||
agent_cls.assert_called_once() # Agent DID wake despite the gate-like text
|
||||
|
||||
def test_no_script_path_runs_agent_normally(self):
|
||||
"""Regression: jobs without a script still work."""
|
||||
import cron.scheduler as scheduler
|
||||
|
||||
agent = MagicMock()
|
||||
agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "ok", "messages": []
|
||||
})
|
||||
job = self._make_job(script=None)
|
||||
job.pop("script", None)
|
||||
with patch.object(scheduler, "_run_job_script") as script_fn, \
|
||||
patch("run_agent.AIAgent", return_value=agent) as agent_cls:
|
||||
scheduler.run_job(job)
|
||||
|
||||
script_fn.assert_not_called()
|
||||
agent_cls.assert_called_once()
|
||||
|
||||
|
||||
class TestBuildJobPromptMissingSkill:
|
||||
"""Verify that a missing skill logs a warning and does not crash the job."""
|
||||
|
||||
|
|
|
|||
148
tests/gateway/test_cancel_background_drain.py
Normal file
148
tests/gateway/test_cancel_background_drain.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
"""Regression test: cancel_background_tasks must drain late-arrival tasks.
|
||||
|
||||
During gateway shutdown, a message arriving while
|
||||
cancel_background_tasks is mid-await can spawn a fresh
|
||||
_process_message_background task via handle_message, which is added
|
||||
to self._background_tasks. Without the re-drain loop, the subsequent
|
||||
_background_tasks.clear() drops the reference; the task runs
|
||||
untracked against a disconnecting adapter.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
class _StubAdapter(BasePlatformAdapter):
|
||||
async def connect(self):
|
||||
pass
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def send(self, chat_id, text, **kwargs):
|
||||
return None
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {}
|
||||
|
||||
|
||||
def _make_adapter():
|
||||
adapter = _StubAdapter(PlatformConfig(enabled=True, token="t"), Platform.TELEGRAM)
|
||||
adapter._send_with_retry = AsyncMock(return_value=None)
|
||||
return adapter
|
||||
|
||||
|
||||
def _event(text, cid="42"):
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=SessionSource(platform=Platform.TELEGRAM, chat_id=cid, chat_type="dm"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_background_tasks_drains_late_arrivals():
|
||||
"""A message that arrives during the gather window must be picked
|
||||
up by the re-drain loop, not leaked as an untracked task."""
|
||||
adapter = _make_adapter()
|
||||
sk = build_session_key(
|
||||
SessionSource(platform=Platform.TELEGRAM, chat_id="42", chat_type="dm")
|
||||
)
|
||||
|
||||
m1_started = asyncio.Event()
|
||||
m1_cleanup_running = asyncio.Event()
|
||||
m2_started = asyncio.Event()
|
||||
m2_cancelled = asyncio.Event()
|
||||
|
||||
async def handler(event):
|
||||
if event.text == "M1":
|
||||
m1_started.set()
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
m1_cleanup_running.set()
|
||||
# Widen the gather window with a shielded cleanup
|
||||
# delay so M2 can get injected during it.
|
||||
await asyncio.shield(asyncio.sleep(0.2))
|
||||
raise
|
||||
else: # M2 — the late arrival
|
||||
m2_started.set()
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
m2_cancelled.set()
|
||||
raise
|
||||
|
||||
adapter._message_handler = handler
|
||||
|
||||
# Spawn M1.
|
||||
await adapter.handle_message(_event("M1"))
|
||||
await asyncio.wait_for(m1_started.wait(), timeout=1.0)
|
||||
|
||||
# Kick off shutdown. This will cancel M1 and await its cleanup.
|
||||
cancel_task = asyncio.create_task(adapter.cancel_background_tasks())
|
||||
|
||||
# Wait until M1's cleanup is running (inside the shielded sleep).
|
||||
# This is the race window: cancel_task is awaiting gather, M1 is
|
||||
# shielded in cleanup, the _active_sessions entry has been cleared
|
||||
# by M1's own finally.
|
||||
await asyncio.wait_for(m1_cleanup_running.wait(), timeout=1.0)
|
||||
|
||||
# Clear the active-session entry (M1's finally hasn't fully run yet,
|
||||
# but in production the platform dispatcher would deliver a new
|
||||
# message that takes the no-active-session spawn path). For this
|
||||
# repro, make it deterministic.
|
||||
adapter._active_sessions.pop(sk, None)
|
||||
|
||||
# Inject late arrival — spawns a fresh _process_message_background
|
||||
# task and adds it to _background_tasks while cancel_task is still
|
||||
# in gather.
|
||||
await adapter.handle_message(_event("M2"))
|
||||
await asyncio.wait_for(m2_started.wait(), timeout=1.0)
|
||||
|
||||
# Let cancel_task finish. Round 1's gather completes when M1's
|
||||
# shielded cleanup finishes. Round 2 should pick up M2.
|
||||
await asyncio.wait_for(cancel_task, timeout=5.0)
|
||||
|
||||
# Assert M2 was drained, not leaked.
|
||||
assert m2_cancelled.is_set(), (
|
||||
"Late-arrival M2 was NOT cancelled by cancel_background_tasks — "
|
||||
"the re-drain loop is missing and the task leaked"
|
||||
)
|
||||
assert adapter._background_tasks == set()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_background_tasks_handles_no_tasks():
|
||||
"""Regression guard: no tasks, no hang, no error."""
|
||||
adapter = _make_adapter()
|
||||
await adapter.cancel_background_tasks()
|
||||
assert adapter._background_tasks == set()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_background_tasks_bounded_rounds():
|
||||
"""Regression guard: the drain loop is bounded — it does not spin
|
||||
forever even if late-arrival tasks keep getting spawned."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
# Single well-behaved task that cancels cleanly — baseline check
|
||||
# that the loop terminates in one round.
|
||||
async def quick():
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
task = asyncio.create_task(quick())
|
||||
adapter._background_tasks.add(task)
|
||||
|
||||
await adapter.cancel_background_tasks()
|
||||
assert task.done()
|
||||
assert adapter._background_tasks == set()
|
||||
|
|
@ -200,6 +200,25 @@ class TestCommandBypassActiveSession:
|
|||
"/background response was not sent back to the user"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_steer_bypasses_guard(self):
|
||||
"""/steer must bypass the Level-1 active-session guard so it reaches
|
||||
the gateway runner's /steer handler and injects into the running
|
||||
agent instead of being queued as user text for the next turn.
|
||||
"""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/steer also check auth.log"))
|
||||
|
||||
assert sk not in adapter._pending_messages, (
|
||||
"/steer was queued as a pending message instead of being dispatched"
|
||||
)
|
||||
assert any("handled:steer" in r for r in adapter.sent_responses), (
|
||||
"/steer response was not sent back to the user"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_bypasses_guard(self):
|
||||
"""/help must bypass so it is not silently dropped as pending slash text."""
|
||||
|
|
@ -249,6 +268,82 @@ class TestCommandBypassActiveSession:
|
|||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: non-bypass-set commands (no dedicated Level-2 handler) also bypass
|
||||
# instead of interrupting + being discarded. Regression for the Discord
|
||||
# ghost-slash-command bug where /model, /reasoning, /voice, /insights, /title,
|
||||
# /resume, /retry, /undo, /compress, /usage, /provider, /reload-mcp,
|
||||
# /sethome, /reset silently interrupted the running agent.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAllResolvableCommandsBypassGuard:
|
||||
"""Every recognized slash command must bypass the Level-1 active-session
|
||||
guard. Without this, commands the user fires mid-run interrupt the agent
|
||||
AND get silently discarded by the slash-command safety net (zero-char
|
||||
response)."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"command_text,canonical",
|
||||
[
|
||||
("/model claude-sonnet-4", "model"),
|
||||
("/model", "model"),
|
||||
("/reasoning high", "reasoning"),
|
||||
("/personality default", "personality"),
|
||||
("/voice on", "voice"),
|
||||
("/insights 7", "insights"),
|
||||
("/title my session", "title"),
|
||||
("/resume yesterday", "resume"),
|
||||
("/retry", "retry"),
|
||||
("/undo", "undo"),
|
||||
("/compress", "compress"),
|
||||
("/usage", "usage"),
|
||||
("/provider", "provider"),
|
||||
("/reload-mcp", "reload-mcp"),
|
||||
("/sethome", "sethome"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_bypasses_guard(self, command_text, canonical):
|
||||
"""Any resolvable slash command bypasses instead of being queued."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event(command_text))
|
||||
|
||||
assert sk not in adapter._pending_messages, (
|
||||
f"{command_text} was queued as pending — it should bypass the guard"
|
||||
)
|
||||
assert len(adapter.sent_responses) > 0, (
|
||||
f"{command_text} produced no response — it should be dispatched, "
|
||||
"not silently discarded"
|
||||
)
|
||||
|
||||
def test_should_bypass_returns_true_for_every_registered_command(self):
|
||||
"""Spot-check: the commands previously-broken on Discord all bypass."""
|
||||
from hermes_cli.commands import should_bypass_active_session
|
||||
|
||||
for cmd in (
|
||||
"model", "reasoning", "personality", "voice", "insights", "title",
|
||||
"resume", "retry", "undo", "compress", "usage", "provider",
|
||||
"reload-mcp", "sethome", "reset",
|
||||
):
|
||||
assert should_bypass_active_session(cmd) is True, (
|
||||
f"/{cmd} must bypass the active-session guard"
|
||||
)
|
||||
|
||||
def test_should_bypass_returns_false_for_unknown(self):
|
||||
"""Unknown words don't bypass — they get queued as user text."""
|
||||
from hermes_cli.commands import should_bypass_active_session
|
||||
|
||||
assert should_bypass_active_session("foobar") is False
|
||||
assert should_bypass_active_session(None) is False
|
||||
assert should_bypass_active_session("") is False
|
||||
# A file path split on whitespace: '/path/to/file.py' -> 'path/to/file.py'
|
||||
assert should_bypass_active_session("path/to/file.py") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: non-bypass messages still get queued
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
122
tests/gateway/test_discord_race_polish.py
Normal file
122
tests/gateway/test_discord_race_polish.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""Regression tests for the Discord adapter race-polish fix.
|
||||
|
||||
Two races are addressed:
|
||||
1. on_message allowlist check racing on_ready's _resolve_allowed_usernames
|
||||
resolution window. Username-based entries in DISCORD_ALLOWED_USERS
|
||||
appear in the set as raw strings for several seconds after
|
||||
connect/reconnect; author.id is always numeric, so legitimate users
|
||||
are silently rejected until resolution finishes.
|
||||
2. join_voice_channel check-and-connect: concurrent /voice channel
|
||||
invocations both see _voice_clients.get(guild_id) is None, both call
|
||||
channel.connect(), second raises ClientException ('Already connected').
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
|
||||
def _make_adapter():
|
||||
"""Bare DiscordAdapter for testing — object.__new__ pattern per AGENTS.md."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter._platform = Platform.DISCORD
|
||||
adapter.config = PlatformConfig(enabled=True, token="t")
|
||||
adapter._ready_event = asyncio.Event()
|
||||
adapter._allowed_user_ids = set()
|
||||
adapter._allowed_role_ids = set()
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_locks = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._client = MagicMock()
|
||||
return adapter
|
||||
|
||||
|
||||
class TestJoinVoiceSerialization:
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_joins_do_not_double_connect(self):
|
||||
"""Two concurrent join_voice_channel calls on the same guild
|
||||
must serialize through the per-guild lock — only ONE
|
||||
channel.connect() actually fires; the second sees the
|
||||
_voice_clients entry the first just installed."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
connect_count = [0]
|
||||
connect_event = asyncio.Event()
|
||||
|
||||
class FakeVC:
|
||||
def __init__(self, channel):
|
||||
self.channel = channel
|
||||
|
||||
def is_connected(self):
|
||||
return True
|
||||
|
||||
async def move_to(self, _channel):
|
||||
return None
|
||||
|
||||
async def disconnect(self):
|
||||
return None
|
||||
|
||||
async def slow_connect(self):
|
||||
connect_count[0] += 1
|
||||
# Widen the race window
|
||||
await connect_event.wait()
|
||||
return FakeVC(self)
|
||||
|
||||
channel = MagicMock()
|
||||
channel.id = 111
|
||||
channel.guild.id = 42
|
||||
channel.connect = lambda: slow_connect(channel)
|
||||
|
||||
# Swap out VoiceReceiver so it doesn't try to set up real audio
|
||||
from gateway.platforms import discord as discord_mod
|
||||
with patch.object(discord_mod, "VoiceReceiver", MagicMock(return_value=MagicMock(start=lambda: None))):
|
||||
with patch.object(discord_mod.asyncio, "ensure_future", lambda _c: asyncio.create_task(asyncio.sleep(0))):
|
||||
# Fire two joins concurrently
|
||||
t1 = asyncio.create_task(adapter.join_voice_channel(channel))
|
||||
t2 = asyncio.create_task(adapter.join_voice_channel(channel))
|
||||
# Let them run until they're blocked on our event
|
||||
await asyncio.sleep(0.05)
|
||||
# Release connect so both can finish
|
||||
connect_event.set()
|
||||
r1, r2 = await asyncio.gather(t1, t2)
|
||||
|
||||
assert connect_count[0] == 1, (
|
||||
f"Expected exactly 1 channel.connect() call, got {connect_count[0]} — "
|
||||
"per-guild voice lock is not serializing join_voice_channel"
|
||||
)
|
||||
assert r1 is True and r2 is True
|
||||
assert 42 in adapter._voice_clients
|
||||
|
||||
|
||||
class TestOnMessageWaitsForReadyEvent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_blocks_until_ready_event_set(self):
|
||||
"""A message arriving before on_ready finishes
|
||||
_resolve_allowed_usernames must wait, not proceed with a
|
||||
half-resolved allowlist."""
|
||||
# This is an integration-style check — we pull out the
|
||||
# on_message handler by asserting the source contains the
|
||||
# expected wait pattern. A full end-to-end test would require
|
||||
# setting up the discord.py client machinery, which is not
|
||||
# practical here.
|
||||
import inspect
|
||||
from gateway.platforms import discord as discord_mod
|
||||
|
||||
src = inspect.getsource(discord_mod.DiscordAdapter.connect)
|
||||
assert "_ready_event.is_set()" in src, (
|
||||
"on_message must gate on _ready_event so username-based "
|
||||
"allowlist entries are resolved before the allowlist check"
|
||||
)
|
||||
assert "await asyncio.wait_for(" in src and "_ready_event.wait()" in src, (
|
||||
"Expected asyncio.wait_for(_ready_event.wait(), timeout=...) "
|
||||
"pattern in on_message"
|
||||
)
|
||||
|
|
@ -645,3 +645,54 @@ def test_group_topic_chat_id_int_string_coercion():
|
|||
|
||||
assert event.auto_skill == "hermes-agent-dev"
|
||||
assert event.source.chat_topic == "Dev"
|
||||
|
||||
|
||||
# ── _build_message_event: from_user=None fallback in DMs ──
|
||||
|
||||
|
||||
def test_build_message_event_dm_from_user_none_falls_back_to_chat_id():
|
||||
"""When from_user is None in a DM, user_id should fall back to chat.id."""
|
||||
from gateway.platforms.base import MessageType
|
||||
|
||||
adapter = _make_adapter()
|
||||
msg = _make_mock_message(chat_id=12345, user_id=42, user_name="Alice")
|
||||
# Simulate from_user being None (edge case on fresh restart / forwarded msg)
|
||||
msg.from_user = None
|
||||
|
||||
event = adapter._build_message_event(msg, MessageType.TEXT)
|
||||
|
||||
# Should fall back to chat.id since chat_type is "dm"
|
||||
assert event.source.user_id == "12345"
|
||||
assert event.source.user_name == "Alice" # falls back to chat.full_name
|
||||
|
||||
|
||||
def test_build_message_event_group_from_user_none_stays_none():
|
||||
"""When from_user is None in a group, user_id should remain None."""
|
||||
from gateway.platforms.base import MessageType
|
||||
|
||||
adapter = _make_adapter()
|
||||
msg = _make_mock_message(
|
||||
chat_id=-1001234567890, chat_type=_ChatType.SUPERGROUP,
|
||||
user_id=42, user_name="Alice"
|
||||
)
|
||||
msg.from_user = None
|
||||
|
||||
event = adapter._build_message_event(msg, MessageType.TEXT)
|
||||
|
||||
# Groups should NOT fall back — anonymous senders stay None
|
||||
assert event.source.user_id is None
|
||||
assert event.source.user_name is None
|
||||
|
||||
|
||||
def test_build_message_event_dm_from_user_present_uses_user():
|
||||
"""When from_user is present in a DM, it should be used (no fallback)."""
|
||||
from gateway.platforms.base import MessageType
|
||||
|
||||
adapter = _make_adapter()
|
||||
msg = _make_mock_message(chat_id=12345, user_id=99999, user_name="Bob")
|
||||
|
||||
event = adapter._build_message_event(msg, MessageType.TEXT)
|
||||
|
||||
# Normal case — from_user is used directly
|
||||
assert event.source.user_id == "99999"
|
||||
assert event.source.user_name == "Bob"
|
||||
|
|
|
|||
|
|
@ -2370,6 +2370,134 @@ class TestAdapterBehavior(unittest.TestCase):
|
|||
elements = payload["zh_cn"]["content"][0]
|
||||
self.assertEqual(elements, [{"tag": "md", "text": "可以用 **粗体** 和 *斜体*。"}])
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_send_splits_fenced_code_blocks_into_separate_post_rows(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
captured = {}
|
||||
|
||||
class _MessageAPI:
|
||||
def create(self, request):
|
||||
captured["request"] = request
|
||||
return SimpleNamespace(
|
||||
success=lambda: True,
|
||||
data=SimpleNamespace(message_id="om_codeblock"),
|
||||
)
|
||||
|
||||
adapter._client = SimpleNamespace(
|
||||
im=SimpleNamespace(
|
||||
v1=SimpleNamespace(
|
||||
message=_MessageAPI(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def _direct(func, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
content = (
|
||||
"确认已入库 ✓\n"
|
||||
"文件路径:`/root/.hermes/profiles/agent_cto/cron/jobs.json`\n"
|
||||
"**解码后的内容:**\n"
|
||||
"```json\n"
|
||||
'{"cron": "list"}\n'
|
||||
"```\n"
|
||||
"后续说明仍应保留。"
|
||||
)
|
||||
|
||||
with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct):
|
||||
result = asyncio.run(
|
||||
adapter.send(
|
||||
chat_id="oc_chat",
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(result.success)
|
||||
self.assertEqual(captured["request"].request_body.msg_type, "post")
|
||||
payload = json.loads(captured["request"].request_body.content)
|
||||
rows = payload["zh_cn"]["content"]
|
||||
self.assertEqual(
|
||||
rows,
|
||||
[
|
||||
[
|
||||
{
|
||||
"tag": "md",
|
||||
"text": "确认已入库 ✓\n文件路径:`/root/.hermes/profiles/agent_cto/cron/jobs.json`\n**解码后的内容:**",
|
||||
}
|
||||
],
|
||||
[{"tag": "md", "text": "```json\n{\"cron\": \"list\"}\n```"}],
|
||||
[{"tag": "md", "text": "后续说明仍应保留。"}],
|
||||
],
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_build_post_payload_keeps_fence_like_code_lines_inside_code_block(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
payload = json.loads(
|
||||
adapter._build_post_payload(
|
||||
"before\n```python\n```oops\n```\nafter"
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
payload["zh_cn"]["content"],
|
||||
[
|
||||
[{"tag": "md", "text": "before"}],
|
||||
[{"tag": "md", "text": "```python\n```oops\n```"}],
|
||||
[{"tag": "md", "text": "after"}],
|
||||
],
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_build_post_payload_preserves_trailing_spaces_in_code_block(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
payload = json.loads(
|
||||
adapter._build_post_payload(
|
||||
"before\n```python\nline with two spaces \n```\nafter"
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
payload["zh_cn"]["content"],
|
||||
[
|
||||
[{"tag": "md", "text": "before"}],
|
||||
[{"tag": "md", "text": "```python\nline with two spaces \n```"}],
|
||||
[{"tag": "md", "text": "after"}],
|
||||
],
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_build_post_payload_splits_multiple_fenced_code_blocks(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
payload = json.loads(
|
||||
adapter._build_post_payload(
|
||||
"before\n```python\nprint(1)\n```\nmiddle\n```json\n{}\n```\nafter"
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
payload["zh_cn"]["content"],
|
||||
[
|
||||
[{"tag": "md", "text": "before"}],
|
||||
[{"tag": "md", "text": "```python\nprint(1)\n```"}],
|
||||
[{"tag": "md", "text": "middle"}],
|
||||
[{"tag": "md", "text": "```json\n{}\n```"}],
|
||||
[{"tag": "md", "text": "after"}],
|
||||
],
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_send_falls_back_to_text_when_post_payload_is_rejected(self):
|
||||
from gateway.config import PlatformConfig
|
||||
|
|
|
|||
212
tests/gateway/test_pending_drain_race.py
Normal file
212
tests/gateway/test_pending_drain_race.py
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
"""Regression tests: pending-drain + finally-cleanup races must not spawn
|
||||
duplicate agents OR silently drop messages that arrived during cleanup.
|
||||
|
||||
Two related races in gateway/platforms/base.py:_process_message_background:
|
||||
|
||||
1. Pending-drain path (previous line 1931):
|
||||
``del self._active_sessions[session_key]`` opened a window where a
|
||||
concurrent inbound message could pass the Level-1 guard, spawn its
|
||||
own _process_message_background, and run simultaneously with the
|
||||
recursive drain. Two agents on one session_key = duplicate responses.
|
||||
|
||||
2. Finally-cleanup path (previous line 1990-1991):
|
||||
Between the awaits in finally (typing_task, stop_typing) and the
|
||||
``del self._active_sessions[session_key]``, a new message could
|
||||
land in _pending_messages. The del ran anyway, and the message was
|
||||
silently dropped — user never got a reply.
|
||||
|
||||
Fix: keep the _active_sessions entry live across the turn chain and
|
||||
clear the Event instead of deleting; in finally, drain any
|
||||
late-arrival pending message by spawning a task instead of
|
||||
dropping it.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
)
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
class _StubAdapter(BasePlatformAdapter):
|
||||
async def connect(self):
|
||||
pass
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def send(self, chat_id, text, **kwargs):
|
||||
return None
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {}
|
||||
|
||||
|
||||
def _make_adapter():
|
||||
adapter = _StubAdapter(PlatformConfig(enabled=True, token="t"), Platform.TELEGRAM)
|
||||
adapter._send_with_retry = AsyncMock(return_value=None)
|
||||
return adapter
|
||||
|
||||
|
||||
def _make_event(text="hi", chat_id="42"):
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=SessionSource(platform=Platform.TELEGRAM, chat_id=chat_id, chat_type="dm"),
|
||||
)
|
||||
|
||||
|
||||
def _sk(chat_id="42"):
|
||||
return build_session_key(
|
||||
SessionSource(platform=Platform.TELEGRAM, chat_id=chat_id, chat_type="dm")
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pending_drain_keeps_active_session_guard_live():
|
||||
"""Fix for R5: during pending-drain cleanup, _active_sessions must stay
|
||||
populated so concurrent inbound messages can't spawn a duplicate
|
||||
_process_message_background. We only CLEAR the Event, never delete."""
|
||||
adapter = _make_adapter()
|
||||
sk = _sk()
|
||||
|
||||
# Register a slow handler so the agent is "mid-processing" when the
|
||||
# pending message arrives.
|
||||
first_started = asyncio.Event()
|
||||
release_first = asyncio.Event()
|
||||
|
||||
async def handler(event):
|
||||
first_started.set()
|
||||
await release_first.wait()
|
||||
return "done"
|
||||
|
||||
adapter._message_handler = handler
|
||||
|
||||
# Spawn M1 through handle_message.
|
||||
await adapter.handle_message(_make_event(text="M1"))
|
||||
|
||||
# Wait until M1 is actively running inside the handler.
|
||||
await asyncio.wait_for(first_started.wait(), timeout=1.0)
|
||||
|
||||
# Assert: session is active.
|
||||
assert sk in adapter._active_sessions
|
||||
active_event = adapter._active_sessions[sk]
|
||||
|
||||
# Simulate pending message (M2) queued while M1 runs.
|
||||
adapter._pending_messages[sk] = _make_event(text="M2")
|
||||
|
||||
# Release M1 — pending-drain block now runs. During its cleanup
|
||||
# awaits, _active_sessions[sk] must remain populated (same object
|
||||
# reference) so any M3 arriving in that window hits the busy-handler.
|
||||
release_first.set()
|
||||
|
||||
# Give the drain a moment to execute its .clear() + await typing_task
|
||||
# without letting it fully finish the recursive call.
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Across the drain transition, the Event object must be the SAME
|
||||
# reference (not replaced, not deleted). If del happened, the key
|
||||
# would be missing briefly; if a new Event was installed, the
|
||||
# identity would differ.
|
||||
assert sk in adapter._active_sessions, (
|
||||
"_active_sessions[session_key] was deleted during pending-drain — "
|
||||
"opens a window for duplicate-agent spawn"
|
||||
)
|
||||
assert adapter._active_sessions[sk] is active_event, (
|
||||
"_active_sessions[session_key] was replaced during pending-drain — "
|
||||
"the old Event may have waiters that now won't be signaled"
|
||||
)
|
||||
|
||||
# Finish drain.
|
||||
await asyncio.sleep(0.1)
|
||||
await adapter.cancel_background_tasks()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finally_cleanup_drains_late_arrival_pending():
|
||||
"""Fix for R6: if a message lands in _pending_messages during the
|
||||
finally-block cleanup awaits, the finally must spawn a drain task
|
||||
instead of deleting _active_sessions and dropping the message."""
|
||||
adapter = _make_adapter()
|
||||
sk = _sk()
|
||||
|
||||
processed = []
|
||||
|
||||
async def handler(event):
|
||||
processed.append(event.text)
|
||||
return "ok"
|
||||
|
||||
adapter._message_handler = handler
|
||||
|
||||
# Instrument stop_typing to inject a late-arrival pending message
|
||||
# during the finally-block await window. This exactly simulates the
|
||||
# R6 race: the message arrives after the response has been sent but
|
||||
# before _active_sessions is deleted.
|
||||
original_stop = adapter.stop_typing if hasattr(adapter, "stop_typing") else None
|
||||
|
||||
injected = {"done": False}
|
||||
|
||||
async def stop_typing_injects_pending(*args, **kwargs):
|
||||
# Yield so the injection happens mid-await.
|
||||
await asyncio.sleep(0)
|
||||
if not injected["done"]:
|
||||
adapter._pending_messages[sk] = _make_event(text="LATE")
|
||||
injected["done"] = True
|
||||
if original_stop:
|
||||
return await original_stop(*args, **kwargs)
|
||||
return None
|
||||
|
||||
adapter.stop_typing = stop_typing_injects_pending
|
||||
|
||||
# Send M1.
|
||||
await adapter.handle_message(_make_event(text="M1"))
|
||||
|
||||
# Drain: wait for M1 to finish and the late-drain task to process LATE.
|
||||
for _ in range(50): # up to ~0.5s
|
||||
if "LATE" in processed:
|
||||
break
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await adapter.cancel_background_tasks()
|
||||
|
||||
assert "M1" in processed, "M1 was not processed"
|
||||
assert "LATE" in processed, (
|
||||
"Late-arrival pending message was silently dropped — finally "
|
||||
"cleanup should have spawned a drain task"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_pending_cleans_up_normally():
|
||||
"""Regression guard: when no pending message exists, the finally
|
||||
block must still delete _active_sessions as before (no leak)."""
|
||||
adapter = _make_adapter()
|
||||
sk = _sk()
|
||||
|
||||
async def handler(event):
|
||||
return "ok"
|
||||
|
||||
adapter._message_handler = handler
|
||||
|
||||
await adapter.handle_message(_make_event(text="solo"))
|
||||
|
||||
# Wait for background task to finish.
|
||||
for _ in range(50):
|
||||
if sk not in adapter._active_sessions:
|
||||
break
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
assert sk not in adapter._active_sessions, (
|
||||
"_active_sessions was not cleaned up after a normal turn with no pending"
|
||||
)
|
||||
assert sk not in adapter._pending_messages
|
||||
|
||||
await adapter.cancel_background_tasks()
|
||||
|
|
@ -1,13 +1,18 @@
|
|||
"""Tests for the pending_event None guard in recursive _run_agent calls.
|
||||
"""Tests for pending follow-up extraction in recursive _run_agent calls.
|
||||
|
||||
When pending_event is None (Path B: pending comes from interrupt_message),
|
||||
accessing pending_event.channel_prompt previously raised AttributeError.
|
||||
This verifies the fix: channel_prompt is captured inside the
|
||||
`if pending_event is not None:` block and falls back to None otherwise.
|
||||
|
||||
Also verifies that internal control interrupt reasons like "Stop requested"
|
||||
do not get recycled into the pending-user-message follow-up path.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from gateway.run import _is_control_interrupt_message
|
||||
|
||||
|
||||
def _extract_channel_prompt(pending_event):
|
||||
"""Reproduce the fixed logic from gateway/run.py.
|
||||
|
|
@ -21,6 +26,15 @@ def _extract_channel_prompt(pending_event):
|
|||
return next_channel_prompt
|
||||
|
||||
|
||||
def _extract_pending_text(interrupted, pending_event, interrupt_message):
|
||||
"""Reproduce the fixed pending-text selection from gateway/run.py."""
|
||||
if interrupted and pending_event is None and interrupt_message:
|
||||
if _is_control_interrupt_message(interrupt_message):
|
||||
return None
|
||||
return interrupt_message
|
||||
return None
|
||||
|
||||
|
||||
class TestPendingEventNoneChannelPrompt:
|
||||
"""Guard against AttributeError when pending_event is None."""
|
||||
|
||||
|
|
@ -40,3 +54,19 @@ class TestPendingEventNoneChannelPrompt:
|
|||
event = SimpleNamespace()
|
||||
result = _extract_channel_prompt(event)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestControlInterruptMessages:
|
||||
"""Control interrupt reasons must not become follow-up user input."""
|
||||
|
||||
def test_stop_requested_is_not_treated_as_pending_user_message(self):
|
||||
result = _extract_pending_text(True, None, "Stop requested")
|
||||
assert result is None
|
||||
|
||||
def test_session_reset_requested_is_not_treated_as_pending_user_message(self):
|
||||
result = _extract_pending_text(True, None, "Session reset requested")
|
||||
assert result is None
|
||||
|
||||
def test_real_user_interrupt_message_still_requeues(self):
|
||||
result = _extract_pending_text(True, None, "actually use postgres instead")
|
||||
assert result == "actually use postgres instead"
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ def _make_runner(proxy_url=None):
|
|||
runner.config = MagicMock()
|
||||
runner.config.streaming = StreamingConfig()
|
||||
runner._running_agents = {}
|
||||
runner._session_run_generation = {}
|
||||
runner._session_model_overrides = {}
|
||||
runner._agent_cache = {}
|
||||
runner._agent_cache_lock = None
|
||||
|
|
@ -160,10 +161,12 @@ class TestRunAgentProxyDispatch:
|
|||
source=source,
|
||||
session_id="test-session-123",
|
||||
session_key="test-key",
|
||||
run_generation=7,
|
||||
)
|
||||
|
||||
assert result["final_response"] == "Hello from remote!"
|
||||
runner._run_agent_via_proxy.assert_called_once()
|
||||
assert runner._run_agent_via_proxy.call_args.kwargs["run_generation"] == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_skips_proxy_when_not_configured(self, monkeypatch):
|
||||
|
|
@ -370,6 +373,40 @@ class TestRunAgentViaProxy:
|
|||
assert "session_id" in result
|
||||
assert result["session_id"] == "sess-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_stale_generation_returns_empty_result(self, monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")
|
||||
monkeypatch.delenv("GATEWAY_PROXY_KEY", raising=False)
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
runner._session_run_generation["test-key"] = 2
|
||||
|
||||
resp = _FakeSSEResponse(
|
||||
status=200,
|
||||
sse_chunks=[
|
||||
'data: {"choices":[{"delta":{"content":"stale"}}]}\n\n',
|
||||
"data: [DONE]\n\n",
|
||||
],
|
||||
)
|
||||
session = _FakeSession(resp)
|
||||
|
||||
with patch("gateway.run._load_gateway_config", return_value={}):
|
||||
with _patch_aiohttp(session):
|
||||
with patch("aiohttp.ClientTimeout"):
|
||||
result = await runner._run_agent_via_proxy(
|
||||
message="hi",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="sess-123",
|
||||
session_key="test-key",
|
||||
run_generation=1,
|
||||
)
|
||||
|
||||
assert result["final_response"] == ""
|
||||
assert result["messages"] == []
|
||||
assert result["api_calls"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_auth_header_without_key(self, monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")
|
||||
|
|
|
|||
247
tests/gateway/test_restart_redelivery_dedup.py
Normal file
247
tests/gateway/test_restart_redelivery_dedup.py
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
"""Tests for /restart idempotency guard against Telegram update re-delivery.
|
||||
|
||||
When PTB's graceful-shutdown ACK call (the final `get_updates` on exit) fails
|
||||
with a network error, Telegram re-delivers the `/restart` message to the new
|
||||
gateway process. Without a dedup guard, the new gateway would process
|
||||
`/restart` again and immediately restart — a self-perpetuating loop.
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import gateway.run as gateway_run
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
|
||||
|
||||
|
||||
def _make_restart_event(update_id: int | None = 100) -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text="/restart",
|
||||
message_type=MessageType.TEXT,
|
||||
source=make_restart_source(),
|
||||
message_id="m1",
|
||||
platform_update_id=update_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restart_handler_writes_dedup_marker_with_update_id(tmp_path, monkeypatch):
|
||||
"""First /restart writes .restart_last_processed.json with the triggering update_id."""
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("INVOCATION_ID", raising=False)
|
||||
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner.request_restart = MagicMock(return_value=True)
|
||||
|
||||
event = _make_restart_event(update_id=12345)
|
||||
result = await runner._handle_restart_command(event)
|
||||
|
||||
assert "Restarting gateway" in result
|
||||
marker_path = tmp_path / ".restart_last_processed.json"
|
||||
assert marker_path.exists()
|
||||
data = json.loads(marker_path.read_text())
|
||||
assert data["platform"] == "telegram"
|
||||
assert data["update_id"] == 12345
|
||||
assert isinstance(data["requested_at"], (int, float))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redelivered_restart_with_same_update_id_is_ignored(tmp_path, monkeypatch):
|
||||
"""A /restart with update_id <= recorded marker is silently ignored as a redelivery."""
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("INVOCATION_ID", raising=False)
|
||||
|
||||
# Previous gateway recorded update_id=12345 a few seconds ago
|
||||
marker = tmp_path / ".restart_last_processed.json"
|
||||
marker.write_text(json.dumps({
|
||||
"platform": "telegram",
|
||||
"update_id": 12345,
|
||||
"requested_at": time.time() - 5,
|
||||
}))
|
||||
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner.request_restart = MagicMock()
|
||||
|
||||
event = _make_restart_event(update_id=12345) # same update_id → redelivery
|
||||
result = await runner._handle_restart_command(event)
|
||||
|
||||
assert result == "" # silently ignored
|
||||
runner.request_restart.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redelivered_restart_with_older_update_id_is_ignored(tmp_path, monkeypatch):
|
||||
"""update_id strictly LESS than the recorded one is also a redelivery."""
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("INVOCATION_ID", raising=False)
|
||||
|
||||
marker = tmp_path / ".restart_last_processed.json"
|
||||
marker.write_text(json.dumps({
|
||||
"platform": "telegram",
|
||||
"update_id": 12345,
|
||||
"requested_at": time.time() - 5,
|
||||
}))
|
||||
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner.request_restart = MagicMock()
|
||||
|
||||
event = _make_restart_event(update_id=12344) # older update — shouldn't happen,
|
||||
# but if Telegram does re-deliver
|
||||
# something older, treat as stale
|
||||
result = await runner._handle_restart_command(event)
|
||||
|
||||
assert result == ""
|
||||
runner.request_restart.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fresh_restart_with_higher_update_id_is_processed(tmp_path, monkeypatch):
|
||||
"""A NEW /restart from the user (higher update_id) bypasses the dedup guard."""
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("INVOCATION_ID", raising=False)
|
||||
|
||||
# Previous restart recorded update_id=12345
|
||||
marker = tmp_path / ".restart_last_processed.json"
|
||||
marker.write_text(json.dumps({
|
||||
"platform": "telegram",
|
||||
"update_id": 12345,
|
||||
"requested_at": time.time() - 5,
|
||||
}))
|
||||
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner.request_restart = MagicMock(return_value=True)
|
||||
|
||||
event = _make_restart_event(update_id=12346) # strictly higher → fresh
|
||||
result = await runner._handle_restart_command(event)
|
||||
|
||||
assert "Restarting gateway" in result
|
||||
runner.request_restart.assert_called_once()
|
||||
|
||||
# Marker is overwritten with the new update_id
|
||||
data = json.loads(marker.read_text())
|
||||
assert data["update_id"] == 12346
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stale_marker_older_than_5min_does_not_block(tmp_path, monkeypatch):
|
||||
"""A marker older than the 5-minute window is ignored — fresh /restart proceeds."""
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("INVOCATION_ID", raising=False)
|
||||
|
||||
marker = tmp_path / ".restart_last_processed.json"
|
||||
marker.write_text(json.dumps({
|
||||
"platform": "telegram",
|
||||
"update_id": 12345,
|
||||
"requested_at": time.time() - 600, # 10 minutes ago
|
||||
}))
|
||||
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner.request_restart = MagicMock(return_value=True)
|
||||
|
||||
# Same update_id as the stale marker, but the marker is too old to trust
|
||||
event = _make_restart_event(update_id=12345)
|
||||
result = await runner._handle_restart_command(event)
|
||||
|
||||
assert "Restarting gateway" in result
|
||||
runner.request_restart.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_marker_file_allows_restart(tmp_path, monkeypatch):
|
||||
"""Clean gateway start (no prior marker) processes /restart normally."""
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("INVOCATION_ID", raising=False)
|
||||
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner.request_restart = MagicMock(return_value=True)
|
||||
|
||||
event = _make_restart_event(update_id=100)
|
||||
result = await runner._handle_restart_command(event)
|
||||
|
||||
assert "Restarting gateway" in result
|
||||
runner.request_restart.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_corrupt_marker_file_is_treated_as_absent(tmp_path, monkeypatch):
|
||||
"""Malformed JSON in the marker file doesn't crash — /restart proceeds."""
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("INVOCATION_ID", raising=False)
|
||||
|
||||
marker = tmp_path / ".restart_last_processed.json"
|
||||
marker.write_text("not-json{")
|
||||
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner.request_restart = MagicMock(return_value=True)
|
||||
|
||||
event = _make_restart_event(update_id=100)
|
||||
result = await runner._handle_restart_command(event)
|
||||
|
||||
assert "Restarting gateway" in result
|
||||
runner.request_restart.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_without_update_id_bypasses_dedup(tmp_path, monkeypatch):
|
||||
"""Events with no platform_update_id (non-Telegram, CLI fallback) aren't gated."""
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("INVOCATION_ID", raising=False)
|
||||
|
||||
marker = tmp_path / ".restart_last_processed.json"
|
||||
marker.write_text(json.dumps({
|
||||
"platform": "telegram",
|
||||
"update_id": 999999,
|
||||
"requested_at": time.time(),
|
||||
}))
|
||||
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner.request_restart = MagicMock(return_value=True)
|
||||
|
||||
# No update_id — the dedup check should NOT kick in
|
||||
event = _make_restart_event(update_id=None)
|
||||
result = await runner._handle_restart_command(event)
|
||||
|
||||
assert "Restarting gateway" in result
|
||||
runner.request_restart.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_platform_bypasses_dedup(tmp_path, monkeypatch):
|
||||
"""Marker from Telegram doesn't block a /restart from another platform."""
|
||||
from gateway.config import Platform
|
||||
from gateway.session import SessionSource
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("INVOCATION_ID", raising=False)
|
||||
|
||||
marker = tmp_path / ".restart_last_processed.json"
|
||||
marker.write_text(json.dumps({
|
||||
"platform": "telegram",
|
||||
"update_id": 12345,
|
||||
"requested_at": time.time(),
|
||||
}))
|
||||
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner.request_restart = MagicMock(return_value=True)
|
||||
|
||||
# /restart from Discord — not a redelivery candidate
|
||||
discord_source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="discord-chan",
|
||||
chat_type="dm",
|
||||
user_id="u1",
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="/restart",
|
||||
message_type=MessageType.TEXT,
|
||||
source=discord_source,
|
||||
message_id="m1",
|
||||
platform_update_id=12345,
|
||||
)
|
||||
result = await runner._handle_restart_command(event)
|
||||
|
||||
assert "Restarting gateway" in result
|
||||
runner.request_restart.assert_called_once()
|
||||
688
tests/gateway/test_restart_resume_pending.py
Normal file
688
tests/gateway/test_restart_resume_pending.py
Normal file
|
|
@ -0,0 +1,688 @@
|
|||
"""Tests for the resume_pending session continuity path.
|
||||
|
||||
Covers the behaviour introduced to fix the ``Gateway shutting down ...
|
||||
task will be interrupted`` follow-up bug (spec: PR #11852, builds on
|
||||
PRs #9850, #9934, #7536):
|
||||
|
||||
1. When a gateway restart drain times out and agents are force-interrupted,
|
||||
the affected sessions are flagged ``resume_pending=True`` — not
|
||||
``suspended`` — so the next user message on the same session_key
|
||||
auto-resumes from the existing transcript instead of getting routed
|
||||
through ``suspend_recently_active()`` and converted into a fresh
|
||||
session.
|
||||
|
||||
2. ``suspended=True`` (from ``/stop`` or stuck-loop escalation) still
|
||||
wins over ``resume_pending`` — the forced-wipe path is preserved.
|
||||
|
||||
3. The restart-resume system note injected into the next user message is
|
||||
a superset of the existing tool-tail auto-continue note (from
|
||||
PR #9934), using session-entry metadata rather than just transcript
|
||||
shape so it fires even when the interrupted transcript does NOT end
|
||||
with a ``tool`` role.
|
||||
|
||||
4. The existing ``.restart_failure_counts`` stuck-loop counter from
|
||||
PR #7536 remains the single source of escalation — no parallel
|
||||
counter is added on ``SessionEntry``.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.session import SessionEntry, SessionSource, SessionStore
|
||||
from tests.gateway.restart_test_helpers import (
|
||||
make_restart_runner,
|
||||
make_restart_source,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_source(platform=Platform.TELEGRAM, chat_id="123", user_id="u1"):
|
||||
return SessionSource(platform=platform, chat_id=chat_id, user_id=user_id)
|
||||
|
||||
|
||||
def _make_store(tmp_path):
|
||||
return SessionStore(sessions_dir=tmp_path, config=GatewayConfig())
|
||||
|
||||
|
||||
def _simulate_note_injection(
|
||||
agent_history: list,
|
||||
user_message: str,
|
||||
resume_entry: SessionEntry | None,
|
||||
) -> str:
|
||||
"""Mirror the note-injection logic in gateway/run.py _run_agent().
|
||||
|
||||
Matches the production code in the ``run_sync`` closure so we can
|
||||
test the decision tree without a full gateway runner.
|
||||
"""
|
||||
message = user_message
|
||||
is_resume_pending = bool(
|
||||
resume_entry is not None and getattr(resume_entry, "resume_pending", False)
|
||||
)
|
||||
|
||||
if is_resume_pending:
|
||||
reason = getattr(resume_entry, "resume_reason", None) or "restart_timeout"
|
||||
reason_phrase = (
|
||||
"a gateway restart"
|
||||
if reason == "restart_timeout"
|
||||
else "a gateway shutdown"
|
||||
if reason == "shutdown_timeout"
|
||||
else "a gateway interruption"
|
||||
)
|
||||
message = (
|
||||
f"[System note: Your previous turn in this session was interrupted "
|
||||
f"by {reason_phrase}. The conversation history below is intact. "
|
||||
f"If it contains unfinished tool result(s), process them first and "
|
||||
f"summarize what was accomplished, then address the user's new "
|
||||
f"message below.]\n\n"
|
||||
+ message
|
||||
)
|
||||
elif agent_history and agent_history[-1].get("role") == "tool":
|
||||
message = (
|
||||
"[System note: Your previous turn was interrupted before you could "
|
||||
"process the last tool result(s). The conversation history contains "
|
||||
"tool outputs you haven't responded to yet. Please finish processing "
|
||||
"those results and summarize what was accomplished, then address the "
|
||||
"user's new message below.]\n\n"
|
||||
+ message
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SessionEntry field + serialization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSessionEntryResumeFields:
|
||||
def test_defaults(self):
|
||||
now = datetime.now()
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm:1",
|
||||
session_id="sid",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
assert entry.resume_pending is False
|
||||
assert entry.resume_reason is None
|
||||
assert entry.last_resume_marked_at is None
|
||||
|
||||
def test_roundtrip_with_resume_fields(self):
|
||||
now = datetime(2026, 4, 18, 12, 0, 0)
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm:1",
|
||||
session_id="sid",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
resume_pending=True,
|
||||
resume_reason="restart_timeout",
|
||||
last_resume_marked_at=now,
|
||||
)
|
||||
restored = SessionEntry.from_dict(entry.to_dict())
|
||||
assert restored.resume_pending is True
|
||||
assert restored.resume_reason == "restart_timeout"
|
||||
assert restored.last_resume_marked_at == now
|
||||
|
||||
def test_from_dict_legacy_without_resume_fields(self):
|
||||
"""Old sessions.json without the new fields deserialize cleanly."""
|
||||
now = datetime.now()
|
||||
legacy = {
|
||||
"session_key": "agent:main:telegram:dm:1",
|
||||
"session_id": "sid",
|
||||
"created_at": now.isoformat(),
|
||||
"updated_at": now.isoformat(),
|
||||
"chat_type": "dm",
|
||||
}
|
||||
restored = SessionEntry.from_dict(legacy)
|
||||
assert restored.resume_pending is False
|
||||
assert restored.resume_reason is None
|
||||
assert restored.last_resume_marked_at is None
|
||||
|
||||
def test_malformed_timestamp_is_tolerated(self):
|
||||
now = datetime.now()
|
||||
data = {
|
||||
"session_key": "k",
|
||||
"session_id": "sid",
|
||||
"created_at": now.isoformat(),
|
||||
"updated_at": now.isoformat(),
|
||||
"resume_pending": True,
|
||||
"resume_reason": "restart_timeout",
|
||||
"last_resume_marked_at": "not-a-timestamp",
|
||||
}
|
||||
restored = SessionEntry.from_dict(data)
|
||||
# resume_pending still honoured, only the broken timestamp drops
|
||||
assert restored.resume_pending is True
|
||||
assert restored.resume_reason == "restart_timeout"
|
||||
assert restored.last_resume_marked_at is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SessionStore.mark_resume_pending / clear_resume_pending
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMarkResumePending:
|
||||
def test_marks_existing_session(self, tmp_path):
|
||||
store = _make_store(tmp_path)
|
||||
source = _make_source()
|
||||
entry = store.get_or_create_session(source)
|
||||
|
||||
assert store.mark_resume_pending(entry.session_key) is True
|
||||
refreshed = store._entries[entry.session_key]
|
||||
assert refreshed.resume_pending is True
|
||||
assert refreshed.resume_reason == "restart_timeout"
|
||||
assert refreshed.last_resume_marked_at is not None
|
||||
|
||||
def test_custom_reason_persists(self, tmp_path):
|
||||
store = _make_store(tmp_path)
|
||||
source = _make_source()
|
||||
entry = store.get_or_create_session(source)
|
||||
|
||||
store.mark_resume_pending(entry.session_key, reason="shutdown_timeout")
|
||||
assert store._entries[entry.session_key].resume_reason == "shutdown_timeout"
|
||||
|
||||
def test_returns_false_for_unknown_key(self, tmp_path):
|
||||
store = _make_store(tmp_path)
|
||||
assert store.mark_resume_pending("no-such-key") is False
|
||||
|
||||
def test_does_not_override_suspended(self, tmp_path):
|
||||
"""suspended wins — mark_resume_pending is a no-op on a suspended entry."""
|
||||
store = _make_store(tmp_path)
|
||||
source = _make_source()
|
||||
entry = store.get_or_create_session(source)
|
||||
store.suspend_session(entry.session_key)
|
||||
|
||||
assert store.mark_resume_pending(entry.session_key) is False
|
||||
e = store._entries[entry.session_key]
|
||||
assert e.suspended is True
|
||||
assert e.resume_pending is False
|
||||
|
||||
def test_survives_roundtrip_through_json(self, tmp_path):
|
||||
store = _make_store(tmp_path)
|
||||
source = _make_source()
|
||||
entry = store.get_or_create_session(source)
|
||||
store.mark_resume_pending(entry.session_key, reason="restart_timeout")
|
||||
|
||||
# Reload from disk
|
||||
store2 = _make_store(tmp_path)
|
||||
store2._ensure_loaded()
|
||||
reloaded = store2._entries[entry.session_key]
|
||||
assert reloaded.resume_pending is True
|
||||
assert reloaded.resume_reason == "restart_timeout"
|
||||
|
||||
|
||||
class TestClearResumePending:
|
||||
def test_clears_flag(self, tmp_path):
|
||||
store = _make_store(tmp_path)
|
||||
source = _make_source()
|
||||
entry = store.get_or_create_session(source)
|
||||
store.mark_resume_pending(entry.session_key)
|
||||
|
||||
assert store.clear_resume_pending(entry.session_key) is True
|
||||
e = store._entries[entry.session_key]
|
||||
assert e.resume_pending is False
|
||||
assert e.resume_reason is None
|
||||
assert e.last_resume_marked_at is None
|
||||
|
||||
def test_returns_false_when_not_pending(self, tmp_path):
|
||||
store = _make_store(tmp_path)
|
||||
source = _make_source()
|
||||
entry = store.get_or_create_session(source)
|
||||
# Not marked
|
||||
assert store.clear_resume_pending(entry.session_key) is False
|
||||
|
||||
def test_returns_false_for_unknown_key(self, tmp_path):
|
||||
store = _make_store(tmp_path)
|
||||
assert store.clear_resume_pending("no-such-key") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SessionStore.get_or_create_session resume_pending behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetOrCreateResumePending:
|
||||
def test_resume_pending_preserves_session_id(self, tmp_path):
|
||||
"""This is THE core behavioural fix — resume_pending ≠ new session."""
|
||||
store = _make_store(tmp_path)
|
||||
source = _make_source()
|
||||
first = store.get_or_create_session(source)
|
||||
original_sid = first.session_id
|
||||
store.mark_resume_pending(first.session_key)
|
||||
|
||||
second = store.get_or_create_session(source)
|
||||
assert second.session_id == original_sid
|
||||
assert second.was_auto_reset is False
|
||||
assert second.auto_reset_reason is None
|
||||
# Flag is NOT cleared on read — only on successful turn completion.
|
||||
assert second.resume_pending is True
|
||||
|
||||
def test_suspended_still_creates_new_session(self, tmp_path):
|
||||
"""Regression guard — suspended must still force a clean slate."""
|
||||
store = _make_store(tmp_path)
|
||||
source = _make_source()
|
||||
first = store.get_or_create_session(source)
|
||||
original_sid = first.session_id
|
||||
store.suspend_session(first.session_key)
|
||||
|
||||
second = store.get_or_create_session(source)
|
||||
assert second.session_id != original_sid
|
||||
assert second.was_auto_reset is True
|
||||
assert second.auto_reset_reason == "suspended"
|
||||
|
||||
def test_suspended_overrides_resume_pending(self, tmp_path):
|
||||
"""Terminal escalation: a session that somehow has BOTH flags must
|
||||
behave like ``suspended`` — forced wipe + auto_reset_reason."""
|
||||
store = _make_store(tmp_path)
|
||||
source = _make_source()
|
||||
first = store.get_or_create_session(source)
|
||||
original_sid = first.session_id
|
||||
|
||||
# Force the pathological state directly (normally mark_resume_pending
|
||||
# refuses to run when suspended=True, but a stuck-loop escalation
|
||||
# can set suspended=True AFTER resume_pending is set).
|
||||
with store._lock:
|
||||
e = store._entries[first.session_key]
|
||||
e.resume_pending = True
|
||||
e.resume_reason = "restart_timeout"
|
||||
e.suspended = True
|
||||
store._save()
|
||||
|
||||
second = store.get_or_create_session(source)
|
||||
assert second.session_id != original_sid
|
||||
assert second.was_auto_reset is True
|
||||
assert second.auto_reset_reason == "suspended"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SessionStore.suspend_recently_active skip behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSuspendRecentlyActiveSkipsResumePending:
|
||||
def test_resume_pending_entries_not_suspended(self, tmp_path):
|
||||
store = _make_store(tmp_path)
|
||||
source = _make_source()
|
||||
entry = store.get_or_create_session(source)
|
||||
store.mark_resume_pending(entry.session_key)
|
||||
|
||||
count = store.suspend_recently_active()
|
||||
assert count == 0
|
||||
e = store._entries[entry.session_key]
|
||||
assert e.suspended is False
|
||||
assert e.resume_pending is True
|
||||
|
||||
def test_non_resume_pending_still_suspended(self, tmp_path):
|
||||
"""Non-resume sessions still get the old crash-recovery suspension."""
|
||||
store = _make_store(tmp_path)
|
||||
source_a = _make_source(chat_id="a")
|
||||
source_b = _make_source(chat_id="b")
|
||||
entry_a = store.get_or_create_session(source_a)
|
||||
entry_b = store.get_or_create_session(source_b)
|
||||
store.mark_resume_pending(entry_a.session_key)
|
||||
|
||||
count = store.suspend_recently_active()
|
||||
assert count == 1
|
||||
assert store._entries[entry_a.session_key].suspended is False
|
||||
assert store._entries[entry_b.session_key].suspended is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Restart-resume system-note injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResumePendingSystemNote:
|
||||
def _pending_entry(self, reason="restart_timeout") -> SessionEntry:
|
||||
now = datetime.now()
|
||||
return SessionEntry(
|
||||
session_key="agent:main:telegram:dm:1",
|
||||
session_id="sid",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
resume_pending=True,
|
||||
resume_reason=reason,
|
||||
last_resume_marked_at=now,
|
||||
)
|
||||
|
||||
def test_resume_pending_restart_note_mentions_restart(self):
|
||||
entry = self._pending_entry(reason="restart_timeout")
|
||||
result = _simulate_note_injection(
|
||||
agent_history=[{"role": "assistant", "content": "in progress"}],
|
||||
user_message="what happened?",
|
||||
resume_entry=entry,
|
||||
)
|
||||
assert "[System note:" in result
|
||||
assert "gateway restart" in result
|
||||
assert "what happened?" in result
|
||||
|
||||
def test_resume_pending_shutdown_note_mentions_shutdown(self):
|
||||
entry = self._pending_entry(reason="shutdown_timeout")
|
||||
result = _simulate_note_injection(
|
||||
agent_history=[{"role": "assistant", "content": "in progress"}],
|
||||
user_message="ping",
|
||||
resume_entry=entry,
|
||||
)
|
||||
assert "gateway shutdown" in result
|
||||
|
||||
def test_resume_pending_fires_without_tool_tail(self):
|
||||
"""Key improvement over PR #9934: the restart-resume note fires
|
||||
even when the transcript's last role is NOT ``tool``."""
|
||||
entry = self._pending_entry()
|
||||
history = [
|
||||
{"role": "user", "content": "run a long thing"},
|
||||
{"role": "assistant", "content": "ok, starting..."},
|
||||
]
|
||||
result = _simulate_note_injection(history, "ping", resume_entry=entry)
|
||||
assert "[System note:" in result
|
||||
assert "gateway restart" in result
|
||||
|
||||
def test_resume_pending_subsumes_tool_tail_note(self):
|
||||
"""When BOTH conditions are true, the restart-resume note wins —
|
||||
no duplicate notes."""
|
||||
entry = self._pending_entry()
|
||||
history = [
|
||||
{"role": "assistant", "content": None, "tool_calls": [
|
||||
{"id": "c1", "function": {"name": "x", "arguments": "{}"}},
|
||||
]},
|
||||
{"role": "tool", "tool_call_id": "c1", "content": "result"},
|
||||
]
|
||||
result = _simulate_note_injection(history, "ping", resume_entry=entry)
|
||||
assert result.count("[System note:") == 1
|
||||
assert "gateway restart" in result
|
||||
# Old tool-tail wording absent
|
||||
assert "haven't responded to yet" not in result
|
||||
|
||||
def test_no_resume_pending_preserves_tool_tail_note(self):
|
||||
"""Regression: the old PR #9934 tool-tail behaviour is unchanged."""
|
||||
history = [
|
||||
{"role": "assistant", "content": None, "tool_calls": [
|
||||
{"id": "c1", "function": {"name": "x", "arguments": "{}"}},
|
||||
]},
|
||||
{"role": "tool", "tool_call_id": "c1", "content": "result"},
|
||||
]
|
||||
result = _simulate_note_injection(history, "ping", resume_entry=None)
|
||||
assert "[System note:" in result
|
||||
assert "tool result" in result
|
||||
|
||||
def test_no_note_when_nothing_to_resume(self):
|
||||
history = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
]
|
||||
result = _simulate_note_injection(history, "ping", resume_entry=None)
|
||||
assert result == "ping"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Drain-timeout path marks sessions resume_pending
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_timeout_marks_resume_pending():
|
||||
"""End-to-end: a drain timeout during gateway stop should flag every
|
||||
active session as resume_pending BEFORE the interrupt fires, so the
|
||||
next startup's suspend_recently_active() does not destroy them."""
|
||||
runner, adapter = make_restart_runner()
|
||||
adapter.disconnect = AsyncMock()
|
||||
runner._restart_drain_timeout = 0.05
|
||||
|
||||
running_agent = MagicMock()
|
||||
session_key_one = "agent:main:telegram:dm:A"
|
||||
session_key_two = "agent:main:telegram:dm:B"
|
||||
runner._running_agents = {
|
||||
session_key_one: running_agent,
|
||||
session_key_two: MagicMock(),
|
||||
}
|
||||
|
||||
# Plug a mock session_store that records marks.
|
||||
session_store = MagicMock()
|
||||
session_store.mark_resume_pending = MagicMock(return_value=True)
|
||||
runner.session_store = session_store
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), patch(
|
||||
"gateway.status.write_runtime_status"
|
||||
):
|
||||
await runner.stop()
|
||||
|
||||
# Both active sessions were marked with the shutdown_timeout reason.
|
||||
calls = session_store.mark_resume_pending.call_args_list
|
||||
marked = {args[0][0] for args in calls}
|
||||
assert marked == {session_key_one, session_key_two}
|
||||
for args in calls:
|
||||
assert args[0][1] == "shutdown_timeout"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_timeout_uses_restart_reason_when_restarting():
|
||||
runner, adapter = make_restart_runner()
|
||||
adapter.disconnect = AsyncMock()
|
||||
runner._restart_drain_timeout = 0.05
|
||||
runner._restart_requested = True
|
||||
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents = {"agent:main:telegram:dm:A": running_agent}
|
||||
|
||||
session_store = MagicMock()
|
||||
session_store.mark_resume_pending = MagicMock(return_value=True)
|
||||
runner.session_store = session_store
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), patch(
|
||||
"gateway.status.write_runtime_status"
|
||||
):
|
||||
await runner.stop(restart=True, detached_restart=False, service_restart=True)
|
||||
|
||||
calls = session_store.mark_resume_pending.call_args_list
|
||||
assert calls, "expected at least one mark_resume_pending call"
|
||||
for args in calls:
|
||||
assert args[0][1] == "restart_timeout"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clean_drain_does_not_mark_resume_pending():
|
||||
"""If the drain completes within timeout (no force-interrupt), no
|
||||
sessions should be flagged — the normal shutdown path is unchanged."""
|
||||
runner, adapter = make_restart_runner()
|
||||
adapter.disconnect = AsyncMock()
|
||||
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents = {"agent:main:telegram:dm:A": running_agent}
|
||||
|
||||
# Finish the agent before the (generous) drain deadline
|
||||
async def finish_agent():
|
||||
await asyncio.sleep(0.05)
|
||||
runner._running_agents.clear()
|
||||
|
||||
asyncio.create_task(finish_agent())
|
||||
|
||||
session_store = MagicMock()
|
||||
session_store.mark_resume_pending = MagicMock(return_value=True)
|
||||
runner.session_store = session_store
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), patch(
|
||||
"gateway.status.write_runtime_status"
|
||||
):
|
||||
await runner.stop()
|
||||
|
||||
session_store.mark_resume_pending.assert_not_called()
|
||||
running_agent.interrupt.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_timeout_only_marks_still_running_sessions():
|
||||
"""A session that finished gracefully during the drain window must
|
||||
NOT be marked ``resume_pending`` — it completed cleanly and its
|
||||
next turn should be a normal fresh turn, not one prefixed with the
|
||||
restart-interruption system note.
|
||||
|
||||
Regression guard for using ``self._running_agents`` at timeout
|
||||
rather than the ``active_agents`` drain-start snapshot.
|
||||
"""
|
||||
runner, adapter = make_restart_runner()
|
||||
adapter.disconnect = AsyncMock()
|
||||
# Long enough for the finisher to exit, short enough to still time out
|
||||
# with the stuck session still present.
|
||||
runner._restart_drain_timeout = 0.3
|
||||
|
||||
session_key_finisher = "agent:main:telegram:dm:A"
|
||||
session_key_stuck = "agent:main:telegram:dm:B"
|
||||
runner._running_agents = {
|
||||
session_key_finisher: MagicMock(),
|
||||
session_key_stuck: MagicMock(),
|
||||
}
|
||||
|
||||
async def finish_one():
|
||||
await asyncio.sleep(0.05)
|
||||
runner._running_agents.pop(session_key_finisher, None)
|
||||
|
||||
asyncio.create_task(finish_one())
|
||||
|
||||
session_store = MagicMock()
|
||||
session_store.mark_resume_pending = MagicMock(return_value=True)
|
||||
runner.session_store = session_store
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), patch(
|
||||
"gateway.status.write_runtime_status"
|
||||
):
|
||||
await runner.stop()
|
||||
|
||||
calls = session_store.mark_resume_pending.call_args_list
|
||||
marked = {args[0][0] for args in calls}
|
||||
# Only the session still running at timeout is marked; the finisher is not.
|
||||
assert marked == {session_key_stuck}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_timeout_skips_pending_sentinel_sessions():
|
||||
"""Pending sentinels — sessions whose AIAgent construction hasn't
|
||||
produced a real agent yet — are skipped by
|
||||
``_interrupt_running_agents()``. The resume_pending marking must
|
||||
mirror that: no agent started means no turn was interrupted.
|
||||
"""
|
||||
from gateway.run import _AGENT_PENDING_SENTINEL
|
||||
|
||||
runner, adapter = make_restart_runner()
|
||||
adapter.disconnect = AsyncMock()
|
||||
runner._restart_drain_timeout = 0.05
|
||||
|
||||
session_key_real = "agent:main:telegram:dm:A"
|
||||
session_key_sentinel = "agent:main:telegram:dm:B"
|
||||
runner._running_agents = {
|
||||
session_key_real: MagicMock(),
|
||||
session_key_sentinel: _AGENT_PENDING_SENTINEL,
|
||||
}
|
||||
|
||||
session_store = MagicMock()
|
||||
session_store.mark_resume_pending = MagicMock(return_value=True)
|
||||
runner.session_store = session_store
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), patch(
|
||||
"gateway.status.write_runtime_status"
|
||||
):
|
||||
await runner.stop()
|
||||
|
||||
calls = session_store.mark_resume_pending.call_args_list
|
||||
marked = {args[0][0] for args in calls}
|
||||
assert marked == {session_key_real}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shutdown banner wording
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restart_banner_uses_try_to_resume_wording():
|
||||
"""The notification sent before drain should hedge the resume promise
|
||||
— the session-continuity fix is best-effort (stuck-loop counter can
|
||||
still escalate to suspended)."""
|
||||
runner, adapter = make_restart_runner()
|
||||
runner._restart_requested = True
|
||||
runner._running_agents["agent:main:telegram:dm:999"] = MagicMock()
|
||||
|
||||
await runner._notify_active_sessions_of_shutdown()
|
||||
|
||||
assert len(adapter.sent) == 1
|
||||
msg = adapter.sent[0]
|
||||
assert "restarting" in msg
|
||||
assert "try to resume" in msg
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stuck-loop escalation integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStuckLoopEscalation:
|
||||
"""The existing .restart_failure_counts counter (PR #7536) remains the
|
||||
single source of terminal escalation — no parallel counter on
|
||||
SessionEntry was added. After the configured threshold, the startup
|
||||
path flips suspended=True which overrides resume_pending."""
|
||||
|
||||
def test_escalation_via_stuck_loop_counter_overrides_resume_pending(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""Simulate a session that keeps getting restart-interrupted and
|
||||
hits the stuck-loop threshold: next startup should force it to
|
||||
fresh-session despite resume_pending being set."""
|
||||
import json
|
||||
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
store = _make_store(tmp_path)
|
||||
source = _make_source()
|
||||
entry = store.get_or_create_session(source)
|
||||
store.mark_resume_pending(entry.session_key, reason="restart_timeout")
|
||||
|
||||
# Simulate counter already at threshold (3 consecutive interrupted
|
||||
# restarts). _suspend_stuck_loop_sessions will flip suspended=True.
|
||||
counts_file = tmp_path / ".restart_failure_counts"
|
||||
counts_file.write_text(json.dumps({entry.session_key: 3}))
|
||||
|
||||
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.session_store = store
|
||||
|
||||
suspended_count = GatewayRunner._suspend_stuck_loop_sessions(runner)
|
||||
assert suspended_count == 1
|
||||
assert store._entries[entry.session_key].suspended is True
|
||||
# resume_pending is still set on the entry, but suspended wins in
|
||||
# get_or_create_session so the next message still gets a new sid.
|
||||
second = store.get_or_create_session(source)
|
||||
assert second.session_id != entry.session_id
|
||||
assert second.auto_reset_reason == "suspended"
|
||||
|
||||
def test_successful_turn_flow_clears_both_counter_and_resume_pending(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""The gateway's post-turn cleanup should clear both signals so a
|
||||
future restart-interrupt starts with a fresh counter."""
|
||||
import json
|
||||
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
store = _make_store(tmp_path)
|
||||
source = _make_source()
|
||||
entry = store.get_or_create_session(source)
|
||||
store.mark_resume_pending(entry.session_key, reason="restart_timeout")
|
||||
|
||||
counts_file = tmp_path / ".restart_failure_counts"
|
||||
counts_file.write_text(json.dumps({entry.session_key: 2}))
|
||||
|
||||
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.session_store = store
|
||||
|
||||
GatewayRunner._clear_restart_failure_count(runner, entry.session_key)
|
||||
store.clear_resume_pending(entry.session_key)
|
||||
|
||||
assert store._entries[entry.session_key].resume_pending is False
|
||||
assert not counts_file.exists()
|
||||
|
|
@ -51,6 +51,9 @@ class ProgressCaptureAdapter(BasePlatformAdapter):
|
|||
async def send_typing(self, chat_id, metadata=None) -> None:
|
||||
self.typing.append({"chat_id": chat_id, "metadata": metadata})
|
||||
|
||||
async def stop_typing(self, chat_id) -> None:
|
||||
self.typing.append({"chat_id": chat_id, "metadata": {"stopped": True}})
|
||||
|
||||
async def get_chat_info(self, chat_id: str):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
|
@ -90,6 +93,40 @@ class LongPreviewAgent:
|
|||
}
|
||||
|
||||
|
||||
class DelayedProgressAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
self.tool_progress_callback("tool.started", "terminal", "first command", {})
|
||||
time.sleep(0.45)
|
||||
self.tool_progress_callback("tool.started", "terminal", "second command", {})
|
||||
time.sleep(0.1)
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
class DelayedInterimAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
self.interim_assistant_callback("first interim")
|
||||
time.sleep(0.45)
|
||||
self.interim_assistant_callback("second interim")
|
||||
time.sleep(0.1)
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
def _make_runner(adapter):
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
GatewayRunner = gateway_run.GatewayRunner
|
||||
|
|
@ -104,6 +141,7 @@ def _make_runner(adapter):
|
|||
runner._fallback_model = None
|
||||
runner._session_db = None
|
||||
runner._running_agents = {}
|
||||
runner._session_run_generation = {}
|
||||
runner.hooks = SimpleNamespace(loaded_hooks=False)
|
||||
runner.config = SimpleNamespace(
|
||||
thread_sessions_per_user=False,
|
||||
|
|
@ -744,6 +782,154 @@ async def test_base_processing_releases_post_delivery_callback_after_main_send()
|
|||
assert released == [True]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_drops_tool_progress_after_generation_invalidation(monkeypatch, tmp_path):
|
||||
import yaml
|
||||
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
yaml.dump({"display": {"tool_progress": "all"}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = DelayedProgressAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
import tools.terminal_tool # noqa: F401 - register terminal tool metadata
|
||||
|
||||
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
|
||||
runner = _make_runner(adapter)
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="dm-1",
|
||||
chat_type="dm",
|
||||
thread_id=None,
|
||||
)
|
||||
session_key = "agent:main:discord:dm:dm-1"
|
||||
runner._session_run_generation[session_key] = 1
|
||||
|
||||
original_send = adapter.send
|
||||
invalidated = {"done": False}
|
||||
|
||||
async def send_and_invalidate(chat_id, content, reply_to=None, metadata=None):
|
||||
result = await original_send(chat_id, content, reply_to=reply_to, metadata=metadata)
|
||||
if "first command" in content and not invalidated["done"]:
|
||||
invalidated["done"] = True
|
||||
runner._invalidate_session_run_generation(session_key, reason="test_stop")
|
||||
return result
|
||||
|
||||
adapter.send = send_and_invalidate
|
||||
|
||||
result = await runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="sess-progress-stop",
|
||||
session_key=session_key,
|
||||
run_generation=1,
|
||||
)
|
||||
|
||||
all_progress_text = " ".join(call["content"] for call in adapter.sent)
|
||||
all_progress_text += " ".join(call["content"] for call in adapter.edits)
|
||||
assert result["final_response"] == "done"
|
||||
assert 'first command' in all_progress_text
|
||||
assert 'second command' not in all_progress_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_drops_interim_commentary_after_generation_invalidation(monkeypatch, tmp_path):
|
||||
import yaml
|
||||
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
yaml.dump({"display": {"tool_progress": "off", "interim_assistant_messages": True}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = DelayedInterimAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
|
||||
runner = _make_runner(adapter)
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="dm-2",
|
||||
chat_type="dm",
|
||||
thread_id=None,
|
||||
)
|
||||
session_key = "agent:main:discord:dm:dm-2"
|
||||
runner._session_run_generation[session_key] = 1
|
||||
|
||||
original_send = adapter.send
|
||||
invalidated = {"done": False}
|
||||
|
||||
async def send_and_invalidate(chat_id, content, reply_to=None, metadata=None):
|
||||
result = await original_send(chat_id, content, reply_to=reply_to, metadata=metadata)
|
||||
if content == "first interim" and not invalidated["done"]:
|
||||
invalidated["done"] = True
|
||||
runner._invalidate_session_run_generation(session_key, reason="test_stop")
|
||||
return result
|
||||
|
||||
adapter.send = send_and_invalidate
|
||||
|
||||
result = await runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="sess-commentary-stop",
|
||||
session_key=session_key,
|
||||
run_generation=1,
|
||||
)
|
||||
|
||||
sent_texts = [call["content"] for call in adapter.sent]
|
||||
assert result["final_response"] == "done"
|
||||
assert "first interim" in sent_texts
|
||||
assert "second interim" not in sent_texts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keep_typing_stops_immediately_when_interrupt_event_is_set():
|
||||
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
task = asyncio.create_task(
|
||||
adapter._keep_typing(
|
||||
"dm-typing-stop",
|
||||
interval=30.0,
|
||||
stop_event=stop_event,
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
stop_event.set()
|
||||
await asyncio.wait_for(task, timeout=0.5)
|
||||
|
||||
normal_typing_calls = [
|
||||
call for call in adapter.typing if call.get("metadata") != {"stopped": True}
|
||||
]
|
||||
stopped_calls = [
|
||||
call for call in adapter.typing if call.get("metadata") == {"stopped": True}
|
||||
]
|
||||
assert len(normal_typing_calls) == 1
|
||||
assert len(stopped_calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verbose_mode_does_not_truncate_args_by_default(monkeypatch, tmp_path):
|
||||
"""Verbose mode with default tool_preview_length (0) should NOT truncate args.
|
||||
|
|
|
|||
|
|
@ -319,3 +319,23 @@ async def test_start_gateway_replace_clears_marker_on_permission_denied(
|
|||
assert ok is False
|
||||
# Marker must NOT be left behind
|
||||
assert not (tmp_path / ".gateway-takeover.json").exists()
|
||||
|
||||
|
||||
def test_runner_warns_when_docker_gateway_lacks_explicit_output_mount(monkeypatch, tmp_path, caplog):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("TERMINAL_ENV", "docker")
|
||||
monkeypatch.setenv("TERMINAL_DOCKER_VOLUMES", '["/etc/localtime:/etc/localtime:ro"]')
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
GatewayRunner(config)
|
||||
|
||||
assert any(
|
||||
"host-visible output mount" in record.message
|
||||
for record in caplog.records
|
||||
)
|
||||
|
|
|
|||
59
tests/gateway/test_safe_adapter_disconnect.py
Normal file
59
tests/gateway/test_safe_adapter_disconnect.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
"""Regression tests: failed-connect path must call adapter.disconnect().
|
||||
|
||||
When adapter.connect() returns False or raises, the adapter may have
|
||||
allocated resources (aiohttp.ClientSession, poll tasks, child
|
||||
subprocesses) before giving up. Without a defensive disconnect() call
|
||||
these leak and surface as "Unclosed client session" warnings at
|
||||
process exit (seen on the 2026-04-18 18:08:16 gateway restart).
|
||||
|
||||
The fix: gateway/run.py wraps each adapter connect() with a safety-net
|
||||
call to _safe_adapter_disconnect() in the failure branches.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bare_runner():
|
||||
"""A GatewayRunner shell that only needs to support _safe_adapter_disconnect."""
|
||||
return object.__new__(GatewayRunner)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_safe_disconnect_calls_adapter_disconnect(bare_runner):
|
||||
"""The helper forwards to adapter.disconnect()."""
|
||||
adapter = MagicMock()
|
||||
adapter.disconnect = AsyncMock(return_value=None)
|
||||
|
||||
await bare_runner._safe_adapter_disconnect(adapter, Platform.TELEGRAM)
|
||||
|
||||
adapter.disconnect.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_safe_disconnect_swallows_exceptions(bare_runner):
|
||||
"""An exception in adapter.disconnect() must not propagate — the
|
||||
caller is already on an error path."""
|
||||
adapter = MagicMock()
|
||||
adapter.disconnect = AsyncMock(side_effect=RuntimeError("partial init"))
|
||||
|
||||
# Must NOT raise
|
||||
await bare_runner._safe_adapter_disconnect(adapter, Platform.TELEGRAM)
|
||||
|
||||
adapter.disconnect.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_safe_disconnect_handles_none_platform(bare_runner):
|
||||
"""Logging path must tolerate platform=None."""
|
||||
adapter = MagicMock()
|
||||
adapter.disconnect = AsyncMock(side_effect=ValueError("nope"))
|
||||
|
||||
await bare_runner._safe_adapter_disconnect(adapter, None)
|
||||
|
||||
adapter.disconnect.assert_awaited_once()
|
||||
|
|
@ -24,10 +24,18 @@ class _FakeAdapter:
|
|||
|
||||
def __init__(self):
|
||||
self._pending_messages = {}
|
||||
self._active_sessions = {}
|
||||
self.interrupted_sessions = []
|
||||
|
||||
async def send(self, chat_id, text, **kwargs):
|
||||
pass
|
||||
|
||||
async def interrupt_session_activity(self, session_key, chat_id):
|
||||
self.interrupted_sessions.append((session_key, chat_id))
|
||||
event = self._active_sessions.get(session_key)
|
||||
if event is not None:
|
||||
event.set()
|
||||
|
||||
|
||||
def _make_runner():
|
||||
runner = object.__new__(GatewayRunner)
|
||||
|
|
@ -37,6 +45,7 @@ def _make_runner():
|
|||
runner.adapters = {Platform.TELEGRAM: _FakeAdapter()}
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._session_run_generation = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._voice_mode = {}
|
||||
|
|
@ -81,7 +90,7 @@ async def test_sentinel_placed_before_agent_setup():
|
|||
# Patch _handle_message_with_agent to capture state at entry
|
||||
sentinel_was_set = False
|
||||
|
||||
async def mock_inner(self_inner, ev, src, qk):
|
||||
async def mock_inner(self_inner, ev, src, qk, generation):
|
||||
nonlocal sentinel_was_set
|
||||
sentinel_was_set = runner._running_agents.get(qk) is _AGENT_PENDING_SENTINEL
|
||||
return "ok"
|
||||
|
|
@ -105,7 +114,7 @@ async def test_sentinel_cleaned_up_after_handler_returns():
|
|||
event = _make_event()
|
||||
session_key = build_session_key(event.source)
|
||||
|
||||
async def mock_inner(self_inner, ev, src, qk):
|
||||
async def mock_inner(self_inner, ev, src, qk, generation):
|
||||
return "ok"
|
||||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
|
||||
|
|
@ -127,7 +136,7 @@ async def test_sentinel_cleaned_up_on_exception():
|
|||
event = _make_event()
|
||||
session_key = build_session_key(event.source)
|
||||
|
||||
async def mock_inner(self_inner, ev, src, qk):
|
||||
async def mock_inner(self_inner, ev, src, qk, generation):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
|
||||
|
|
@ -154,7 +163,7 @@ async def test_second_message_during_sentinel_queued_not_duplicate():
|
|||
|
||||
barrier = asyncio.Event()
|
||||
|
||||
async def slow_inner(self_inner, ev, src, qk):
|
||||
async def slow_inner(self_inner, ev, src, qk, generation):
|
||||
# Simulate slow setup — wait until test tells us to proceed
|
||||
await barrier.wait()
|
||||
return "ok"
|
||||
|
|
@ -333,7 +342,7 @@ async def test_stop_during_sentinel_force_cleans_session():
|
|||
|
||||
barrier = asyncio.Event()
|
||||
|
||||
async def slow_inner(self_inner, ev, src, qk):
|
||||
async def slow_inner(self_inner, ev, src, qk, generation):
|
||||
await barrier.wait()
|
||||
return "ok"
|
||||
|
||||
|
|
@ -381,6 +390,7 @@ async def test_stop_hard_kills_running_agent():
|
|||
fake_agent = MagicMock()
|
||||
fake_agent.get_activity_summary.return_value = {"seconds_since_activity": 0}
|
||||
runner._running_agents[session_key] = fake_agent
|
||||
runner.adapters[Platform.TELEGRAM]._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
# Send /stop
|
||||
stop_event = _make_event(text="/stop")
|
||||
|
|
@ -393,6 +403,10 @@ async def test_stop_hard_kills_running_agent():
|
|||
assert session_key not in runner._running_agents, (
|
||||
"/stop must remove the agent from _running_agents so the session is unlocked"
|
||||
)
|
||||
assert runner.adapters[Platform.TELEGRAM].interrupted_sessions == [
|
||||
(session_key, "12345")
|
||||
]
|
||||
assert runner.adapters[Platform.TELEGRAM]._active_sessions[session_key].is_set()
|
||||
|
||||
# Must return a confirmation
|
||||
assert result is not None
|
||||
|
|
|
|||
|
|
@ -740,3 +740,140 @@ class TestSignalStopTyping:
|
|||
await adapter.stop_typing("+155****4567")
|
||||
|
||||
adapter._stop_typing_indicator.assert_awaited_once_with("+155****4567")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Typing-indicator backoff on repeated failures (Signal RPC spam fix)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalTypingBackoff:
|
||||
"""When base.py's _keep_typing refresh loop calls send_typing every ~2s
|
||||
and the recipient is unreachable (NETWORK_FAILURE), the adapter must:
|
||||
|
||||
- log WARNING only for the first failure (subsequent failures use DEBUG
|
||||
via log_failures=False on the _rpc call)
|
||||
- after 3 consecutive failures, skip the RPC entirely during an
|
||||
exponential cooldown window instead of hammering signal-cli every 2s
|
||||
- reset counters on a successful sendTyping
|
||||
- reset counters when _stop_typing_indicator() is called for the chat
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_failure_logs_at_warning_subsequent_at_debug(
|
||||
self, monkeypatch
|
||||
):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
calls = []
|
||||
|
||||
async def _fake_rpc(method, params, rpc_id=None, *, log_failures=True):
|
||||
calls.append({"log_failures": log_failures})
|
||||
return None # simulate NETWORK_FAILURE
|
||||
|
||||
adapter._rpc = _fake_rpc
|
||||
|
||||
await adapter.send_typing("+155****4567")
|
||||
await adapter.send_typing("+155****4567")
|
||||
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["log_failures"] is True # first failure — warn
|
||||
assert calls[1]["log_failures"] is False # subsequent — debug
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_three_consecutive_failures_trigger_cooldown(
|
||||
self, monkeypatch
|
||||
):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def _fake_rpc(method, params, rpc_id=None, *, log_failures=True):
|
||||
call_count["n"] += 1
|
||||
return None
|
||||
|
||||
adapter._rpc = _fake_rpc
|
||||
|
||||
# Three failures engage the cooldown.
|
||||
await adapter.send_typing("+155****4567")
|
||||
await adapter.send_typing("+155****4567")
|
||||
await adapter.send_typing("+155****4567")
|
||||
assert call_count["n"] == 3
|
||||
assert "+155****4567" in adapter._typing_skip_until
|
||||
|
||||
# Fourth, fifth, ... calls during the cooldown window are short-
|
||||
# circuited — the RPC is not issued at all.
|
||||
await adapter.send_typing("+155****4567")
|
||||
await adapter.send_typing("+155****4567")
|
||||
assert call_count["n"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cooldown_is_per_chat_not_global(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
call_log = []
|
||||
|
||||
async def _fake_rpc(method, params, rpc_id=None, *, log_failures=True):
|
||||
call_log.append(params.get("recipient") or params.get("groupId"))
|
||||
return None
|
||||
|
||||
adapter._rpc = _fake_rpc
|
||||
|
||||
# Drive chat A into cooldown.
|
||||
for _ in range(3):
|
||||
await adapter.send_typing("+155****4567")
|
||||
assert "+155****4567" in adapter._typing_skip_until
|
||||
|
||||
# Chat B is unaffected — still makes RPCs.
|
||||
await adapter.send_typing("+155****9999")
|
||||
await adapter.send_typing("+155****9999")
|
||||
assert "+155****9999" not in adapter._typing_skip_until
|
||||
# Chat A cooldown untouched
|
||||
assert "+155****4567" in adapter._typing_skip_until
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_resets_failure_counter_and_cooldown(
|
||||
self, monkeypatch
|
||||
):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
result_queue = [None, None, {"timestamp": 12345}]
|
||||
call_log = []
|
||||
|
||||
async def _fake_rpc(method, params, rpc_id=None, *, log_failures=True):
|
||||
call_log.append(log_failures)
|
||||
return result_queue.pop(0)
|
||||
|
||||
adapter._rpc = _fake_rpc
|
||||
|
||||
await adapter.send_typing("+155****4567") # fail 1 — warn
|
||||
await adapter.send_typing("+155****4567") # fail 2 — debug
|
||||
await adapter.send_typing("+155****4567") # success — reset
|
||||
|
||||
assert adapter._typing_failures.get("+155****4567", 0) == 0
|
||||
assert "+155****4567" not in adapter._typing_skip_until
|
||||
|
||||
# Next failure after recovery logs at WARNING again (fresh counter).
|
||||
async def _fail(method, params, rpc_id=None, *, log_failures=True):
|
||||
call_log.append(log_failures)
|
||||
return None
|
||||
|
||||
adapter._rpc = _fail
|
||||
await adapter.send_typing("+155****4567")
|
||||
assert call_log[-1] is True # first failure in a fresh cycle
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_typing_indicator_clears_backoff_state(
|
||||
self, monkeypatch
|
||||
):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
|
||||
async def _fail(method, params, rpc_id=None, *, log_failures=True):
|
||||
return None
|
||||
|
||||
adapter._rpc = _fail
|
||||
|
||||
for _ in range(3):
|
||||
await adapter.send_typing("+155****4567")
|
||||
assert adapter._typing_failures.get("+155****4567") == 3
|
||||
assert "+155****4567" in adapter._typing_skip_until
|
||||
|
||||
await adapter._stop_typing_indicator("+155****4567")
|
||||
|
||||
assert "+155****4567" not in adapter._typing_failures
|
||||
assert "+155****4567" not in adapter._typing_skip_until
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ def _make_runner(session_entry: SessionEntry):
|
|||
runner.session_store.rewrite_transcript = MagicMock()
|
||||
runner.session_store.update_session = MagicMock()
|
||||
runner._running_agents = {}
|
||||
runner._session_run_generation = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = MagicMock()
|
||||
|
|
@ -223,6 +224,121 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_discards_stale_result_after_session_invalidation(monkeypatch):
|
||||
import gateway.run as gateway_run
|
||||
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
|
||||
session_key = session_entry.session_key
|
||||
runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks = {session_key: object()}
|
||||
|
||||
async def _stale_result(**kwargs):
|
||||
runner._invalidate_session_run_generation(kwargs["session_key"], reason="test_stale_result")
|
||||
return {
|
||||
"final_response": "late reply",
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"history_offset": 0,
|
||||
"last_prompt_tokens": 80,
|
||||
"input_tokens": 120,
|
||||
"output_tokens": 45,
|
||||
"model": "openai/test-model",
|
||||
}
|
||||
|
||||
runner._run_agent = AsyncMock(side_effect=_stale_result)
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
monkeypatch.setattr(
|
||||
"agent.model_metadata.get_model_context_length",
|
||||
lambda *_args, **_kwargs: 100000,
|
||||
)
|
||||
|
||||
result = await runner._handle_message(_make_event("hello"))
|
||||
|
||||
assert result is None
|
||||
runner.session_store.append_to_transcript.assert_not_called()
|
||||
runner.session_store.update_session.assert_not_called()
|
||||
assert session_key not in runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_stale_result_keeps_newer_generation_callback(monkeypatch):
|
||||
import gateway.run as gateway_run
|
||||
|
||||
class _Adapter:
|
||||
def __init__(self):
|
||||
self._post_delivery_callbacks = {}
|
||||
|
||||
async def send(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
def pop_post_delivery_callback(self, session_key, *, generation=None):
|
||||
entry = self._post_delivery_callbacks.get(session_key)
|
||||
if entry is None:
|
||||
return None
|
||||
if isinstance(entry, tuple):
|
||||
entry_generation, callback = entry
|
||||
if generation is not None and entry_generation != generation:
|
||||
return None
|
||||
self._post_delivery_callbacks.pop(session_key, None)
|
||||
return callback
|
||||
if generation is not None:
|
||||
return None
|
||||
return self._post_delivery_callbacks.pop(session_key, None)
|
||||
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
|
||||
session_key = session_entry.session_key
|
||||
adapter = _Adapter()
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
|
||||
async def _stale_result(**kwargs):
|
||||
# Simulate a newer run claiming the callback slot before the stale run unwinds.
|
||||
runner._session_run_generation[session_key] = 2
|
||||
adapter._post_delivery_callbacks[session_key] = (2, lambda: None)
|
||||
return {
|
||||
"final_response": "late reply",
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"history_offset": 0,
|
||||
"last_prompt_tokens": 80,
|
||||
"input_tokens": 120,
|
||||
"output_tokens": 45,
|
||||
"model": "openai/test-model",
|
||||
}
|
||||
|
||||
runner._run_agent = AsyncMock(side_effect=_stale_result)
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
monkeypatch.setattr(
|
||||
"agent.model_metadata.get_model_context_length",
|
||||
lambda *_args, **_kwargs: 100000,
|
||||
)
|
||||
|
||||
result = await runner._handle_message(_make_event("hello"))
|
||||
|
||||
assert result is None
|
||||
assert session_key in adapter._post_delivery_callbacks
|
||||
assert adapter._post_delivery_callbacks[session_key][0] == 2
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_command_bypasses_active_session_guard():
|
||||
|
|
|
|||
191
tests/gateway/test_steer_command.py
Normal file
191
tests/gateway/test_steer_command.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
"""Tests for the gateway /steer command handler.
|
||||
|
||||
/steer injects a user message into the agent's next tool result without
|
||||
interrupting. The gateway runner must:
|
||||
|
||||
1. When an agent IS running → call ``agent.steer(text)``, do NOT set
|
||||
``_interrupt_requested``, do NOT touch ``_pending_messages``.
|
||||
2. When the agent is the PENDING sentinel → fall back to /queue
|
||||
semantics (store in ``adapter._pending_messages``).
|
||||
3. When no agent is active → strip the slash prefix and let the normal
|
||||
prompt pipeline handle it as a regular user message.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_source() -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def _make_event(text: str) -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
source=_make_source(),
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
|
||||
def _make_runner(session_entry: SessionEntry):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock()
|
||||
adapter._pending_messages = {}
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = session_entry
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner.session_store.has_any_sessions.return_value = True
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.get_session_title.return_value = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._show_reasoning = False
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner._set_session_env = lambda _context: None
|
||||
runner._should_send_voice_reply = lambda *_args, **_kwargs: False
|
||||
runner._send_voice_reply = AsyncMock()
|
||||
runner._capture_gateway_honcho_if_configured = lambda *args, **kwargs: None
|
||||
runner._emit_gateway_run_progress = AsyncMock()
|
||||
return runner, adapter
|
||||
|
||||
|
||||
def _session_entry() -> SessionEntry:
|
||||
return SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
total_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_steer_calls_agent_steer_and_does_not_interrupt():
|
||||
"""When an agent is running, /steer must call agent.steer(text) and
|
||||
leave interrupt state untouched."""
|
||||
runner, adapter = _make_runner(_session_entry())
|
||||
sk = build_session_key(_make_source())
|
||||
|
||||
running_agent = MagicMock()
|
||||
running_agent.steer.return_value = True
|
||||
runner._running_agents[sk] = running_agent
|
||||
|
||||
result = await runner._handle_message(_make_event("/steer also check auth.log"))
|
||||
|
||||
# The handler replied with a confirmation
|
||||
assert result is not None
|
||||
assert "steer" in result.lower() or "queued" in result.lower()
|
||||
# The agent's steer() was called with the payload (prefix stripped)
|
||||
running_agent.steer.assert_called_once_with("also check auth.log")
|
||||
# Critically: interrupt was NOT called
|
||||
running_agent.interrupt.assert_not_called()
|
||||
# And no user-text queueing happened — the steer doesn't go into
|
||||
# _pending_messages (that would be turn-boundary /queue semantics).
|
||||
assert runner._pending_messages == {}
|
||||
assert adapter._pending_messages == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_steer_without_payload_returns_usage():
|
||||
runner, _adapter = _make_runner(_session_entry())
|
||||
sk = build_session_key(_make_source())
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents[sk] = running_agent
|
||||
|
||||
result = await runner._handle_message(_make_event("/steer"))
|
||||
|
||||
assert result is not None
|
||||
assert "Usage" in result or "usage" in result
|
||||
running_agent.steer.assert_not_called()
|
||||
running_agent.interrupt.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_steer_with_pending_sentinel_falls_back_to_queue():
|
||||
"""When the agent hasn't finished booting (sentinel), /steer should
|
||||
queue as a turn-boundary follow-up instead of crashing."""
|
||||
from gateway.run import _AGENT_PENDING_SENTINEL
|
||||
|
||||
runner, adapter = _make_runner(_session_entry())
|
||||
sk = build_session_key(_make_source())
|
||||
runner._running_agents[sk] = _AGENT_PENDING_SENTINEL
|
||||
|
||||
result = await runner._handle_message(_make_event("/steer wait up"))
|
||||
|
||||
assert result is not None
|
||||
assert "queued" in result.lower() or "starting" in result.lower()
|
||||
# The fallback put the text into the adapter's pending queue.
|
||||
assert sk in adapter._pending_messages
|
||||
assert adapter._pending_messages[sk].text == "wait up"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_steer_agent_without_steer_method_falls_back():
|
||||
"""If the running agent somehow lacks the steer() method (older build,
|
||||
test stub), the handler must not explode — fall back to /queue."""
|
||||
runner, adapter = _make_runner(_session_entry())
|
||||
sk = build_session_key(_make_source())
|
||||
|
||||
# A bare object that does NOT have steer() — use a spec'd Mock so
|
||||
# hasattr(agent, "steer") returns False.
|
||||
running_agent = MagicMock(spec=[])
|
||||
runner._running_agents[sk] = running_agent
|
||||
|
||||
result = await runner._handle_message(_make_event("/steer fallback"))
|
||||
|
||||
assert result is not None
|
||||
# Must mention queueing since steer wasn't available
|
||||
assert "queued" in result.lower()
|
||||
assert sk in adapter._pending_messages
|
||||
assert adapter._pending_messages[sk].text == "fallback"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_steer_rejected_payload_returns_rejection_message():
|
||||
"""If agent.steer() returns False (e.g. empty after strip — though
|
||||
the gateway already guards this), surface a rejection message."""
|
||||
runner, _adapter = _make_runner(_session_entry())
|
||||
sk = build_session_key(_make_source())
|
||||
|
||||
running_agent = MagicMock()
|
||||
running_agent.steer.return_value = False
|
||||
runner._running_agents[sk] = running_agent
|
||||
|
||||
result = await runner._handle_message(_make_event("/steer hello"))
|
||||
|
||||
assert result is not None
|
||||
assert "rejected" in result.lower() or "empty" in result.lower()
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
pytest.main([__file__, "-v"])
|
||||
|
|
@ -502,11 +502,13 @@ class TestSegmentBreakOnToolBoundary:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_segment_break_clears_failed_edit_fallback_state(self):
|
||||
"""A tool boundary after edit failure must not duplicate the next segment."""
|
||||
"""A tool boundary after edit failure must flush the undelivered tail
|
||||
without duplicating the prefix the user already saw (#8124)."""
|
||||
adapter = MagicMock()
|
||||
send_results = [
|
||||
SimpleNamespace(success=True, message_id="msg_1"),
|
||||
SimpleNamespace(success=True, message_id="msg_2"),
|
||||
SimpleNamespace(success=True, message_id="msg_3"),
|
||||
]
|
||||
adapter.send = AsyncMock(side_effect=send_results)
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=False, error="flood_control:6"))
|
||||
|
|
@ -526,7 +528,60 @@ class TestSegmentBreakOnToolBoundary:
|
|||
await task
|
||||
|
||||
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
|
||||
assert sent_texts == ["Hello ▉", "Next segment"]
|
||||
# The undelivered "world" tail must reach the user, and the next
|
||||
# segment must not duplicate "Hello" that was already visible.
|
||||
assert sent_texts == ["Hello ▉", "world", "Next segment"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_segment_break_after_mid_stream_edit_failure_preserves_tail(self):
|
||||
"""Regression for #8124: when an earlier edit succeeded but later edits
|
||||
fail (persistent flood control) and a tool boundary arrives before the
|
||||
fallback threshold is reached, the pre-boundary tail must still be
|
||||
delivered — not silently dropped by the segment reset."""
|
||||
adapter = MagicMock()
|
||||
# msg_1 for the initial partial, msg_2 for the flushed tail,
|
||||
# msg_3 for the post-boundary segment.
|
||||
send_results = [
|
||||
SimpleNamespace(success=True, message_id="msg_1"),
|
||||
SimpleNamespace(success=True, message_id="msg_2"),
|
||||
SimpleNamespace(success=True, message_id="msg_3"),
|
||||
]
|
||||
adapter.send = AsyncMock(side_effect=send_results)
|
||||
|
||||
# First two edits succeed, everything after fails with flood control
|
||||
# — simulating Telegram's "edit once then get rate-limited" pattern.
|
||||
edit_results = [
|
||||
SimpleNamespace(success=True), # "Hello world ▉" — succeeds
|
||||
SimpleNamespace(success=False, error="flood_control:6.0"), # "Hello world more ▉" — flood triggered
|
||||
SimpleNamespace(success=False, error="flood_control:6.0"), # finalize edit at segment break
|
||||
SimpleNamespace(success=False, error="flood_control:6.0"), # cursor-strip attempt
|
||||
]
|
||||
adapter.edit_message = AsyncMock(side_effect=edit_results + [edit_results[-1]] * 10)
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉")
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
|
||||
|
||||
consumer.on_delta("Hello")
|
||||
task = asyncio.create_task(consumer.run())
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(" world")
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(" more")
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(None) # tool boundary
|
||||
consumer.on_delta("Here is the tool result.")
|
||||
consumer.finish()
|
||||
await task
|
||||
|
||||
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
|
||||
# "more" must have been delivered, not dropped.
|
||||
all_text = " ".join(sent_texts)
|
||||
assert "more" in all_text, (
|
||||
f"Pre-boundary tail 'more' was silently dropped: sends={sent_texts}"
|
||||
)
|
||||
# Post-boundary text must also reach the user.
|
||||
assert "Here is the tool result." in all_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_message_id_enters_fallback_mode(self):
|
||||
|
|
@ -1161,3 +1216,87 @@ class TestBufferOnlyMode:
|
|||
# text, the consumer may send then edit, or just send once at got_done.
|
||||
# The key assertion: this doesn't break.
|
||||
assert adapter.send.call_count >= 1
|
||||
|
||||
|
||||
# ── Cursor stripping on fallback (#7183) ────────────────────────────────────
|
||||
|
||||
|
||||
class TestCursorStrippingOnFallback:
|
||||
"""Regression: cursor must be stripped when fallback continuation is empty (#7183).
|
||||
|
||||
When _send_fallback_final is called with nothing new to deliver (the visible
|
||||
partial already matches final_text), the last edit may still show the cursor
|
||||
character because fallback mode was entered after a failed edit. Before the
|
||||
fix this would leave the message permanently frozen with a visible ▉.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cursor_stripped_when_continuation_empty(self):
|
||||
"""_send_fallback_final must attempt a final edit to strip the cursor."""
|
||||
adapter = MagicMock()
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
adapter.edit_message = AsyncMock(
|
||||
return_value=SimpleNamespace(success=True, message_id="msg-1")
|
||||
)
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter, "chat-1",
|
||||
config=StreamConsumerConfig(cursor=" ▉"),
|
||||
)
|
||||
consumer._message_id = "msg-1"
|
||||
consumer._last_sent_text = "Hello world ▉"
|
||||
consumer._fallback_final_send = False
|
||||
|
||||
await consumer._send_fallback_final("Hello world")
|
||||
|
||||
adapter.edit_message.assert_called_once()
|
||||
call_args = adapter.edit_message.call_args
|
||||
assert call_args.kwargs["content"] == "Hello world"
|
||||
assert consumer._already_sent is True
|
||||
# _last_sent_text should reflect the cleaned text after a successful strip
|
||||
assert consumer._last_sent_text == "Hello world"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cursor_not_stripped_when_no_cursor_configured(self):
|
||||
"""No edit attempted when cursor is not configured."""
|
||||
adapter = MagicMock()
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
adapter.edit_message = AsyncMock()
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter, "chat-1",
|
||||
config=StreamConsumerConfig(cursor=""),
|
||||
)
|
||||
consumer._message_id = "msg-1"
|
||||
consumer._last_sent_text = "Hello world"
|
||||
consumer._fallback_final_send = False
|
||||
|
||||
await consumer._send_fallback_final("Hello world")
|
||||
|
||||
adapter.edit_message.assert_not_called()
|
||||
assert consumer._already_sent is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cursor_strip_edit_failure_handled(self):
|
||||
"""If the cursor-stripping edit itself fails, it must not crash and
|
||||
must not corrupt _last_sent_text."""
|
||||
adapter = MagicMock()
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
adapter.edit_message = AsyncMock(
|
||||
return_value=SimpleNamespace(success=False, error="flood_control")
|
||||
)
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter, "chat-1",
|
||||
config=StreamConsumerConfig(cursor=" ▉"),
|
||||
)
|
||||
consumer._message_id = "msg-1"
|
||||
consumer._last_sent_text = "Hello ▉"
|
||||
consumer._fallback_final_send = False
|
||||
|
||||
await consumer._send_fallback_final("Hello")
|
||||
|
||||
# Should still set already_sent despite the cursor-strip edit failure
|
||||
assert consumer._already_sent is True
|
||||
# _last_sent_text must NOT be updated when the edit failed
|
||||
assert consumer._last_sent_text == "Hello ▉"
|
||||
|
|
|
|||
|
|
@ -483,6 +483,32 @@ class TestSendDocument:
|
|||
assert "not found" in result.error.lower()
|
||||
connected_adapter._bot.send_document.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_workspace_path_has_docker_hint(self, connected_adapter):
|
||||
"""Container-local-looking paths get a more actionable Docker hint."""
|
||||
result = await connected_adapter.send_document(
|
||||
chat_id="12345",
|
||||
file_path="/workspace/report.txt",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "docker sandbox" in result.error.lower()
|
||||
assert "host-visible path" in result.error.lower()
|
||||
connected_adapter._bot.send_document.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_outputs_path_has_docker_hint(self, connected_adapter):
|
||||
"""Legacy /outputs paths also get the Docker hint."""
|
||||
result = await connected_adapter.send_document(
|
||||
chat_id="12345",
|
||||
file_path="/outputs/report.txt",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "docker sandbox" in result.error.lower()
|
||||
assert "host-visible path" in result.error.lower()
|
||||
connected_adapter._bot.send_document.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_not_connected(self, adapter):
|
||||
"""If bot is None, returns not connected error."""
|
||||
|
|
@ -665,6 +691,17 @@ class TestSendVideo:
|
|||
assert result.success is False
|
||||
assert "not found" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_workspace_path_has_docker_hint(self, connected_adapter):
|
||||
result = await connected_adapter.send_video(
|
||||
chat_id="12345",
|
||||
video_path="/workspace/video.mp4",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "docker sandbox" in result.error.lower()
|
||||
assert "host-visible path" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_not_connected(self, adapter):
|
||||
result = await adapter.send_video(
|
||||
|
|
|
|||
|
|
@ -148,6 +148,70 @@ class TestDiscordTextBatching:
|
|||
await asyncio.sleep(0.25)
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shield_protects_handle_message_from_cancel(self):
|
||||
"""Regression guard: a follow-up chunk arriving while
|
||||
handle_message is mid-flight must NOT cancel the running
|
||||
dispatch. _enqueue_text_event fires prior_task.cancel() on
|
||||
every new chunk; without asyncio.shield around handle_message
|
||||
the cancel propagates into the agent's streaming request and
|
||||
aborts the response.
|
||||
"""
|
||||
adapter = _make_discord_adapter()
|
||||
|
||||
handle_started = asyncio.Event()
|
||||
release_handle = asyncio.Event()
|
||||
first_handle_cancelled = asyncio.Event()
|
||||
first_handle_completed = asyncio.Event()
|
||||
call_count = [0]
|
||||
|
||||
async def slow_handle(event):
|
||||
call_count[0] += 1
|
||||
# Only the first call (batch 1) is the one we're protecting.
|
||||
if call_count[0] == 1:
|
||||
handle_started.set()
|
||||
try:
|
||||
await release_handle.wait()
|
||||
first_handle_completed.set()
|
||||
except asyncio.CancelledError:
|
||||
first_handle_cancelled.set()
|
||||
raise
|
||||
# Second call (batch 2) returns immediately — not the subject
|
||||
# of this test.
|
||||
|
||||
adapter.handle_message = slow_handle
|
||||
|
||||
# Prime batch 1 and wait for it to land inside handle_message.
|
||||
adapter._enqueue_text_event(_make_event("batch 1", Platform.DISCORD))
|
||||
await asyncio.wait_for(handle_started.wait(), timeout=1.0)
|
||||
|
||||
# A new chunk arrives — _enqueue_text_event fires
|
||||
# prior_task.cancel() on batch 1's flush task, which is
|
||||
# currently awaiting inside handle_message.
|
||||
adapter._enqueue_text_event(_make_event("batch 2 follow-up", Platform.DISCORD))
|
||||
|
||||
# Let the cancel propagate.
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# CRITICAL ASSERTION: batch 1's handle_message must NOT have
|
||||
# been cancelled. Without asyncio.shield this assertion fails
|
||||
# because CancelledError propagates from the flush task's
|
||||
# `await self.handle_message(event)` into slow_handle.
|
||||
assert not first_handle_cancelled.is_set(), (
|
||||
"handle_message for batch 1 was cancelled by a follow-up "
|
||||
"chunk — asyncio.shield is missing or broken"
|
||||
)
|
||||
|
||||
# Release batch 1's handle_message and let it complete.
|
||||
release_handle.set()
|
||||
await asyncio.wait_for(first_handle_completed.wait(), timeout=1.0)
|
||||
assert first_handle_completed.is_set()
|
||||
|
||||
# Cleanup
|
||||
for task in list(adapter._pending_text_batch_tasks.values()):
|
||||
task.cancel()
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Matrix text batching
|
||||
|
|
|
|||
|
|
@ -758,7 +758,7 @@ class TestVoiceChannelCommands:
|
|||
result = await runner._handle_voice_channel_join(event)
|
||||
|
||||
assert "voice dependencies are missing" in result.lower()
|
||||
assert "hermes-agent[messaging]" in result
|
||||
assert "PyNaCl" in result
|
||||
|
||||
# -- _handle_voice_channel_leave --
|
||||
|
||||
|
|
|
|||
473
tests/gateway/test_webhook_deliver_only.py
Normal file
473
tests/gateway/test_webhook_deliver_only.py
Normal file
|
|
@ -0,0 +1,473 @@
|
|||
"""Tests for the webhook adapter's ``deliver_only`` route mode.
|
||||
|
||||
``deliver_only`` lets external services (Supabase webhooks, monitoring
|
||||
alerts, background jobs, other agents) push plain-text notifications to
|
||||
a user's chat via the webhook adapter WITHOUT invoking the agent. The
|
||||
rendered prompt template becomes the literal message body.
|
||||
|
||||
Covers:
|
||||
- Agent is NOT invoked (``handle_message`` never called)
|
||||
- Rendered content is delivered to the target platform adapter
|
||||
- HTTP returns 200 OK on success, 502 on delivery failure
|
||||
- Startup validation rejects ``deliver_only`` without a real delivery target
|
||||
- HMAC auth, rate limiting, and idempotency still apply
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent, SendResult
|
||||
from gateway.platforms.webhook import WebhookAdapter, _INSECURE_NO_AUTH
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_adapter(routes, **extra_kw) -> WebhookAdapter:
|
||||
extra = {"host": "0.0.0.0", "port": 0, "routes": routes}
|
||||
extra.update(extra_kw)
|
||||
config = PlatformConfig(enabled=True, extra=extra)
|
||||
return WebhookAdapter(config)
|
||||
|
||||
|
||||
def _create_app(adapter: WebhookAdapter) -> web.Application:
|
||||
app = web.Application()
|
||||
app.router.add_get("/health", adapter._handle_health)
|
||||
app.router.add_post("/webhooks/{route_name}", adapter._handle_webhook)
|
||||
return app
|
||||
|
||||
|
||||
def _wire_mock_target(adapter: WebhookAdapter, platform_name: str = "telegram"):
|
||||
"""Attach a gateway_runner with a mocked target adapter."""
|
||||
mock_target = AsyncMock()
|
||||
mock_target.send = AsyncMock(return_value=SendResult(success=True))
|
||||
|
||||
mock_runner = MagicMock()
|
||||
mock_runner.adapters = {Platform(platform_name): mock_target}
|
||||
mock_runner.config.get_home_channel.return_value = None
|
||||
|
||||
adapter.gateway_runner = mock_runner
|
||||
return mock_target
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Core behaviour: agent bypass
|
||||
# ===================================================================
|
||||
|
||||
class TestDeliverOnlyBypassesAgent:
|
||||
"""The whole point of the feature — handle_message must not be called."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_delivers_directly_without_agent(self):
|
||||
routes = {
|
||||
"match-alert": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "12345"},
|
||||
"prompt": "{payload.user} matched with {payload.other}!",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
mock_target = _wire_mock_target(adapter)
|
||||
|
||||
# Guard: handle_message must NOT be called in deliver_only mode
|
||||
handle_message_calls: list[MessageEvent] = []
|
||||
|
||||
async def _capture(event):
|
||||
handle_message_calls.append(event)
|
||||
|
||||
adapter.handle_message = _capture
|
||||
|
||||
app = _create_app(adapter)
|
||||
body = json.dumps(
|
||||
{"payload": {"user": "alice", "other": "bob"}}
|
||||
).encode()
|
||||
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/match-alert",
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-GitHub-Delivery": "delivery-1",
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["status"] == "delivered"
|
||||
assert data["route"] == "match-alert"
|
||||
assert data["target"] == "telegram"
|
||||
|
||||
# Let any background tasks settle before asserting no agent call
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Agent was NOT invoked
|
||||
assert handle_message_calls == []
|
||||
|
||||
# Target adapter.send() WAS called with the rendered template
|
||||
mock_target.send.assert_awaited_once()
|
||||
call_args = mock_target.send.await_args
|
||||
chat_id_arg, content_arg = call_args.args[0], call_args.args[1]
|
||||
assert chat_id_arg == "12345"
|
||||
assert content_arg == "alice matched with bob!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_template_rendering_works(self):
|
||||
"""Dot-notation template variables resolve in deliver_only mode."""
|
||||
routes = {
|
||||
"alert": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "chat-1"},
|
||||
"prompt": "Build {build.number} status: {build.status}",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
mock_target = _wire_mock_target(adapter)
|
||||
app = _create_app(adapter)
|
||||
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/alert",
|
||||
json={"build": {"number": 77, "status": "FAILED"}},
|
||||
headers={"X-GitHub-Delivery": "d-render-1"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
|
||||
mock_target.send.assert_awaited_once()
|
||||
content_arg = mock_target.send.await_args.args[1]
|
||||
assert content_arg == "Build 77 status: FAILED"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_id_passed_through(self):
|
||||
"""deliver_extra.thread_id flows through to the target adapter."""
|
||||
routes = {
|
||||
"r": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "c-1", "thread_id": "topic-42"},
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
mock_target = _wire_mock_target(adapter)
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": "d-thread-1"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
|
||||
assert mock_target.send.await_args.kwargs["metadata"] == {
|
||||
"thread_id": "topic-42"
|
||||
}
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# HTTP status codes
|
||||
# ===================================================================
|
||||
|
||||
class TestDeliverOnlyStatusCodes:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delivery_failure_returns_502(self):
|
||||
"""If the target adapter returns SendResult(success=False), 502."""
|
||||
routes = {
|
||||
"r": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "c-1"},
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
mock_target = _wire_mock_target(adapter)
|
||||
mock_target.send = AsyncMock(
|
||||
return_value=SendResult(success=False, error="rate limited by tg")
|
||||
)
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": "d-fail-1"},
|
||||
)
|
||||
assert resp.status == 502
|
||||
data = await resp.json()
|
||||
# Generic error — no adapter-level detail leaks
|
||||
assert data["error"] == "Delivery failed"
|
||||
assert "rate limited" not in json.dumps(data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delivery_exception_returns_502(self):
|
||||
"""If adapter.send() raises, we return 502 (not 500)."""
|
||||
routes = {
|
||||
"r": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "c-1"},
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
mock_target = _wire_mock_target(adapter)
|
||||
mock_target.send = AsyncMock(side_effect=RuntimeError("tg exploded"))
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": "d-exc-1"},
|
||||
)
|
||||
assert resp.status == 502
|
||||
data = await resp.json()
|
||||
assert data["error"] == "Delivery failed"
|
||||
# Exception message must not leak
|
||||
assert "exploded" not in json.dumps(data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_target_platform_not_connected_returns_502(self):
|
||||
"""deliver_only to a platform the gateway doesn't have → 502."""
|
||||
routes = {
|
||||
"r": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "discord", # not configured in mock runner
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "c-1"},
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
_wire_mock_target(adapter, platform_name="telegram") # only TG wired
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": "d-no-platform-1"},
|
||||
)
|
||||
assert resp.status == 502
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Startup validation
|
||||
# ===================================================================
|
||||
|
||||
class TestDeliverOnlyStartupValidation:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deliver_only_with_log_deliver_rejected(self):
|
||||
"""deliver_only=true + deliver=log is nonsense — reject at connect()."""
|
||||
routes = {
|
||||
"bad": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "log",
|
||||
"deliver_only": True,
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
with pytest.raises(ValueError, match="deliver_only=true but deliver is 'log'"):
|
||||
await adapter.connect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deliver_only_with_missing_deliver_rejected(self):
|
||||
"""deliver_only=true with no deliver field defaults to 'log' → reject."""
|
||||
routes = {
|
||||
"bad": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
# no deliver field
|
||||
"deliver_only": True,
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
with pytest.raises(ValueError, match="deliver_only=true"):
|
||||
await adapter.connect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deliver_only_with_real_target_accepted(self):
|
||||
"""Sanity check — a valid deliver_only config passes validation."""
|
||||
routes = {
|
||||
"good": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "c-1"},
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
# connect() does more than validation (binds a socket) — we just
|
||||
# want to verify the validation doesn't raise. Call it and tear
|
||||
# down immediately.
|
||||
try:
|
||||
started = await adapter.connect()
|
||||
if started:
|
||||
await adapter.disconnect()
|
||||
except ValueError:
|
||||
pytest.fail("valid deliver_only config should not raise ValueError")
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Security + reliability invariants still hold
|
||||
# ===================================================================
|
||||
|
||||
class TestDeliverOnlySecurityInvariants:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hmac_still_enforced(self):
|
||||
"""deliver_only does NOT bypass HMAC validation."""
|
||||
secret = "real-secret-123"
|
||||
routes = {
|
||||
"r": {
|
||||
"secret": secret,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "c-1"},
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
mock_target = _wire_mock_target(adapter)
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
# No signature header → reject
|
||||
resp = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": "d-noauth-1"},
|
||||
)
|
||||
assert resp.status == 401
|
||||
|
||||
# Target never called
|
||||
mock_target.send.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_idempotency_still_applies(self):
|
||||
"""Same delivery_id posted twice → second is suppressed."""
|
||||
routes = {
|
||||
"r": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "c-1"},
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
mock_target = _wire_mock_target(adapter)
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
r1 = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": "dup-1"},
|
||||
)
|
||||
assert r1.status == 200
|
||||
|
||||
r2 = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": "dup-1"},
|
||||
)
|
||||
# Existing webhook adapter treats duplicates as 200 + status=duplicate
|
||||
assert r2.status == 200
|
||||
data = await r2.json()
|
||||
assert data["status"] == "duplicate"
|
||||
|
||||
# Target was called exactly once
|
||||
assert mock_target.send.await_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_still_applies(self):
|
||||
"""Route-level rate limit caps deliver_only POSTs too."""
|
||||
routes = {
|
||||
"r": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "c-1"},
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes, rate_limit=2)
|
||||
_wire_mock_target(adapter)
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
for i in range(2):
|
||||
r = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": f"rl-{i}"},
|
||||
)
|
||||
assert r.status == 200
|
||||
|
||||
# Third within the window → 429
|
||||
r3 = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": "rl-3"},
|
||||
)
|
||||
assert r3.status == 429
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Unit: _direct_deliver dispatch
|
||||
# ===================================================================
|
||||
|
||||
class TestDirectDeliverUnit:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatches_to_cross_platform_for_messaging_targets(self):
|
||||
adapter = _make_adapter({})
|
||||
mock_target = _wire_mock_target(adapter, "telegram")
|
||||
|
||||
result = await adapter._direct_deliver(
|
||||
"hello",
|
||||
{"deliver": "telegram", "deliver_extra": {"chat_id": "c-1"}},
|
||||
)
|
||||
assert result.success is True
|
||||
mock_target.send.assert_awaited_once_with(
|
||||
"c-1", "hello", metadata=None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatches_to_github_comment(self):
|
||||
adapter = _make_adapter({})
|
||||
with patch.object(
|
||||
adapter, "_deliver_github_comment",
|
||||
new=AsyncMock(return_value=SendResult(success=True)),
|
||||
) as mock_gh:
|
||||
result = await adapter._direct_deliver(
|
||||
"review body",
|
||||
{
|
||||
"deliver": "github_comment",
|
||||
"deliver_extra": {"repo": "org/r", "pr_number": "1"},
|
||||
},
|
||||
)
|
||||
assert result.success is True
|
||||
mock_gh.assert_awaited_once()
|
||||
|
|
@ -14,7 +14,6 @@ from hermes_cli.auth import (
|
|||
PROVIDER_REGISTRY,
|
||||
_read_codex_tokens,
|
||||
_save_codex_tokens,
|
||||
_write_codex_cli_tokens,
|
||||
_import_codex_cli_tokens,
|
||||
get_codex_auth_status,
|
||||
get_provider_auth_state,
|
||||
|
|
@ -182,98 +181,6 @@ def test_codex_tokens_not_written_to_shared_file(tmp_path, monkeypatch):
|
|||
assert data["tokens"]["access_token"] == "hermes-at"
|
||||
|
||||
|
||||
def test_write_codex_cli_tokens_creates_file(tmp_path, monkeypatch):
|
||||
"""_write_codex_cli_tokens creates ~/.codex/auth.json with refreshed tokens."""
|
||||
codex_home = tmp_path / "codex-cli"
|
||||
monkeypatch.setenv("CODEX_HOME", str(codex_home))
|
||||
|
||||
_write_codex_cli_tokens("new-access", "new-refresh", last_refresh="2026-04-12T00:00:00Z")
|
||||
|
||||
auth_path = codex_home / "auth.json"
|
||||
assert auth_path.exists()
|
||||
data = json.loads(auth_path.read_text())
|
||||
assert data["tokens"]["access_token"] == "new-access"
|
||||
assert data["tokens"]["refresh_token"] == "new-refresh"
|
||||
assert data["last_refresh"] == "2026-04-12T00:00:00Z"
|
||||
# Verify file permissions are restricted
|
||||
assert (auth_path.stat().st_mode & 0o777) == 0o600
|
||||
|
||||
|
||||
def test_write_codex_cli_tokens_preserves_existing(tmp_path, monkeypatch):
|
||||
"""_write_codex_cli_tokens preserves extra fields in existing auth.json."""
|
||||
codex_home = tmp_path / "codex-cli"
|
||||
codex_home.mkdir(parents=True, exist_ok=True)
|
||||
monkeypatch.setenv("CODEX_HOME", str(codex_home))
|
||||
|
||||
existing = {
|
||||
"tokens": {
|
||||
"access_token": "old-access",
|
||||
"refresh_token": "old-refresh",
|
||||
"extra_field": "preserved",
|
||||
},
|
||||
"last_refresh": "2026-01-01T00:00:00Z",
|
||||
"custom_key": "keep_me",
|
||||
}
|
||||
(codex_home / "auth.json").write_text(json.dumps(existing))
|
||||
|
||||
_write_codex_cli_tokens("updated-access", "updated-refresh")
|
||||
|
||||
data = json.loads((codex_home / "auth.json").read_text())
|
||||
assert data["tokens"]["access_token"] == "updated-access"
|
||||
assert data["tokens"]["refresh_token"] == "updated-refresh"
|
||||
assert data["tokens"]["extra_field"] == "preserved"
|
||||
assert data["custom_key"] == "keep_me"
|
||||
# last_refresh not updated since we didn't pass it
|
||||
assert data["last_refresh"] == "2026-01-01T00:00:00Z"
|
||||
|
||||
|
||||
def test_write_codex_cli_tokens_handles_missing_dir(tmp_path, monkeypatch):
|
||||
"""_write_codex_cli_tokens creates parent directories if missing."""
|
||||
codex_home = tmp_path / "does" / "not" / "exist"
|
||||
monkeypatch.setenv("CODEX_HOME", str(codex_home))
|
||||
|
||||
_write_codex_cli_tokens("at", "rt")
|
||||
|
||||
assert (codex_home / "auth.json").exists()
|
||||
data = json.loads((codex_home / "auth.json").read_text())
|
||||
assert data["tokens"]["access_token"] == "at"
|
||||
|
||||
|
||||
def test_refresh_codex_auth_tokens_writes_back_to_cli(tmp_path, monkeypatch):
|
||||
"""After refreshing, _refresh_codex_auth_tokens writes back to ~/.codex/auth.json."""
|
||||
from hermes_cli.auth import _refresh_codex_auth_tokens
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
codex_home = tmp_path / "codex-cli"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
codex_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({"version": 1, "providers": {}}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setenv("CODEX_HOME", str(codex_home))
|
||||
|
||||
# Write initial CLI tokens
|
||||
(codex_home / "auth.json").write_text(json.dumps({
|
||||
"tokens": {"access_token": "old-at", "refresh_token": "old-rt"},
|
||||
}))
|
||||
|
||||
# Mock the pure refresh to return new tokens
|
||||
monkeypatch.setattr("hermes_cli.auth.refresh_codex_oauth_pure", lambda *a, **kw: {
|
||||
"access_token": "refreshed-at",
|
||||
"refresh_token": "refreshed-rt",
|
||||
"last_refresh": "2026-04-12T01:00:00Z",
|
||||
})
|
||||
|
||||
_refresh_codex_auth_tokens(
|
||||
{"access_token": "old-at", "refresh_token": "old-rt"},
|
||||
timeout_seconds=10,
|
||||
)
|
||||
|
||||
# Verify CLI file was updated
|
||||
cli_data = json.loads((codex_home / "auth.json").read_text())
|
||||
assert cli_data["tokens"]["access_token"] == "refreshed-at"
|
||||
assert cli_data["tokens"]["refresh_token"] == "refreshed-rt"
|
||||
|
||||
|
||||
def test_resolve_returns_hermes_auth_store_source(tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
_setup_hermes_auth(hermes_home)
|
||||
|
|
|
|||
|
|
@ -124,29 +124,23 @@ class TestCmdUpdateBranchFallback:
|
|||
if call.args and call.args[0][0] == "/usr/bin/npm"
|
||||
]
|
||||
|
||||
# cmd_update runs npm commands in three locations:
|
||||
# 1. repo root — slash-command / TUI bridge deps
|
||||
# 2. ui-tui/ — Ink TUI deps
|
||||
# 3. web/ — install + "npm run build" for the web frontend
|
||||
full_flags = [
|
||||
"/usr/bin/npm",
|
||||
"install",
|
||||
"--silent",
|
||||
"--no-fund",
|
||||
"--no-audit",
|
||||
"--progress=false",
|
||||
]
|
||||
assert npm_calls == [
|
||||
(
|
||||
[
|
||||
"/usr/bin/npm",
|
||||
"install",
|
||||
"--silent",
|
||||
"--no-fund",
|
||||
"--no-audit",
|
||||
"--progress=false",
|
||||
],
|
||||
PROJECT_ROOT,
|
||||
),
|
||||
(
|
||||
[
|
||||
"/usr/bin/npm",
|
||||
"install",
|
||||
"--silent",
|
||||
"--no-fund",
|
||||
"--no-audit",
|
||||
"--progress=false",
|
||||
],
|
||||
PROJECT_ROOT / "ui-tui",
|
||||
),
|
||||
(full_flags, PROJECT_ROOT),
|
||||
(full_flags, PROJECT_ROOT / "ui-tui"),
|
||||
(["/usr/bin/npm", "install", "--silent"], PROJECT_ROOT / "web"),
|
||||
(["/usr/bin/npm", "run", "build"], PROJECT_ROOT / "web"),
|
||||
]
|
||||
|
||||
def test_update_non_interactive_skips_migration_prompt(self, mock_args, capsys):
|
||||
|
|
|
|||
|
|
@ -459,7 +459,7 @@ class TestCustomProviderCompatibility:
|
|||
migrate_config(interactive=False, quiet=True)
|
||||
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
|
||||
assert raw["_config_version"] == 18
|
||||
assert raw["_config_version"] == 19
|
||||
assert raw["providers"]["openai-direct"] == {
|
||||
"api": "https://api.openai.com/v1",
|
||||
"api_key": "test-key",
|
||||
|
|
@ -606,7 +606,7 @@ class TestInterimAssistantMessageConfig:
|
|||
migrate_config(interactive=False, quiet=True)
|
||||
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
|
||||
assert raw["_config_version"] == 18
|
||||
assert raw["_config_version"] == 19
|
||||
assert raw["display"]["tool_progress"] == "off"
|
||||
assert raw["display"]["interim_assistant_messages"] is True
|
||||
|
||||
|
|
@ -626,6 +626,6 @@ class TestDiscordChannelPromptsConfig:
|
|||
migrate_config(interactive=False, quiet=True)
|
||||
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
|
||||
assert raw["_config_version"] == 18
|
||||
assert raw["_config_version"] == 19
|
||||
assert raw["discord"]["auto_thread"] is True
|
||||
assert raw["discord"]["channel_prompts"] == {}
|
||||
|
|
|
|||
|
|
@ -54,12 +54,12 @@ class TestCronCommandLifecycle:
|
|||
deliver=None,
|
||||
repeat=None,
|
||||
skill=None,
|
||||
skills=["find-nearby", "blogwatcher"],
|
||||
skills=["maps", "blogwatcher"],
|
||||
clear_skills=False,
|
||||
)
|
||||
)
|
||||
updated = get_job(job["id"])
|
||||
assert updated["skills"] == ["find-nearby", "blogwatcher"]
|
||||
assert updated["skills"] == ["maps", "blogwatcher"]
|
||||
assert updated["name"] == "Edited Job"
|
||||
assert updated["prompt"] == "Revised prompt"
|
||||
assert updated["schedule_display"] == "every 120m"
|
||||
|
|
@ -95,7 +95,7 @@ class TestCronCommandLifecycle:
|
|||
deliver=None,
|
||||
repeat=None,
|
||||
skill=None,
|
||||
skills=["blogwatcher", "find-nearby"],
|
||||
skills=["blogwatcher", "maps"],
|
||||
)
|
||||
)
|
||||
out = capsys.readouterr().out
|
||||
|
|
@ -103,5 +103,5 @@ class TestCronCommandLifecycle:
|
|||
|
||||
jobs = list_jobs()
|
||||
assert len(jobs) == 1
|
||||
assert jobs[0]["skills"] == ["blogwatcher", "find-nearby"]
|
||||
assert jobs[0]["skills"] == ["blogwatcher", "maps"]
|
||||
assert jobs[0]["name"] == "Skill combo"
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ class TestGeminiModelCatalog:
|
|||
models = _PROVIDER_MODELS["gemini"]
|
||||
assert "gemini-2.5-pro" in models
|
||||
assert "gemini-2.5-flash" in models
|
||||
assert "gemma-4-31b-it" in models
|
||||
assert "gemma-4-31b-it" not in models
|
||||
|
||||
def test_provider_models_has_3x(self):
|
||||
models = _PROVIDER_MODELS["gemini"]
|
||||
|
|
@ -207,6 +207,37 @@ class TestGeminiAgentInit:
|
|||
assert agent.api_mode == "chat_completions"
|
||||
assert agent.provider == "gemini"
|
||||
|
||||
def test_gemini_uses_bearer_auth(self, monkeypatch):
|
||||
"""Gemini OpenAI-compatible endpoint should receive the real API key."""
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "AIzaSy_REAL_KEY")
|
||||
real_key = "AIzaSy_REAL_KEY"
|
||||
with patch("run_agent.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
from run_agent import AIAgent
|
||||
AIAgent(
|
||||
model="gemini-2.5-flash",
|
||||
provider="gemini",
|
||||
api_key=real_key,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
)
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert call_kwargs.get("api_key") == real_key
|
||||
headers = call_kwargs.get("default_headers", {})
|
||||
assert "x-goog-api-key" not in headers
|
||||
|
||||
def test_gemini_resolve_provider_client_auth(self, monkeypatch):
|
||||
"""resolve_provider_client('gemini') should pass the real API key through."""
|
||||
monkeypatch.setenv("GEMINI_API_KEY", "AIzaSy_TEST_KEY")
|
||||
real_key = "AIzaSy_TEST_KEY"
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
resolve_provider_client("gemini")
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert call_kwargs.get("api_key") == real_key
|
||||
headers = call_kwargs.get("default_headers", {})
|
||||
assert "x-goog-api-key" not in headers
|
||||
|
||||
|
||||
# ── models.dev Integration ──
|
||||
|
||||
|
|
@ -261,9 +292,32 @@ class TestGeminiModelsDev:
|
|||
result = list_agentic_models("gemini")
|
||||
assert "gemini-3-flash-preview" in result
|
||||
assert "gemini-2.5-pro" in result
|
||||
assert "gemma-4-31b-it" in result
|
||||
assert "gemma-4-31b-it" not in result
|
||||
# Filtered out:
|
||||
assert "gemini-embedding-001" not in result # no tool_call
|
||||
assert "gemini-2.5-flash-preview-tts" not in result # no tool_call
|
||||
assert "gemini-live-2.5-flash" not in result # noise: live-
|
||||
assert "gemini-2.5-flash-preview-04-17" not in result # noise: dated preview
|
||||
|
||||
def test_list_provider_models_hides_low_tpm_google_gemmas(self):
|
||||
mock_data = {
|
||||
"google": {
|
||||
"models": {
|
||||
"gemini-2.5-pro": {},
|
||||
"gemma-4-31b-it": {},
|
||||
"gemma-3-27b-it": {},
|
||||
"gemini-1.5-pro": {},
|
||||
"gemini-2.0-flash": {},
|
||||
}
|
||||
}
|
||||
}
|
||||
with patch("agent.models_dev.fetch_models_dev", return_value=mock_data):
|
||||
from agent.models_dev import list_provider_models
|
||||
|
||||
result = list_provider_models("gemini")
|
||||
|
||||
assert "gemini-2.5-pro" in result
|
||||
assert "gemma-4-31b-it" not in result
|
||||
assert "gemma-3-27b-it" not in result
|
||||
assert "gemini-1.5-pro" not in result
|
||||
assert "gemini-2.0-flash" not in result
|
||||
|
|
|
|||
|
|
@ -450,9 +450,9 @@ class TestValidateApiNotFound:
|
|||
assert result["recognized"] is True
|
||||
|
||||
def test_dissimilar_model_shows_suggestions_not_autocorrect(self):
|
||||
"""Models too different for auto-correction still get suggestions."""
|
||||
"""Models too different for auto-correction are rejected with suggestions."""
|
||||
result = _validate("anthropic/claude-nonexistent")
|
||||
assert result["accepted"] is True
|
||||
assert result["accepted"] is False
|
||||
assert result.get("corrected_model") is None
|
||||
assert "not found" in result["message"]
|
||||
|
||||
|
|
@ -532,11 +532,11 @@ class TestValidateCodexAutoCorrection:
|
|||
assert result["message"] is None
|
||||
|
||||
def test_very_different_name_falls_to_suggestions(self):
|
||||
"""Names too different for auto-correction get the suggestion list."""
|
||||
"""Names too different for auto-correction are rejected with a suggestion list."""
|
||||
codex_models = ["gpt-5.4-mini", "gpt-5.4", "gpt-5.3-codex"]
|
||||
with patch("hermes_cli.models.provider_model_ids", return_value=codex_models):
|
||||
result = validate_requested_model("totally-wrong", "openai-codex")
|
||||
assert result["accepted"] is True
|
||||
assert result["accepted"] is False
|
||||
assert result["recognized"] is False
|
||||
assert result.get("corrected_model") is None
|
||||
assert "not found" in result["message"]
|
||||
|
|
|
|||
29
tests/hermes_cli/test_setup_agent_settings.py
Normal file
29
tests/hermes_cli/test_setup_agent_settings.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""Tests for agent-settings copy in the interactive setup wizard."""
|
||||
|
||||
from hermes_cli.setup import setup_agent_settings
|
||||
|
||||
|
||||
def test_setup_agent_settings_uses_displayed_max_iterations_value(tmp_path, monkeypatch, capsys):
|
||||
"""The helper text should match the value shown in the prompt."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
config = {
|
||||
"agent": {"max_turns": 90},
|
||||
"display": {"tool_progress": "all"},
|
||||
"compression": {"threshold": 0.50},
|
||||
"session_reset": {"mode": "both", "idle_minutes": 1440, "at_hour": 4},
|
||||
}
|
||||
|
||||
prompt_answers = iter(["60", "all", "0.5"])
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.get_env_value", lambda key: "60" if key == "HERMES_MAX_ITERATIONS" else "")
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: next(prompt_answers))
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda *args, **kwargs: 4)
|
||||
monkeypatch.setattr("hermes_cli.setup.save_env_value", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr("hermes_cli.setup.save_config", lambda *args, **kwargs: None)
|
||||
|
||||
setup_agent_settings(config)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Press Enter to keep 60." in out
|
||||
assert "Default is 90" not in out
|
||||
325
tests/hermes_cli/test_update_hangup_protection.py
Normal file
325
tests/hermes_cli/test_update_hangup_protection.py
Normal file
|
|
@ -0,0 +1,325 @@
|
|||
"""Tests for SIGHUP protection and stdout mirroring in ``hermes update``.
|
||||
|
||||
Covers ``_UpdateOutputStream``, ``_install_hangup_protection``, and
|
||||
``_finalize_update_output`` in ``hermes_cli/main.py``. These exist so
|
||||
that ``hermes update`` survives a terminal disconnect mid-install
|
||||
(SSH drop, shell close) without leaving the venv half-installed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.main import (
|
||||
_UpdateOutputStream,
|
||||
_finalize_update_output,
|
||||
_install_hangup_protection,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# _UpdateOutputStream
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpdateOutputStream:
|
||||
def test_write_mirrors_to_both_original_and_log(self):
|
||||
original = io.StringIO()
|
||||
log = io.StringIO()
|
||||
stream = _UpdateOutputStream(original, log)
|
||||
|
||||
stream.write("hello world\n")
|
||||
|
||||
assert original.getvalue() == "hello world\n"
|
||||
assert log.getvalue() == "hello world\n"
|
||||
|
||||
def test_write_continues_after_broken_original(self):
|
||||
"""When the terminal disconnects, original.write raises BrokenPipeError.
|
||||
|
||||
The wrapper must catch it, flip the broken flag, and keep writing to
|
||||
the log from then on.
|
||||
"""
|
||||
log = io.StringIO()
|
||||
|
||||
class _BrokenStream:
|
||||
def write(self, data):
|
||||
raise BrokenPipeError("terminal gone")
|
||||
|
||||
def flush(self):
|
||||
raise BrokenPipeError("terminal gone")
|
||||
|
||||
stream = _UpdateOutputStream(_BrokenStream(), log)
|
||||
|
||||
# First write triggers the broken-pipe path.
|
||||
stream.write("first line\n")
|
||||
# Subsequent writes take the fast broken path (no exception).
|
||||
stream.write("second line\n")
|
||||
|
||||
assert log.getvalue() == "first line\nsecond line\n"
|
||||
assert stream._original_broken is True
|
||||
|
||||
def test_write_tolerates_oserror_and_valueerror(self):
|
||||
"""OSError (EIO) and ValueError (closed file) should also be absorbed."""
|
||||
log = io.StringIO()
|
||||
|
||||
class _RaisingStream:
|
||||
def __init__(self, exc):
|
||||
self._exc = exc
|
||||
|
||||
def write(self, data):
|
||||
raise self._exc
|
||||
|
||||
def flush(self):
|
||||
raise self._exc
|
||||
|
||||
for exc in (OSError("EIO"), ValueError("closed file")):
|
||||
stream = _UpdateOutputStream(_RaisingStream(exc), log)
|
||||
stream.write("x\n")
|
||||
assert stream._original_broken is True
|
||||
|
||||
def test_log_failure_does_not_abort_write(self):
|
||||
"""Even if the log file write raises, the original write must still happen."""
|
||||
class _BrokenLog:
|
||||
def write(self, data):
|
||||
raise OSError("disk full")
|
||||
|
||||
def flush(self):
|
||||
raise OSError("disk full")
|
||||
|
||||
original = io.StringIO()
|
||||
stream = _UpdateOutputStream(original, _BrokenLog())
|
||||
|
||||
stream.write("data\n")
|
||||
|
||||
assert original.getvalue() == "data\n"
|
||||
|
||||
def test_flush_tolerates_broken_original(self):
|
||||
class _BrokenStream:
|
||||
def write(self, data):
|
||||
return len(data)
|
||||
|
||||
def flush(self):
|
||||
raise BrokenPipeError("gone")
|
||||
|
||||
log = io.StringIO()
|
||||
stream = _UpdateOutputStream(_BrokenStream(), log)
|
||||
stream.flush() # must not raise
|
||||
assert stream._original_broken is True
|
||||
|
||||
def test_isatty_delegates_to_original(self):
|
||||
class _TtyStream:
|
||||
def isatty(self):
|
||||
return True
|
||||
|
||||
def write(self, data):
|
||||
return len(data)
|
||||
|
||||
def flush(self):
|
||||
return None
|
||||
|
||||
stream = _UpdateOutputStream(_TtyStream(), io.StringIO())
|
||||
assert stream.isatty() is True
|
||||
|
||||
def test_isatty_returns_false_after_broken(self):
|
||||
class _BrokenStream:
|
||||
def isatty(self):
|
||||
return True
|
||||
|
||||
def write(self, data):
|
||||
raise BrokenPipeError()
|
||||
|
||||
def flush(self):
|
||||
return None
|
||||
|
||||
stream = _UpdateOutputStream(_BrokenStream(), io.StringIO())
|
||||
stream.write("x") # marks broken
|
||||
assert stream.isatty() is False
|
||||
|
||||
def test_getattr_delegates_unknown_attrs(self):
|
||||
class _StreamWithEncoding:
|
||||
encoding = "utf-8"
|
||||
|
||||
def write(self, data):
|
||||
return len(data)
|
||||
|
||||
def flush(self):
|
||||
return None
|
||||
|
||||
stream = _UpdateOutputStream(_StreamWithEncoding(), io.StringIO())
|
||||
assert stream.encoding == "utf-8"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# _install_hangup_protection
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInstallHangupProtection:
|
||||
def test_gateway_mode_is_noop(self):
|
||||
"""In gateway mode the process is already detached — don't touch stdio or signals."""
|
||||
prev_out, prev_err = sys.stdout, sys.stderr
|
||||
prev_sighup = signal.getsignal(signal.SIGHUP) if hasattr(signal, "SIGHUP") else None
|
||||
|
||||
state = _install_hangup_protection(gateway_mode=True)
|
||||
|
||||
try:
|
||||
assert sys.stdout is prev_out
|
||||
assert sys.stderr is prev_err
|
||||
assert state["log_file"] is None
|
||||
assert state["installed"] is False
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
assert signal.getsignal(signal.SIGHUP) == prev_sighup
|
||||
finally:
|
||||
_finalize_update_output(state)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not hasattr(signal, "SIGHUP"), reason="SIGHUP not available on this platform"
|
||||
)
|
||||
def test_installs_sighup_ignore(self, tmp_path, monkeypatch):
|
||||
"""SIGHUP should be set to SIG_IGN so SSH disconnect doesn't kill the update."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
# Clear cached get_hermes_home if present
|
||||
import hermes_cli.config as _cfg
|
||||
if hasattr(_cfg, "_HERMES_HOME_CACHE"):
|
||||
_cfg._HERMES_HOME_CACHE = None # type: ignore[attr-defined]
|
||||
|
||||
original_handler = signal.getsignal(signal.SIGHUP)
|
||||
state = _install_hangup_protection(gateway_mode=False)
|
||||
|
||||
try:
|
||||
assert signal.getsignal(signal.SIGHUP) == signal.SIG_IGN
|
||||
finally:
|
||||
_finalize_update_output(state)
|
||||
# Restore whatever was there before so we don't leak to other tests.
|
||||
signal.signal(signal.SIGHUP, original_handler)
|
||||
|
||||
def test_wraps_stdout_and_stderr_with_mirror(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
# Nuke any cached home path
|
||||
import hermes_cli.config as _cfg
|
||||
if hasattr(_cfg, "_HERMES_HOME_CACHE"):
|
||||
_cfg._HERMES_HOME_CACHE = None # type: ignore[attr-defined]
|
||||
|
||||
prev_out, prev_err = sys.stdout, sys.stderr
|
||||
state = _install_hangup_protection(gateway_mode=False)
|
||||
|
||||
try:
|
||||
# On Windows (no SIGHUP) we still wrap stdio and create the log.
|
||||
assert state["installed"] is True
|
||||
assert isinstance(sys.stdout, _UpdateOutputStream)
|
||||
assert isinstance(sys.stderr, _UpdateOutputStream)
|
||||
assert state["log_file"] is not None
|
||||
|
||||
sys.stdout.write("checking mirror\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
log_path = tmp_path / "logs" / "update.log"
|
||||
assert log_path.exists()
|
||||
contents = log_path.read_text(encoding="utf-8")
|
||||
assert "checking mirror" in contents
|
||||
assert "hermes update started" in contents
|
||||
finally:
|
||||
_finalize_update_output(state)
|
||||
# Sanity-check restoration
|
||||
assert sys.stdout is prev_out
|
||||
assert sys.stderr is prev_err
|
||||
|
||||
def test_logs_dir_created_if_missing(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
import hermes_cli.config as _cfg
|
||||
if hasattr(_cfg, "_HERMES_HOME_CACHE"):
|
||||
_cfg._HERMES_HOME_CACHE = None # type: ignore[attr-defined]
|
||||
|
||||
# No logs/ dir yet.
|
||||
assert not (tmp_path / "logs").exists()
|
||||
|
||||
state = _install_hangup_protection(gateway_mode=False)
|
||||
try:
|
||||
assert (tmp_path / "logs").is_dir()
|
||||
assert (tmp_path / "logs" / "update.log").exists()
|
||||
finally:
|
||||
_finalize_update_output(state)
|
||||
|
||||
def test_non_fatal_if_log_setup_fails(self, monkeypatch):
|
||||
"""If get_hermes_home() raises, stdio must be left untouched but SIGHUP still handled."""
|
||||
prev_out, prev_err = sys.stdout, sys.stderr
|
||||
|
||||
def _boom():
|
||||
raise RuntimeError("no home for you")
|
||||
|
||||
# Patch the import inside _install_hangup_protection.
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.get_hermes_home", _boom, raising=True
|
||||
)
|
||||
|
||||
original_handler = (
|
||||
signal.getsignal(signal.SIGHUP) if hasattr(signal, "SIGHUP") else None
|
||||
)
|
||||
|
||||
state = _install_hangup_protection(gateway_mode=False)
|
||||
|
||||
try:
|
||||
assert sys.stdout is prev_out
|
||||
assert sys.stderr is prev_err
|
||||
assert state["installed"] is False
|
||||
# SIGHUP must still be installed even when log setup fails.
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
assert signal.getsignal(signal.SIGHUP) == signal.SIG_IGN
|
||||
finally:
|
||||
_finalize_update_output(state)
|
||||
if hasattr(signal, "SIGHUP") and original_handler is not None:
|
||||
signal.signal(signal.SIGHUP, original_handler)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# _finalize_update_output
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFinalizeUpdateOutput:
|
||||
def test_none_state_is_noop(self):
|
||||
_finalize_update_output(None) # must not raise
|
||||
|
||||
def test_restores_streams_and_closes_log(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
import hermes_cli.config as _cfg
|
||||
if hasattr(_cfg, "_HERMES_HOME_CACHE"):
|
||||
_cfg._HERMES_HOME_CACHE = None # type: ignore[attr-defined]
|
||||
|
||||
prev_out = sys.stdout
|
||||
state = _install_hangup_protection(gateway_mode=False)
|
||||
log_file = state["log_file"]
|
||||
|
||||
assert sys.stdout is not prev_out
|
||||
assert log_file is not None
|
||||
|
||||
_finalize_update_output(state)
|
||||
|
||||
assert sys.stdout is prev_out
|
||||
# The log file handle should be closed.
|
||||
assert log_file.closed is True
|
||||
|
||||
def test_skipped_install_leaves_stdio_alone(self):
|
||||
"""When install failed (state['installed']=False) finalize should not
|
||||
touch sys.stdout / sys.stderr (they were never wrapped)."""
|
||||
# Build a synthetic state that mimics a failed install.
|
||||
sentinel_out = object()
|
||||
state = {
|
||||
"prev_stdout": sentinel_out,
|
||||
"prev_stderr": sentinel_out,
|
||||
"log_file": None,
|
||||
"installed": False,
|
||||
}
|
||||
before_out, before_err = sys.stdout, sys.stderr
|
||||
|
||||
_finalize_update_output(state)
|
||||
|
||||
assert sys.stdout is before_out
|
||||
assert sys.stderr is before_err
|
||||
|
|
@ -460,10 +460,3 @@ class TestPrefetchCacheAccessors:
|
|||
assert mgr.pop_context_result("cli:test") == payload
|
||||
assert mgr.pop_context_result("cli:test") == {}
|
||||
|
||||
def test_set_and_pop_dialectic_result(self):
|
||||
mgr = _make_manager(write_frequency="turn")
|
||||
|
||||
mgr.set_dialectic_result("cli:test", "Resume with toolset cleanup")
|
||||
|
||||
assert mgr.pop_dialectic_result("cli:test") == "Resume with toolset cleanup"
|
||||
assert mgr.pop_dialectic_result("cli:test") == ""
|
||||
|
|
|
|||
|
|
@ -26,6 +26,9 @@ class TestCmdStatus:
|
|||
write_frequency = "async"
|
||||
session_strategy = "per-session"
|
||||
context_tokens = 800
|
||||
dialectic_reasoning_level = "low"
|
||||
reasoning_level_cap = "high"
|
||||
reasoning_heuristic = True
|
||||
|
||||
def resolve_session_name(self):
|
||||
return "hermes"
|
||||
|
|
|
|||
|
|
@ -568,15 +568,15 @@ class TestToolsModeInitBehavior:
|
|||
|
||||
with patch("plugins.memory.honcho.client.HonchoClientConfig.from_global_config", return_value=cfg), \
|
||||
patch("plugins.memory.honcho.client.get_honcho_client", return_value=MagicMock()), \
|
||||
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mock_manager), \
|
||||
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mock_manager) as mock_manager_cls, \
|
||||
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
|
||||
provider.initialize(session_id="test-session-001", **init_kwargs)
|
||||
|
||||
return provider, cfg
|
||||
return provider, cfg, mock_manager_cls
|
||||
|
||||
def test_tools_lazy_default(self):
|
||||
"""tools + initOnSessionStart=false → session NOT initialized after initialize()."""
|
||||
provider, _ = self._make_provider_with_config(
|
||||
provider, _, _ = self._make_provider_with_config(
|
||||
recall_mode="tools", init_on_session_start=False,
|
||||
)
|
||||
assert provider._session_initialized is False
|
||||
|
|
@ -585,7 +585,7 @@ class TestToolsModeInitBehavior:
|
|||
|
||||
def test_tools_eager_init(self):
|
||||
"""tools + initOnSessionStart=true → session IS initialized after initialize()."""
|
||||
provider, _ = self._make_provider_with_config(
|
||||
provider, _, _ = self._make_provider_with_config(
|
||||
recall_mode="tools", init_on_session_start=True,
|
||||
)
|
||||
assert provider._session_initialized is True
|
||||
|
|
@ -593,33 +593,34 @@ class TestToolsModeInitBehavior:
|
|||
|
||||
def test_tools_eager_prefetch_still_empty(self):
|
||||
"""tools mode with eager init still returns empty from prefetch() (no auto-injection)."""
|
||||
provider, _ = self._make_provider_with_config(
|
||||
provider, _, _ = self._make_provider_with_config(
|
||||
recall_mode="tools", init_on_session_start=True,
|
||||
)
|
||||
assert provider.prefetch("test query") == ""
|
||||
|
||||
def test_tools_lazy_prefetch_empty(self):
|
||||
"""tools mode with lazy init also returns empty from prefetch()."""
|
||||
provider, _ = self._make_provider_with_config(
|
||||
provider, _, _ = self._make_provider_with_config(
|
||||
recall_mode="tools", init_on_session_start=False,
|
||||
)
|
||||
assert provider.prefetch("test query") == ""
|
||||
|
||||
def test_explicit_peer_name_not_overridden_by_user_id(self):
|
||||
"""Explicit peerName in config must not be replaced by gateway user_id."""
|
||||
_, cfg = self._make_provider_with_config(
|
||||
_, cfg, _ = self._make_provider_with_config(
|
||||
recall_mode="tools", init_on_session_start=True,
|
||||
peer_name="Kathie", user_id="8439114563",
|
||||
)
|
||||
assert cfg.peer_name == "Kathie"
|
||||
|
||||
def test_user_id_used_when_no_peer_name(self):
|
||||
"""Gateway user_id is used as peer_name when no explicit peerName configured."""
|
||||
_, cfg = self._make_provider_with_config(
|
||||
"""Gateway user_id is passed separately from config peer_name."""
|
||||
_, cfg, mock_manager_cls = self._make_provider_with_config(
|
||||
recall_mode="tools", init_on_session_start=True,
|
||||
peer_name=None, user_id="8439114563",
|
||||
)
|
||||
assert cfg.peer_name == "8439114563"
|
||||
assert cfg.peer_name is None
|
||||
assert mock_manager_cls.call_args.kwargs["runtime_user_peer_name"] == "8439114563"
|
||||
|
||||
|
||||
class TestPerSessionMigrateGuard:
|
||||
|
|
@ -815,6 +816,27 @@ class TestDialecticInputGuard:
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _settle_prewarm(provider):
|
||||
"""Wait for the session-start prewarm dialectic thread, then return the
|
||||
provider to a clean 'nothing fired yet' state so cadence/first-turn/
|
||||
trivial-prompt tests can assert from a known baseline."""
|
||||
if provider._prefetch_thread:
|
||||
provider._prefetch_thread.join(timeout=3.0)
|
||||
with provider._prefetch_lock:
|
||||
provider._prefetch_result = ""
|
||||
provider._prefetch_result_fired_at = -999
|
||||
provider._prefetch_thread = None
|
||||
provider._prefetch_thread_started_at = 0.0
|
||||
provider._last_dialectic_turn = -999
|
||||
provider._dialectic_empty_streak = 0
|
||||
if getattr(provider, "_manager", None) is not None:
|
||||
try:
|
||||
provider._manager.dialectic_query.reset_mock()
|
||||
provider._manager.prefetch_context.reset_mock()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
class TestDialecticCadenceDefaults:
|
||||
"""Regression tests for dialectic_cadence default value."""
|
||||
|
||||
|
|
@ -840,12 +862,15 @@ class TestDialecticCadenceDefaults:
|
|||
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
|
||||
provider.initialize(session_id="test-session-001")
|
||||
|
||||
_settle_prewarm(provider)
|
||||
return provider
|
||||
|
||||
def test_default_is_3(self):
|
||||
"""Default dialectic_cadence should be 3 to avoid per-turn LLM calls."""
|
||||
def test_unset_falls_back_to_1(self):
|
||||
"""Unset dialecticCadence falls back to 1 (every turn) for backwards
|
||||
compatibility with existing configs that predate the setting. The
|
||||
setup wizard writes 2 explicitly on new configs."""
|
||||
provider = self._make_provider()
|
||||
assert provider._dialectic_cadence == 3
|
||||
assert provider._dialectic_cadence == 1
|
||||
|
||||
def test_config_override(self):
|
||||
"""dialecticCadence from config overrides the default."""
|
||||
|
|
@ -908,6 +933,7 @@ class TestDialecticDepth:
|
|||
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
|
||||
provider.initialize(session_id="test-session-001")
|
||||
|
||||
_settle_prewarm(provider)
|
||||
return provider
|
||||
|
||||
def test_default_depth_is_1(self):
|
||||
|
|
@ -1027,46 +1053,6 @@ class TestDialecticDepth:
|
|||
assert provider._manager.dialectic_query.call_count == 2
|
||||
assert "Synthesis" in result
|
||||
|
||||
def test_first_turn_runs_dialectic_synchronously(self):
|
||||
"""First turn should fire the dialectic synchronously (cold start)."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
provider = self._make_provider(cfg_extra={"dialectic_depth": 1})
|
||||
provider._manager = MagicMock()
|
||||
provider._manager.dialectic_query.return_value = "cold start synthesis"
|
||||
provider._manager.get_prefetch_context.return_value = None
|
||||
provider._manager.pop_context_result.return_value = None
|
||||
provider._session_key = "test"
|
||||
provider._base_context_cache = "" # cold start
|
||||
provider._last_dialectic_turn = -999 # never fired
|
||||
|
||||
result = provider.prefetch("hello world")
|
||||
assert "cold start synthesis" in result
|
||||
assert provider._manager.dialectic_query.call_count == 1
|
||||
# After first-turn sync, _last_dialectic_turn should be updated
|
||||
assert provider._last_dialectic_turn != -999
|
||||
|
||||
def test_first_turn_dialectic_does_not_double_fire(self):
|
||||
"""After first-turn sync dialectic, queue_prefetch should skip (cadence)."""
|
||||
from unittest.mock import MagicMock
|
||||
provider = self._make_provider(cfg_extra={"dialectic_depth": 1})
|
||||
provider._manager = MagicMock()
|
||||
provider._manager.dialectic_query.return_value = "cold start synthesis"
|
||||
provider._manager.get_prefetch_context.return_value = None
|
||||
provider._manager.pop_context_result.return_value = None
|
||||
provider._session_key = "test"
|
||||
provider._base_context_cache = ""
|
||||
provider._last_dialectic_turn = -999
|
||||
provider._turn_count = 0
|
||||
|
||||
# First turn fires sync dialectic
|
||||
provider.prefetch("hello")
|
||||
assert provider._manager.dialectic_query.call_count == 1
|
||||
|
||||
# Now queue_prefetch on same turn should skip (cadence: 0 - 0 < 3)
|
||||
provider._manager.dialectic_query.reset_mock()
|
||||
provider.queue_prefetch("hello")
|
||||
assert provider._manager.dialectic_query.call_count == 0
|
||||
|
||||
def test_run_dialectic_depth_bails_early_on_strong_signal(self):
|
||||
"""Depth 2 skips pass 1 when pass 0 returns strong signal."""
|
||||
from unittest.mock import MagicMock
|
||||
|
|
@ -1083,6 +1069,584 @@ class TestDialecticDepth:
|
|||
assert provider._manager.dialectic_query.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trivial-prompt heuristic + dialectic cadence silent-failure guards
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTrivialPromptHeuristic:
|
||||
"""Trivial prompts ('ok', 'y', slash commands) must short-circuit injection."""
|
||||
|
||||
@staticmethod
|
||||
def _make_provider():
|
||||
from unittest.mock import patch, MagicMock
|
||||
from plugins.memory.honcho.client import HonchoClientConfig
|
||||
|
||||
cfg = HonchoClientConfig(api_key="test-key", enabled=True, recall_mode="hybrid")
|
||||
provider = HonchoMemoryProvider()
|
||||
mock_manager = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session.messages = []
|
||||
mock_manager.get_or_create.return_value = mock_session
|
||||
|
||||
with patch("plugins.memory.honcho.client.HonchoClientConfig.from_global_config", return_value=cfg), \
|
||||
patch("plugins.memory.honcho.client.get_honcho_client", return_value=MagicMock()), \
|
||||
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mock_manager), \
|
||||
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
|
||||
provider.initialize(session_id="test-session-trivial")
|
||||
_settle_prewarm(provider)
|
||||
return provider
|
||||
|
||||
def test_classifier_catches_common_trivial_forms(self):
|
||||
for t in ("ok", "OK", " ok ", "y", "yes", "sure", "thanks", "lgtm", "/help", "", " "):
|
||||
assert HonchoMemoryProvider._is_trivial_prompt(t), f"expected trivial: {t!r}"
|
||||
|
||||
def test_classifier_lets_substantive_prompts_through(self):
|
||||
for t in ("hello world", "what's my name", "explain this", "ok so what's next"):
|
||||
assert not HonchoMemoryProvider._is_trivial_prompt(t), f"expected non-trivial: {t!r}"
|
||||
|
||||
def test_prefetch_skips_on_trivial_prompt(self):
|
||||
provider = self._make_provider()
|
||||
provider._session_key = "test"
|
||||
provider._base_context_cache = "cached base"
|
||||
provider._last_dialectic_turn = 0
|
||||
provider._turn_count = 5
|
||||
|
||||
assert provider.prefetch("ok") == ""
|
||||
assert provider.prefetch("/help") == ""
|
||||
# Dialectic should not have fired
|
||||
assert provider._manager.dialectic_query.call_count == 0
|
||||
|
||||
def test_queue_prefetch_skips_on_trivial_prompt(self):
|
||||
provider = self._make_provider()
|
||||
provider._session_key = "test"
|
||||
provider._turn_count = 10
|
||||
provider._last_dialectic_turn = -999 # would otherwise fire
|
||||
# initialize() pre-warms; clear call counts before the assertion.
|
||||
provider._manager.prefetch_context.reset_mock()
|
||||
provider._manager.dialectic_query.reset_mock()
|
||||
|
||||
provider.queue_prefetch("y")
|
||||
# Trivial prompts short-circuit both context refresh and dialectic fire.
|
||||
assert provider._manager.prefetch_context.call_count == 0
|
||||
assert provider._manager.dialectic_query.call_count == 0
|
||||
|
||||
|
||||
class TestDialecticCadenceAdvancesOnSuccess:
|
||||
"""Cadence tracker advances only when the dialectic call returns a
|
||||
non-empty result. Empty results (transient API error, sparse representation)
|
||||
must retry on the next eligible turn instead of waiting the full cadence."""
|
||||
|
||||
@staticmethod
|
||||
def _make_provider():
|
||||
from unittest.mock import patch, MagicMock
|
||||
from plugins.memory.honcho.client import HonchoClientConfig
|
||||
|
||||
cfg = HonchoClientConfig(
|
||||
api_key="test-key", enabled=True, recall_mode="hybrid", dialectic_depth=1,
|
||||
)
|
||||
provider = HonchoMemoryProvider()
|
||||
mock_manager = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session.messages = []
|
||||
mock_manager.get_or_create.return_value = mock_session
|
||||
|
||||
with patch("plugins.memory.honcho.client.HonchoClientConfig.from_global_config", return_value=cfg), \
|
||||
patch("plugins.memory.honcho.client.get_honcho_client", return_value=MagicMock()), \
|
||||
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mock_manager), \
|
||||
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
|
||||
provider.initialize(session_id="test-session-retry")
|
||||
_settle_prewarm(provider)
|
||||
return provider
|
||||
|
||||
def test_empty_dialectic_result_does_not_advance_cadence(self):
|
||||
import time as _time
|
||||
provider = self._make_provider()
|
||||
provider._session_key = "test"
|
||||
provider._manager.dialectic_query.return_value = "" # silent failure
|
||||
provider._turn_count = 5
|
||||
provider._last_dialectic_turn = 0 # would fire (5 - 0 = 5 ≥ 3)
|
||||
|
||||
provider.queue_prefetch("hello")
|
||||
# wait for the background thread to settle
|
||||
if provider._prefetch_thread:
|
||||
provider._prefetch_thread.join(timeout=2.0)
|
||||
|
||||
# Dialectic call was attempted
|
||||
assert provider._manager.dialectic_query.call_count == 1
|
||||
# But cadence tracker did NOT advance — next turn should retry
|
||||
assert provider._last_dialectic_turn == 0
|
||||
|
||||
def test_non_empty_dialectic_result_advances_cadence(self):
|
||||
provider = self._make_provider()
|
||||
provider._session_key = "test"
|
||||
provider._manager.dialectic_query.return_value = "real synthesis output"
|
||||
provider._turn_count = 5
|
||||
provider._last_dialectic_turn = 0
|
||||
|
||||
provider.queue_prefetch("hello")
|
||||
if provider._prefetch_thread:
|
||||
provider._prefetch_thread.join(timeout=2.0)
|
||||
|
||||
assert provider._last_dialectic_turn == 5
|
||||
|
||||
def test_in_flight_thread_is_not_stacked(self):
|
||||
import threading as _threading
|
||||
import time as _time
|
||||
provider = self._make_provider()
|
||||
provider._session_key = "test"
|
||||
provider._turn_count = 10
|
||||
provider._last_dialectic_turn = 0
|
||||
|
||||
# Simulate a prior thread still running (fresh, not stale)
|
||||
hold = _threading.Event()
|
||||
|
||||
def _block():
|
||||
hold.wait(timeout=5.0)
|
||||
|
||||
fresh = _threading.Thread(target=_block, daemon=True)
|
||||
fresh.start()
|
||||
provider._prefetch_thread = fresh
|
||||
provider._prefetch_thread_started_at = _time.monotonic() # fresh start
|
||||
|
||||
provider.queue_prefetch("hello")
|
||||
# Should have short-circuited — no new dialectic call
|
||||
assert provider._manager.dialectic_query.call_count == 0
|
||||
hold.set()
|
||||
fresh.join(timeout=2.0)
|
||||
|
||||
|
||||
class TestSessionStartDialecticPrewarm:
|
||||
"""Session-start prewarm fires a depth-aware dialectic whose result is
|
||||
consumed by turn 1 — no duplicate .chat() and no dead-cache orphaning."""
|
||||
|
||||
@staticmethod
|
||||
def _make_provider(cfg_extra=None, dialectic_result="prewarm synthesis"):
|
||||
from unittest.mock import patch, MagicMock
|
||||
from plugins.memory.honcho.client import HonchoClientConfig
|
||||
|
||||
defaults = dict(api_key="test-key", enabled=True, recall_mode="hybrid")
|
||||
if cfg_extra:
|
||||
defaults.update(cfg_extra)
|
||||
cfg = HonchoClientConfig(**defaults)
|
||||
provider = HonchoMemoryProvider()
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_or_create.return_value = MagicMock(messages=[])
|
||||
mock_manager.get_prefetch_context.return_value = None
|
||||
mock_manager.pop_context_result.return_value = None
|
||||
mock_manager.dialectic_query.return_value = dialectic_result
|
||||
|
||||
with patch("plugins.memory.honcho.client.HonchoClientConfig.from_global_config", return_value=cfg), \
|
||||
patch("plugins.memory.honcho.client.get_honcho_client", return_value=MagicMock()), \
|
||||
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mock_manager), \
|
||||
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
|
||||
provider.initialize(session_id="test-prewarm")
|
||||
return provider
|
||||
|
||||
def test_prewarm_populates_prefetch_result(self):
|
||||
p = self._make_provider()
|
||||
# Wait for prewarm thread to land
|
||||
if p._prefetch_thread:
|
||||
p._prefetch_thread.join(timeout=3.0)
|
||||
with p._prefetch_lock:
|
||||
assert p._prefetch_result == "prewarm synthesis"
|
||||
assert p._last_dialectic_turn == 0
|
||||
|
||||
def test_turn1_consumes_prewarm_without_duplicate_dialectic(self):
|
||||
"""With prewarm result already in _prefetch_result, turn 1 prefetch
|
||||
should NOT fire another dialectic."""
|
||||
p = self._make_provider()
|
||||
if p._prefetch_thread:
|
||||
p._prefetch_thread.join(timeout=3.0)
|
||||
p._manager.dialectic_query.reset_mock()
|
||||
p._session_key = "test-prewarm"
|
||||
p._base_context_cache = ""
|
||||
p._turn_count = 1
|
||||
|
||||
result = p.prefetch("hello world")
|
||||
assert "prewarm synthesis" in result
|
||||
# The sync first-turn path must NOT have fired another .chat()
|
||||
assert p._manager.dialectic_query.call_count == 0
|
||||
|
||||
def test_turn1_falls_back_to_sync_when_prewarm_missing(self):
|
||||
"""If the prewarm produced nothing (empty graph, API blip), turn 1
|
||||
still fires its own sync dialectic."""
|
||||
p = self._make_provider(dialectic_result="") # prewarm returns empty
|
||||
if p._prefetch_thread:
|
||||
p._prefetch_thread.join(timeout=3.0)
|
||||
with p._prefetch_lock:
|
||||
assert p._prefetch_result == "" # prewarm landed nothing
|
||||
# Switch dialectic_query to return something on the sync first-turn call
|
||||
p._manager.dialectic_query.return_value = "sync recovery"
|
||||
p._manager.dialectic_query.reset_mock()
|
||||
p._session_key = "test-prewarm"
|
||||
p._base_context_cache = ""
|
||||
p._turn_count = 1
|
||||
|
||||
result = p.prefetch("hello world")
|
||||
assert "sync recovery" in result
|
||||
assert p._manager.dialectic_query.call_count == 1
|
||||
|
||||
|
||||
class TestDialecticLiveness:
|
||||
"""Liveness + observability: stale-thread recovery, stale-result discard,
|
||||
empty-streak backoff, and the snapshot method used for diagnostics."""
|
||||
|
||||
@staticmethod
|
||||
def _make_provider(cfg_extra=None):
|
||||
from unittest.mock import patch, MagicMock
|
||||
from plugins.memory.honcho.client import HonchoClientConfig
|
||||
|
||||
defaults = dict(api_key="test-key", enabled=True, recall_mode="hybrid", timeout=2.0)
|
||||
if cfg_extra:
|
||||
defaults.update(cfg_extra)
|
||||
cfg = HonchoClientConfig(**defaults)
|
||||
provider = HonchoMemoryProvider()
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_or_create.return_value = MagicMock(messages=[])
|
||||
mock_manager.get_prefetch_context.return_value = None
|
||||
mock_manager.pop_context_result.return_value = None
|
||||
mock_manager.dialectic_query.return_value = "" # default: silent
|
||||
|
||||
with patch("plugins.memory.honcho.client.HonchoClientConfig.from_global_config", return_value=cfg), \
|
||||
patch("plugins.memory.honcho.client.get_honcho_client", return_value=MagicMock()), \
|
||||
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mock_manager), \
|
||||
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
|
||||
provider.initialize(session_id="test-liveness")
|
||||
_settle_prewarm(provider)
|
||||
return provider
|
||||
|
||||
def test_stale_thread_is_treated_as_dead(self):
|
||||
"""A thread older than timeout × multiplier no longer blocks new fires."""
|
||||
import threading as _threading
|
||||
p = self._make_provider()
|
||||
p._session_key = "test"
|
||||
p._turn_count = 10
|
||||
p._last_dialectic_turn = 0
|
||||
p._manager.dialectic_query.return_value = "fresh synthesis"
|
||||
|
||||
# Plant an alive thread with an old timestamp (stale)
|
||||
hold = _threading.Event()
|
||||
stuck = _threading.Thread(target=lambda: hold.wait(timeout=10.0), daemon=True)
|
||||
stuck.start()
|
||||
p._prefetch_thread = stuck
|
||||
# timeout=2.0, multiplier=2.0, so anything older than 4s is stale
|
||||
p._prefetch_thread_started_at = 0.0 # very old (1970 monotonic baseline)
|
||||
|
||||
p.queue_prefetch("hello")
|
||||
# New thread should have been spawned since stuck one is stale
|
||||
assert p._prefetch_thread is not stuck, "stale thread must be recycled"
|
||||
if p._prefetch_thread:
|
||||
p._prefetch_thread.join(timeout=2.0)
|
||||
assert p._manager.dialectic_query.call_count == 1
|
||||
hold.set()
|
||||
stuck.join(timeout=2.0)
|
||||
|
||||
def test_stale_pending_result_is_discarded_on_read(self):
|
||||
"""A pending dialectic result from many turns ago is discarded
|
||||
instead of injected against a fresh conversational pivot."""
|
||||
p = self._make_provider(cfg_extra={"raw": {"dialecticCadence": 2}})
|
||||
p._session_key = "test"
|
||||
p._base_context_cache = "base ctx"
|
||||
with p._prefetch_lock:
|
||||
p._prefetch_result = "ancient synthesis"
|
||||
p._prefetch_result_fired_at = 1
|
||||
# cadence=2, multiplier=2 → stale after 4 turns since fire
|
||||
p._turn_count = 10
|
||||
p._last_dialectic_turn = 1 # prevents sync first-turn path
|
||||
|
||||
result = p.prefetch("what's new")
|
||||
assert "ancient synthesis" not in result, "stale pending must be discarded"
|
||||
# Cache slot cleared
|
||||
with p._prefetch_lock:
|
||||
assert p._prefetch_result == ""
|
||||
assert p._prefetch_result_fired_at == -999
|
||||
|
||||
def test_fresh_pending_result_is_kept(self):
|
||||
"""A pending result within the staleness window is injected normally."""
|
||||
p = self._make_provider(cfg_extra={"raw": {"dialecticCadence": 3}})
|
||||
p._session_key = "test"
|
||||
p._base_context_cache = ""
|
||||
with p._prefetch_lock:
|
||||
p._prefetch_result = "recent synthesis"
|
||||
p._prefetch_result_fired_at = 8
|
||||
p._turn_count = 9 # 1 turn since fire, well within cadence × 2 = 6
|
||||
p._last_dialectic_turn = 8
|
||||
|
||||
result = p.prefetch("what's new")
|
||||
assert "recent synthesis" in result
|
||||
|
||||
def test_empty_streak_widens_effective_cadence(self):
|
||||
"""After N empty returns, the gate waits cadence + N turns."""
|
||||
p = self._make_provider(cfg_extra={"raw": {"dialecticCadence": 1}})
|
||||
p._dialectic_empty_streak = 3
|
||||
# cadence=1, streak=3 → effective = 4
|
||||
assert p._effective_cadence() == 4
|
||||
|
||||
def test_backoff_is_capped(self):
|
||||
"""Effective cadence is capped at cadence × _BACKOFF_MAX."""
|
||||
p = self._make_provider(cfg_extra={"raw": {"dialecticCadence": 2}})
|
||||
p._dialectic_empty_streak = 100
|
||||
# cadence=2, ceiling = 2 × 8 = 16
|
||||
assert p._effective_cadence() == 16
|
||||
|
||||
def test_success_resets_empty_streak(self):
|
||||
"""A non-empty result zeroes the streak so healthy operation restores
|
||||
the base cadence immediately."""
|
||||
p = self._make_provider(cfg_extra={"raw": {"dialecticCadence": 1}})
|
||||
p._session_key = "test"
|
||||
p._dialectic_empty_streak = 5
|
||||
p._turn_count = 10
|
||||
p._last_dialectic_turn = 0
|
||||
p._manager.dialectic_query.return_value = "real output"
|
||||
|
||||
p.queue_prefetch("hello")
|
||||
if p._prefetch_thread:
|
||||
p._prefetch_thread.join(timeout=2.0)
|
||||
assert p._dialectic_empty_streak == 0
|
||||
assert p._last_dialectic_turn == 10
|
||||
|
||||
def test_empty_result_increments_streak(self):
|
||||
p = self._make_provider(cfg_extra={"raw": {"dialecticCadence": 1}})
|
||||
p._session_key = "test"
|
||||
p._turn_count = 5
|
||||
p._last_dialectic_turn = 0
|
||||
p._manager.dialectic_query.return_value = "" # empty
|
||||
|
||||
p.queue_prefetch("hello")
|
||||
if p._prefetch_thread:
|
||||
p._prefetch_thread.join(timeout=2.0)
|
||||
assert p._dialectic_empty_streak == 1
|
||||
assert p._last_dialectic_turn == 0 # cadence not advanced
|
||||
|
||||
def test_liveness_snapshot_shape(self):
|
||||
p = self._make_provider()
|
||||
snap = p.liveness_snapshot()
|
||||
for key in (
|
||||
"turn_count", "last_dialectic_turn", "pending_result_fired_at",
|
||||
"empty_streak", "effective_cadence", "thread_alive", "thread_age_seconds",
|
||||
):
|
||||
assert key in snap
|
||||
|
||||
|
||||
class TestDialecticLifecycleSmoke:
|
||||
"""End-to-end smoke walking a multi-turn session through prewarm,
|
||||
turn 1 consume, trivial skip, cadence fire, empty-result retry,
|
||||
heuristic bump, and session-end flush."""
|
||||
|
||||
@staticmethod
|
||||
def _make_provider(cfg_extra=None):
|
||||
from unittest.mock import patch, MagicMock
|
||||
from plugins.memory.honcho.client import HonchoClientConfig
|
||||
|
||||
defaults = dict(
|
||||
api_key="test-key", enabled=True, recall_mode="hybrid",
|
||||
dialectic_reasoning_level="low", reasoning_heuristic=True,
|
||||
reasoning_level_cap="high", dialectic_depth=1,
|
||||
)
|
||||
if cfg_extra:
|
||||
defaults.update(cfg_extra)
|
||||
cfg = HonchoClientConfig(**defaults)
|
||||
provider = HonchoMemoryProvider()
|
||||
mock_manager = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session.messages = []
|
||||
mock_manager.get_or_create.return_value = mock_session
|
||||
mock_manager.get_prefetch_context.return_value = None
|
||||
mock_manager.pop_context_result.return_value = None
|
||||
|
||||
with patch("plugins.memory.honcho.client.HonchoClientConfig.from_global_config", return_value=cfg), \
|
||||
patch("plugins.memory.honcho.client.get_honcho_client", return_value=MagicMock()), \
|
||||
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mock_manager), \
|
||||
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
|
||||
return provider, mock_manager, cfg
|
||||
|
||||
def _await_thread(self, provider):
|
||||
if provider._prefetch_thread:
|
||||
provider._prefetch_thread.join(timeout=3.0)
|
||||
|
||||
def test_full_multi_turn_session(self):
|
||||
"""Walks init → turns 1..8 → session end. Asserts at every step that
|
||||
the plugin did exactly what it should and nothing more.
|
||||
|
||||
Uses dialecticCadence=3 so we can exercise skip-turns between fires
|
||||
and the silent-failure retry path without their gates tripping each
|
||||
other. Trivial + slash skips apply independent of cadence.
|
||||
"""
|
||||
from unittest.mock import patch, MagicMock
|
||||
provider, mgr, cfg = self._make_provider(
|
||||
cfg_extra={"raw": {"dialecticCadence": 3}}
|
||||
)
|
||||
|
||||
# Program the dialectic responses in the exact order they'll be requested.
|
||||
# An extra or missing call fails the test — strong smoke signal.
|
||||
responses = iter([
|
||||
"prewarm: user is eri, works on hermes", # session-start prewarm
|
||||
"cadence fire: long query synthesis", # turn 4 queue_prefetch
|
||||
"", # turn 7 fire: silent failure
|
||||
"retry success: fresh synthesis", # turn 8 queue_prefetch retry
|
||||
])
|
||||
mgr.dialectic_query.side_effect = lambda *a, **kw: next(responses)
|
||||
|
||||
# ---- init: prewarm fires ----
|
||||
with patch("plugins.memory.honcho.client.HonchoClientConfig.from_global_config", return_value=cfg), \
|
||||
patch("plugins.memory.honcho.client.get_honcho_client", return_value=MagicMock()), \
|
||||
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mgr), \
|
||||
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
|
||||
provider.initialize(session_id="smoke-test")
|
||||
|
||||
self._await_thread(provider)
|
||||
with provider._prefetch_lock:
|
||||
assert provider._prefetch_result.startswith("prewarm"), \
|
||||
"session-start prewarm must land in _prefetch_result"
|
||||
assert provider._last_dialectic_turn == 0, "prewarm marks turn 0"
|
||||
assert mgr.dialectic_query.call_count == 1
|
||||
|
||||
# ---- turn 1: consume prewarm, no duplicate dialectic ----
|
||||
provider.on_turn_start(1, "hey")
|
||||
inject1 = provider.prefetch("hey")
|
||||
assert "prewarm" in inject1, "turn 1 must surface prewarm"
|
||||
provider.sync_turn("hey", "hi there")
|
||||
provider.queue_prefetch("hey") # cadence gate: (1-0)<3 → skip
|
||||
self._await_thread(provider)
|
||||
assert mgr.dialectic_query.call_count == 1, \
|
||||
"turn 1 must not fire — prewarm covered it and cadence skips"
|
||||
|
||||
# ---- turn 2: trivial 'ok' → skip everything ----
|
||||
mgr.prefetch_context.reset_mock()
|
||||
provider.on_turn_start(2, "ok")
|
||||
assert provider.prefetch("ok") == "", "trivial prompt must short-circuit injection"
|
||||
provider.sync_turn("ok", "cool")
|
||||
provider.queue_prefetch("ok")
|
||||
self._await_thread(provider)
|
||||
assert mgr.dialectic_query.call_count == 1, "trivial must not fire dialectic"
|
||||
assert mgr.prefetch_context.call_count == 0, "trivial must not fire context refresh"
|
||||
|
||||
# ---- turn 3: slash '/help' → also skip ----
|
||||
provider.on_turn_start(3, "/help")
|
||||
assert provider.prefetch("/help") == ""
|
||||
provider.queue_prefetch("/help")
|
||||
assert mgr.dialectic_query.call_count == 1
|
||||
|
||||
# ---- turn 4: long query → cadence fires + heuristic bumps ----
|
||||
long_q = "walk me through " + ("x " * 100) # ~200 chars → heuristic +1
|
||||
provider.on_turn_start(4, long_q)
|
||||
provider.prefetch(long_q)
|
||||
provider.sync_turn(long_q, "sure")
|
||||
provider.queue_prefetch(long_q) # (4-0)≥3 → fires
|
||||
self._await_thread(provider)
|
||||
assert mgr.dialectic_query.call_count == 2, "turn 4 cadence fire"
|
||||
_, kwargs = mgr.dialectic_query.call_args
|
||||
assert kwargs.get("reasoning_level") in ("medium", "high"), \
|
||||
f"long query must bump reasoning level above 'low'; got {kwargs.get('reasoning_level')}"
|
||||
assert provider._last_dialectic_turn == 4, "cadence tracker advances on success"
|
||||
|
||||
# ---- turns 5–6: cadence cooldown, no fires ----
|
||||
for t in (5, 6):
|
||||
provider.on_turn_start(t, "tell me more")
|
||||
provider.queue_prefetch("tell me more")
|
||||
self._await_thread(provider)
|
||||
assert mgr.dialectic_query.call_count == 2, "turns 5–6 blocked by cadence window"
|
||||
|
||||
# ---- turn 7: fires but silent failure (empty dialectic) ----
|
||||
provider.on_turn_start(7, "and then what")
|
||||
provider.queue_prefetch("and then what") # (7-4)≥3 → fires
|
||||
self._await_thread(provider)
|
||||
assert mgr.dialectic_query.call_count == 3, "turn 7 fires"
|
||||
assert provider._last_dialectic_turn == 4, \
|
||||
"silent failure must NOT burn the cadence window"
|
||||
|
||||
# ---- turn 8: retries because cadence didn't advance ----
|
||||
provider.on_turn_start(8, "try again")
|
||||
provider.queue_prefetch("try again") # (8-4)≥3 → fires again
|
||||
self._await_thread(provider)
|
||||
assert mgr.dialectic_query.call_count == 4, \
|
||||
"turn 8 retries because turn 7's empty result didn't advance cadence"
|
||||
assert provider._last_dialectic_turn == 8, "retry success advances"
|
||||
|
||||
# ---- session end: flush messages ----
|
||||
provider.on_session_end([])
|
||||
mgr.flush_all.assert_called()
|
||||
|
||||
|
||||
class TestReasoningHeuristic:
|
||||
"""Char-count heuristic that scales the auto-injected reasoning level by
|
||||
query length, clamped at reasoning_level_cap."""
|
||||
|
||||
@staticmethod
|
||||
def _make_provider(cfg_extra=None):
|
||||
from unittest.mock import patch, MagicMock
|
||||
from plugins.memory.honcho.client import HonchoClientConfig
|
||||
|
||||
defaults = dict(
|
||||
api_key="test-key", enabled=True, recall_mode="hybrid",
|
||||
dialectic_reasoning_level="low", reasoning_heuristic=True,
|
||||
reasoning_level_cap="high",
|
||||
)
|
||||
if cfg_extra:
|
||||
defaults.update(cfg_extra)
|
||||
cfg = HonchoClientConfig(**defaults)
|
||||
provider = HonchoMemoryProvider()
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_or_create.return_value = MagicMock(messages=[])
|
||||
with patch("plugins.memory.honcho.client.HonchoClientConfig.from_global_config", return_value=cfg), \
|
||||
patch("plugins.memory.honcho.client.get_honcho_client", return_value=MagicMock()), \
|
||||
patch("plugins.memory.honcho.session.HonchoSessionManager", return_value=mock_manager), \
|
||||
patch("hermes_constants.get_hermes_home", return_value=MagicMock()):
|
||||
provider.initialize(session_id="test-heuristic")
|
||||
_settle_prewarm(provider)
|
||||
return provider
|
||||
|
||||
def test_short_query_stays_at_base(self):
|
||||
p = self._make_provider()
|
||||
assert p._apply_reasoning_heuristic("low", "hey") == "low"
|
||||
|
||||
def test_medium_query_bumps_one_level(self):
|
||||
p = self._make_provider()
|
||||
q = "x" * 150
|
||||
assert p._apply_reasoning_heuristic("low", q) == "medium"
|
||||
|
||||
def test_long_query_bumps_two_levels(self):
|
||||
p = self._make_provider()
|
||||
q = "x" * 500
|
||||
assert p._apply_reasoning_heuristic("low", q) == "high"
|
||||
|
||||
def test_bump_respects_cap(self):
|
||||
p = self._make_provider(cfg_extra={"reasoning_level_cap": "medium"})
|
||||
q = "x" * 500 # would hit 'high' without the cap
|
||||
assert p._apply_reasoning_heuristic("low", q) == "medium"
|
||||
|
||||
def test_max_never_auto_selected_with_default_cap(self):
|
||||
p = self._make_provider(cfg_extra={"dialectic_reasoning_level": "high"})
|
||||
q = "x" * 500 # base=high, bump would push to 'max'
|
||||
assert p._apply_reasoning_heuristic("high", q) == "high"
|
||||
|
||||
def test_heuristic_disabled_returns_base(self):
|
||||
p = self._make_provider(cfg_extra={"reasoning_heuristic": False})
|
||||
q = "x" * 500
|
||||
assert p._apply_reasoning_heuristic("low", q) == "low"
|
||||
|
||||
def test_resolve_pass_level_applies_heuristic_at_base_mapping(self):
|
||||
"""Depth=1, pass 0 maps to 'base' → heuristic applies."""
|
||||
p = self._make_provider()
|
||||
q = "x" * 150
|
||||
assert p._resolve_pass_level(0, query=q) == "medium"
|
||||
|
||||
def test_resolve_pass_level_does_not_touch_explicit_per_pass(self):
|
||||
"""dialecticDepthLevels wins absolutely — no heuristic scaling."""
|
||||
p = self._make_provider(cfg_extra={"dialectic_depth_levels": ["minimal"]})
|
||||
q = "x" * 500 # heuristic would otherwise bump to 'high'
|
||||
assert p._resolve_pass_level(0, query=q) == "minimal"
|
||||
|
||||
def test_resolve_pass_level_does_not_touch_lighter_passes(self):
|
||||
"""Depth 3 pass 0 is hardcoded 'minimal' — heuristic must not bump it."""
|
||||
p = self._make_provider(cfg_extra={"dialectic_depth": 3})
|
||||
q = "x" * 500
|
||||
assert p._resolve_pass_level(0, query=q) == "minimal"
|
||||
# But the 'base' pass (idx 1 for depth 3) does get heuristic
|
||||
assert p._resolve_pass_level(1, query=q) == "high"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_peer_card None guard
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -23,6 +23,10 @@ def _make_agent(monkeypatch):
|
|||
|
||||
class _Stub:
|
||||
_interrupt_requested = False
|
||||
_interrupt_message = None
|
||||
# Bind to this thread's ident so interrupt() targets a real tid.
|
||||
_execution_thread_id = threading.current_thread().ident
|
||||
_interrupt_thread_signal_pending = False
|
||||
log_prefix = ""
|
||||
quiet_mode = True
|
||||
verbose_logging = False
|
||||
|
|
@ -40,6 +44,15 @@ def _make_agent(monkeypatch):
|
|||
_current_tool = None
|
||||
_last_activity = 0
|
||||
_print_fn = print
|
||||
# Worker-thread tracking state mirrored from AIAgent.__init__ so the
|
||||
# real interrupt() method can fan out to concurrent-tool workers.
|
||||
_active_children: list = []
|
||||
|
||||
def __init__(self):
|
||||
# Instance-level (not class-level) so each test gets a fresh set.
|
||||
self._tool_worker_threads: set = set()
|
||||
self._tool_worker_threads_lock = threading.Lock()
|
||||
self._active_children_lock = threading.Lock()
|
||||
|
||||
def _touch_activity(self, desc):
|
||||
self._last_activity = time.time()
|
||||
|
|
@ -60,8 +73,10 @@ def _make_agent(monkeypatch):
|
|||
return False
|
||||
|
||||
stub = _Stub()
|
||||
# Bind the real methods
|
||||
# Bind the real methods under test
|
||||
stub._execute_tool_calls_concurrent = _ra.AIAgent._execute_tool_calls_concurrent.__get__(stub)
|
||||
stub.interrupt = _ra.AIAgent.interrupt.__get__(stub)
|
||||
stub.clear_interrupt = _ra.AIAgent.clear_interrupt.__get__(stub)
|
||||
stub._invoke_tool = MagicMock(side_effect=lambda *a, **kw: '{"ok": true}')
|
||||
return stub
|
||||
|
||||
|
|
@ -137,3 +152,109 @@ def test_concurrent_preflight_interrupt_skips_all(monkeypatch):
|
|||
assert "skipped due to user interrupt" in messages[1]["content"]
|
||||
# _invoke_tool should never have been called
|
||||
agent._invoke_tool.assert_not_called()
|
||||
|
||||
|
||||
def test_running_concurrent_worker_sees_is_interrupted(monkeypatch):
|
||||
"""Regression guard for the "interrupt-doesn't-reach-hung-tool" class of
|
||||
bug Physikal reported in April 2026.
|
||||
|
||||
Before this fix, `AIAgent.interrupt()` called `_set_interrupt(True,
|
||||
_execution_thread_id)` — which only flagged the agent's *main* thread.
|
||||
Tools running inside `_execute_tool_calls_concurrent` execute on
|
||||
ThreadPoolExecutor worker threads whose tids are NOT the agent's, so
|
||||
`is_interrupted()` (which checks the *current* thread's tid) returned
|
||||
False inside those tools no matter how many times the gateway called
|
||||
`.interrupt()`. Hung ssh / long curl / big make-build tools would run
|
||||
to their own timeout.
|
||||
|
||||
This test runs a fake tool in the concurrent path that polls
|
||||
`is_interrupted()` like a real terminal command does, then calls
|
||||
`agent.interrupt()` from another thread, and asserts the poll sees True
|
||||
within one second.
|
||||
"""
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
agent = _make_agent(monkeypatch)
|
||||
|
||||
# Counter plus observation hooks so we can prove the worker saw the flip.
|
||||
observed = {"saw_true": False, "poll_count": 0, "worker_tid": None}
|
||||
worker_started = threading.Event()
|
||||
|
||||
def polling_tool(name, args, task_id, call_id=None):
|
||||
observed["worker_tid"] = threading.current_thread().ident
|
||||
worker_started.set()
|
||||
deadline = time.monotonic() + 5.0
|
||||
while time.monotonic() < deadline:
|
||||
observed["poll_count"] += 1
|
||||
if is_interrupted():
|
||||
observed["saw_true"] = True
|
||||
return '{"interrupted": true}'
|
||||
time.sleep(0.05)
|
||||
return '{"timed_out": true}'
|
||||
|
||||
agent._invoke_tool = MagicMock(side_effect=polling_tool)
|
||||
|
||||
tc1 = _FakeToolCall("hung_fake_tool_1", call_id="tc1")
|
||||
tc2 = _FakeToolCall("hung_fake_tool_2", call_id="tc2")
|
||||
msg = _FakeAssistantMsg([tc1, tc2])
|
||||
messages = []
|
||||
|
||||
def _interrupt_after_start():
|
||||
# Wait until at least one worker is running so its tid is tracked.
|
||||
worker_started.wait(timeout=2.0)
|
||||
time.sleep(0.2) # let the other worker enter too
|
||||
agent.interrupt("stop requested by test")
|
||||
|
||||
t = threading.Thread(target=_interrupt_after_start)
|
||||
t.start()
|
||||
start = time.monotonic()
|
||||
agent._execute_tool_calls_concurrent(msg, messages, "test_task")
|
||||
elapsed = time.monotonic() - start
|
||||
t.join(timeout=2.0)
|
||||
|
||||
# The worker must have actually polled is_interrupted — otherwise the
|
||||
# test isn't exercising what it claims to.
|
||||
assert observed["poll_count"] > 0, (
|
||||
"polling_tool never ran — test scaffold issue"
|
||||
)
|
||||
# The worker must see the interrupt within ~1 s of agent.interrupt()
|
||||
# being called. Before the fix this loop ran until its 5 s own-timeout.
|
||||
assert observed["saw_true"], (
|
||||
f"is_interrupted() never returned True inside the concurrent worker "
|
||||
f"after agent.interrupt() — interrupt-propagation hole regressed. "
|
||||
f"worker_tid={observed['worker_tid']!r} poll_count={observed['poll_count']}"
|
||||
)
|
||||
assert elapsed < 3.0, (
|
||||
f"concurrent execution took {elapsed:.2f}s after interrupt — the fan-out "
|
||||
f"to worker tids didn't shortcut the tool's poll loop as expected"
|
||||
)
|
||||
# Also verify cleanup: no stale worker tids should remain after all
|
||||
# tools finished.
|
||||
assert agent._tool_worker_threads == set(), (
|
||||
f"worker tids leaked after run: {agent._tool_worker_threads}"
|
||||
)
|
||||
|
||||
|
||||
def test_clear_interrupt_clears_worker_tids(monkeypatch):
|
||||
"""After clear_interrupt(), stale worker-tid bits must be cleared so the
|
||||
next turn's tools — which may be scheduled onto recycled tids — don't
|
||||
see a false interrupt."""
|
||||
from tools.interrupt import is_interrupted, set_interrupt
|
||||
|
||||
agent = _make_agent(monkeypatch)
|
||||
# Simulate a worker having registered but not yet exited cleanly (e.g. a
|
||||
# hypothetical bug in the tear-down). Put a fake tid in the set and
|
||||
# flag it interrupted.
|
||||
fake_tid = threading.current_thread().ident # use real tid so is_interrupted can see it
|
||||
with agent._tool_worker_threads_lock:
|
||||
agent._tool_worker_threads.add(fake_tid)
|
||||
set_interrupt(True, fake_tid)
|
||||
assert is_interrupted() is True # sanity
|
||||
|
||||
agent.clear_interrupt()
|
||||
|
||||
assert is_interrupted() is False, (
|
||||
"clear_interrupt() did not clear the interrupt bit for a tracked "
|
||||
"worker tid — stale interrupt can leak into the next turn"
|
||||
)
|
||||
|
||||
|
|
|
|||
39
tests/run_agent/test_memory_provider_init.py
Normal file
39
tests/run_agent/test_memory_provider_init.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
"""Regression tests for memory provider selection during AIAgent init."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def test_blank_memory_provider_does_not_auto_enable_honcho():
|
||||
"""Blank memory.provider should remain opt-out even if Honcho fallback looks configured."""
|
||||
cfg = {"memory": {"provider": ""}, "agent": {}}
|
||||
honcho_cfg = SimpleNamespace(enabled=True, api_key="stale-key", base_url=None)
|
||||
|
||||
with (
|
||||
patch("hermes_cli.config.load_config", return_value=cfg),
|
||||
patch("hermes_cli.config.save_config") as save_config,
|
||||
patch(
|
||||
"plugins.memory.honcho.client.HonchoClientConfig.from_global_config",
|
||||
return_value=honcho_cfg,
|
||||
) as from_global_config,
|
||||
patch("plugins.memory.load_memory_provider") as load_memory_provider,
|
||||
patch("agent.model_metadata.get_model_context_length", return_value=204_800),
|
||||
patch("run_agent.get_tool_definitions", return_value=[]),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=False,
|
||||
)
|
||||
|
||||
assert agent._memory_manager is None
|
||||
from_global_config.assert_not_called()
|
||||
load_memory_provider.assert_not_called()
|
||||
save_config.assert_not_called()
|
||||
|
||||
|
|
@ -317,6 +317,60 @@ class TestStripThinkBlocks:
|
|||
result = agent._strip_think_blocks("<thought>orphaned reasoning without close")
|
||||
assert "<thought>" not in result
|
||||
|
||||
# ─── Unterminated-block coverage (#8878, #9568, #10408) ──────────────
|
||||
# Reasoning models served via NIM / MiniMax M2.7 frequently drop the
|
||||
# closing tag, leaking raw reasoning into assistant content. The open
|
||||
# tag appears at a block boundary (start of text or after a newline);
|
||||
# everything from that tag to end-of-string is stripped.
|
||||
|
||||
def test_unterminated_think_block_content_stripped(self, agent):
|
||||
"""Content after unterminated <think> is fully stripped."""
|
||||
result = agent._strip_think_blocks("<think>orphaned reasoning without close")
|
||||
assert "orphaned reasoning" not in result
|
||||
assert result.strip() == ""
|
||||
|
||||
def test_unterminated_thought_block_content_stripped(self, agent):
|
||||
"""Gemma-style <thought> with no close is fully stripped."""
|
||||
result = agent._strip_think_blocks("<thought>orphaned reasoning without close")
|
||||
assert "orphaned reasoning" not in result
|
||||
assert result.strip() == ""
|
||||
|
||||
def test_unterminated_multiline_block_stripped(self, agent):
|
||||
"""Multi-line unterminated blocks are stripped in full."""
|
||||
result = agent._strip_think_blocks(
|
||||
"<think>\nmulti\nline\nreasoning\nthat never closes"
|
||||
)
|
||||
assert "multi" not in result
|
||||
assert "never closes" not in result
|
||||
|
||||
def test_unterminated_block_after_answer_preserves_prefix(self, agent):
|
||||
"""Visible answer before a line-starting unterminated tag is kept."""
|
||||
result = agent._strip_think_blocks(
|
||||
"Answer is 42.\n<think>actually let me reconsider"
|
||||
)
|
||||
assert "Answer is 42." in result
|
||||
assert "reconsider" not in result
|
||||
|
||||
def test_inline_think_mention_in_prose_not_over_stripped(self, agent):
|
||||
"""Mid-line `<think>` mentioned in prose must not swallow the rest
|
||||
of the content (the block-boundary check prevents this)."""
|
||||
text = "Use the <think> tag like this in your prose."
|
||||
result = agent._strip_think_blocks(text)
|
||||
# Block-boundary check prevents unterminated-strip from firing
|
||||
assert "prose" in result
|
||||
assert "Use the" in result
|
||||
|
||||
def test_mixed_case_closed_pair_stripped(self, agent):
|
||||
"""Mixed-case variants <THINK>…</THINK>, <Thinking>…</Thinking> are
|
||||
handled by case-insensitive closed-pair regex, so the trailing
|
||||
content is preserved."""
|
||||
result = agent._strip_think_blocks("<THINK>upper</THINK>final")
|
||||
assert "upper" not in result
|
||||
assert "final" in result
|
||||
result = agent._strip_think_blocks("<Thinking>mixed</Thinking>final")
|
||||
assert "mixed" not in result
|
||||
assert "final" in result
|
||||
|
||||
|
||||
class TestExtractReasoning:
|
||||
def test_reasoning_field(self, agent):
|
||||
|
|
@ -1088,6 +1142,41 @@ class TestBuildAssistantMessage:
|
|||
result = agent._build_assistant_message(msg, "tool_calls")
|
||||
assert "extra_content" not in result["tool_calls"][0]
|
||||
|
||||
def test_think_blocks_stripped_from_content(self, agent):
|
||||
"""Inline <think> blocks are stripped from stored content (#8878, #9568).
|
||||
|
||||
The reasoning is captured into ``msg['reasoning']`` via the inline
|
||||
fallback in ``_extract_reasoning``; the raw tags in ``content`` are
|
||||
redundant and leak to messaging platforms / pollute titles /
|
||||
inflate context if left in place.
|
||||
"""
|
||||
msg = _mock_assistant_msg(
|
||||
content="<think>internal reasoning</think>The actual answer."
|
||||
)
|
||||
result = agent._build_assistant_message(msg, "stop")
|
||||
assert "<think>" not in result["content"]
|
||||
assert "internal reasoning" not in result["content"]
|
||||
assert "The actual answer." in result["content"]
|
||||
# Reasoning preserved separately via inline extraction fallback
|
||||
assert result["reasoning"] == "internal reasoning"
|
||||
|
||||
def test_think_blocks_stripped_preserves_normal_content(self, agent):
|
||||
"""Content without reasoning tags passes through unchanged."""
|
||||
msg = _mock_assistant_msg(content="No thinking here.")
|
||||
result = agent._build_assistant_message(msg, "stop")
|
||||
assert result["content"] == "No thinking here."
|
||||
|
||||
def test_unterminated_think_block_stripped(self, agent):
|
||||
"""Unterminated <think> block (MiniMax / NIM dropped close tag) is
|
||||
fully stripped from stored content."""
|
||||
msg = _mock_assistant_msg(
|
||||
content="<think>reasoning that never closes on this NIM endpoint"
|
||||
)
|
||||
result = agent._build_assistant_message(msg, "stop")
|
||||
assert "<think>" not in result["content"]
|
||||
assert "reasoning that never closes" not in result["content"]
|
||||
assert result["content"] == ""
|
||||
|
||||
|
||||
class TestFormatToolsForSystemMessage:
|
||||
def test_no_tools_returns_empty_array(self, agent):
|
||||
|
|
@ -1196,6 +1285,7 @@ class TestExecuteToolCalls:
|
|||
tc = _mock_tool_call(name="web_search", arguments='{"q":"test"}', call_id="c1")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc])
|
||||
messages = []
|
||||
agent.platform = "cli"
|
||||
agent.tool_progress_callback = None
|
||||
|
||||
with patch("run_agent.handle_function_call", return_value="search result"), \
|
||||
|
|
@ -1207,6 +1297,21 @@ class TestExecuteToolCalls:
|
|||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "tool"
|
||||
|
||||
def test_quiet_tool_output_suppressed_without_progress_callback_for_non_cli_agent(self, agent):
|
||||
tc = _mock_tool_call(name="web_search", arguments='{"q":"test"}', call_id="c1")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc])
|
||||
messages = []
|
||||
agent.platform = None
|
||||
agent.tool_progress_callback = None
|
||||
|
||||
with patch("run_agent.handle_function_call", return_value="search result"), \
|
||||
patch.object(agent, "_safe_print") as mock_print:
|
||||
agent._execute_tool_calls(mock_msg, messages, "task-1")
|
||||
|
||||
mock_print.assert_not_called()
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "tool"
|
||||
|
||||
def test_vprint_suppressed_in_parseable_quiet_mode(self, agent):
|
||||
agent.suppress_status_output = True
|
||||
|
||||
|
|
@ -1787,6 +1892,30 @@ class TestRunConversation:
|
|||
assert all("message_count" in c and "messages" not in c for c in pre_request_calls)
|
||||
assert all("usage" in c and "response" not in c for c in post_request_calls)
|
||||
|
||||
def test_content_with_tool_calls_stays_silent_for_non_cli_quiet_mode(self, agent):
|
||||
self._setup_agent(agent)
|
||||
agent.platform = None
|
||||
tc = _mock_tool_call(name="web_search", arguments="{}", call_id="c1")
|
||||
resp1 = _mock_response(
|
||||
content="I'll search for that.",
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=[tc],
|
||||
)
|
||||
resp2 = _mock_response(content="Done searching", finish_reason="stop")
|
||||
agent.client.chat.completions.create.side_effect = [resp1, resp2]
|
||||
|
||||
with (
|
||||
patch("run_agent.handle_function_call", return_value="search result"),
|
||||
patch.object(agent, "_safe_print") as mock_print,
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
result = agent.run_conversation("search something")
|
||||
|
||||
assert result["final_response"] == "Done searching"
|
||||
mock_print.assert_not_called()
|
||||
|
||||
def test_interrupt_breaks_loop(self, agent):
|
||||
self._setup_agent(agent)
|
||||
|
||||
|
|
|
|||
228
tests/run_agent/test_steer.py
Normal file
228
tests/run_agent/test_steer.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
"""Tests for AIAgent.steer() — mid-run user message injection.
|
||||
|
||||
/steer lets the user add a note to the agent's next tool result without
|
||||
interrupting the current tool call. The agent sees the note inline with
|
||||
tool output on its next iteration, preserving message-role alternation
|
||||
and prompt-cache integrity.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
def _bare_agent() -> AIAgent:
|
||||
"""Build an AIAgent without running __init__, then install the steer
|
||||
state manually — matches the existing object.__new__ stub pattern
|
||||
used elsewhere in the test suite.
|
||||
"""
|
||||
agent = object.__new__(AIAgent)
|
||||
agent._pending_steer = None
|
||||
agent._pending_steer_lock = threading.Lock()
|
||||
return agent
|
||||
|
||||
|
||||
class TestSteerAcceptance:
|
||||
def test_accepts_non_empty_text(self):
|
||||
agent = _bare_agent()
|
||||
assert agent.steer("go ahead and check the logs") is True
|
||||
assert agent._pending_steer == "go ahead and check the logs"
|
||||
|
||||
def test_rejects_empty_string(self):
|
||||
agent = _bare_agent()
|
||||
assert agent.steer("") is False
|
||||
assert agent._pending_steer is None
|
||||
|
||||
def test_rejects_whitespace_only(self):
|
||||
agent = _bare_agent()
|
||||
assert agent.steer(" \n\t ") is False
|
||||
assert agent._pending_steer is None
|
||||
|
||||
def test_rejects_none(self):
|
||||
agent = _bare_agent()
|
||||
assert agent.steer(None) is False # type: ignore[arg-type]
|
||||
assert agent._pending_steer is None
|
||||
|
||||
def test_strips_surrounding_whitespace(self):
|
||||
agent = _bare_agent()
|
||||
assert agent.steer(" hello world \n") is True
|
||||
assert agent._pending_steer == "hello world"
|
||||
|
||||
def test_concatenates_multiple_steers_with_newlines(self):
|
||||
agent = _bare_agent()
|
||||
agent.steer("first note")
|
||||
agent.steer("second note")
|
||||
agent.steer("third note")
|
||||
assert agent._pending_steer == "first note\nsecond note\nthird note"
|
||||
|
||||
|
||||
class TestSteerDrain:
|
||||
def test_drain_returns_and_clears(self):
|
||||
agent = _bare_agent()
|
||||
agent.steer("hello")
|
||||
assert agent._drain_pending_steer() == "hello"
|
||||
assert agent._pending_steer is None
|
||||
|
||||
def test_drain_on_empty_returns_none(self):
|
||||
agent = _bare_agent()
|
||||
assert agent._drain_pending_steer() is None
|
||||
|
||||
|
||||
class TestSteerInjection:
|
||||
def test_appends_to_last_tool_result(self):
|
||||
agent = _bare_agent()
|
||||
agent.steer("please also check auth.log")
|
||||
messages = [
|
||||
{"role": "user", "content": "what's in /var/log?"},
|
||||
{"role": "assistant", "tool_calls": [{"id": "a"}, {"id": "b"}]},
|
||||
{"role": "tool", "content": "ls output A", "tool_call_id": "a"},
|
||||
{"role": "tool", "content": "ls output B", "tool_call_id": "b"},
|
||||
]
|
||||
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=2)
|
||||
# The LAST tool result is modified; earlier ones are untouched.
|
||||
assert messages[2]["content"] == "ls output A"
|
||||
assert "ls output B" in messages[3]["content"]
|
||||
assert "[USER STEER" in messages[3]["content"]
|
||||
assert "please also check auth.log" in messages[3]["content"]
|
||||
# And pending_steer is consumed.
|
||||
assert agent._pending_steer is None
|
||||
|
||||
def test_no_op_when_no_steer_pending(self):
|
||||
agent = _bare_agent()
|
||||
messages = [
|
||||
{"role": "assistant", "tool_calls": [{"id": "a"}]},
|
||||
{"role": "tool", "content": "output", "tool_call_id": "a"},
|
||||
]
|
||||
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1)
|
||||
assert messages[-1]["content"] == "output" # unchanged
|
||||
|
||||
def test_no_op_when_num_tool_msgs_zero(self):
|
||||
agent = _bare_agent()
|
||||
agent.steer("steer")
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=0)
|
||||
# Steer should remain pending (nothing to drain into)
|
||||
assert agent._pending_steer == "steer"
|
||||
|
||||
def test_marker_is_unambiguous_about_origin(self):
|
||||
"""The injection marker must make clear the text is from the user
|
||||
and not tool output — this is the cache-safe way to signal
|
||||
provenance without violating message-role alternation.
|
||||
"""
|
||||
agent = _bare_agent()
|
||||
agent.steer("stop after next step")
|
||||
messages = [{"role": "tool", "content": "x", "tool_call_id": "1"}]
|
||||
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1)
|
||||
content = messages[-1]["content"]
|
||||
assert "USER STEER" in content
|
||||
assert "not tool output" in content.lower() or "injected mid-run" in content.lower()
|
||||
|
||||
def test_multimodal_content_list_preserved(self):
|
||||
"""Anthropic-style list content should be preserved, with the steer
|
||||
appended as a text block."""
|
||||
agent = _bare_agent()
|
||||
agent.steer("extra note")
|
||||
original_blocks = [{"type": "text", "text": "existing output"}]
|
||||
messages = [
|
||||
{"role": "tool", "content": list(original_blocks), "tool_call_id": "1"}
|
||||
]
|
||||
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1)
|
||||
new_content = messages[-1]["content"]
|
||||
assert isinstance(new_content, list)
|
||||
assert len(new_content) == 2
|
||||
assert new_content[0] == {"type": "text", "text": "existing output"}
|
||||
assert new_content[1]["type"] == "text"
|
||||
assert "extra note" in new_content[1]["text"]
|
||||
|
||||
def test_restashed_when_no_tool_result_in_batch(self):
|
||||
"""If the 'batch' contains no tool-role messages (e.g. all skipped
|
||||
after an interrupt), the steer should be put back into the pending
|
||||
slot so the caller's fallback path can deliver it."""
|
||||
agent = _bare_agent()
|
||||
agent.steer("ping")
|
||||
messages = [
|
||||
{"role": "user", "content": "x"},
|
||||
{"role": "assistant", "content": "y"},
|
||||
]
|
||||
# Claim there were N tool msgs, but the tail has none — simulates
|
||||
# the interrupt-cancelled case.
|
||||
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=2)
|
||||
# Messages untouched
|
||||
assert messages[-1]["content"] == "y"
|
||||
# And the steer is back in pending so the fallback can grab it
|
||||
assert agent._pending_steer == "ping"
|
||||
|
||||
|
||||
class TestSteerThreadSafety:
|
||||
def test_concurrent_steer_calls_preserve_all_text(self):
|
||||
agent = _bare_agent()
|
||||
N = 200
|
||||
|
||||
def worker(idx: int) -> None:
|
||||
agent.steer(f"note-{idx}")
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(N)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
text = agent._drain_pending_steer()
|
||||
assert text is not None
|
||||
# Every single note must be preserved — none dropped by the lock.
|
||||
lines = text.split("\n")
|
||||
assert len(lines) == N
|
||||
assert set(lines) == {f"note-{i}" for i in range(N)}
|
||||
|
||||
|
||||
class TestSteerClearedOnInterrupt:
|
||||
def test_clear_interrupt_drops_pending_steer(self):
|
||||
"""A hard interrupt supersedes any pending steer — the agent's
|
||||
next tool iteration won't happen, so delivering the steer later
|
||||
would be surprising."""
|
||||
agent = _bare_agent()
|
||||
# Minimal surface needed by clear_interrupt()
|
||||
agent._interrupt_requested = True
|
||||
agent._interrupt_message = None
|
||||
agent._interrupt_thread_signal_pending = False
|
||||
agent._execution_thread_id = None
|
||||
agent._tool_worker_threads = None
|
||||
agent._tool_worker_threads_lock = None
|
||||
|
||||
agent.steer("will be dropped")
|
||||
assert agent._pending_steer == "will be dropped"
|
||||
|
||||
agent.clear_interrupt()
|
||||
assert agent._pending_steer is None
|
||||
|
||||
|
||||
class TestSteerCommandRegistry:
|
||||
def test_steer_in_command_registry(self):
|
||||
"""The /steer slash command must be registered so it reaches all
|
||||
platforms (CLI, gateway, TUI autocomplete, Telegram/Slack menus).
|
||||
"""
|
||||
from hermes_cli.commands import resolve_command, ACTIVE_SESSION_BYPASS_COMMANDS
|
||||
|
||||
cmd = resolve_command("steer")
|
||||
assert cmd is not None
|
||||
assert cmd.name == "steer"
|
||||
assert cmd.category == "Session"
|
||||
assert cmd.args_hint == "<prompt>"
|
||||
|
||||
def test_steer_in_bypass_set(self):
|
||||
"""When the agent is running, /steer MUST bypass the Level-1
|
||||
base-adapter queue so it reaches the gateway runner's /steer
|
||||
handler. Otherwise it would be queued as user text and only
|
||||
delivered at turn end — defeating the whole point.
|
||||
"""
|
||||
from hermes_cli.commands import ACTIVE_SESSION_BYPASS_COMMANDS, should_bypass_active_session
|
||||
|
||||
assert "steer" in ACTIVE_SESSION_BYPASS_COMMANDS
|
||||
assert should_bypass_active_session("steer") is True
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
pytest.main([__file__, "-v"])
|
||||
|
|
@ -141,6 +141,50 @@ class TestStreamingAccumulator:
|
|||
assert tc[0].function.name == "terminal"
|
||||
assert tc[0].function.arguments == '{"command": "ls"}'
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_tool_name_not_duplicated_when_resent_per_chunk(self, mock_close, mock_create):
|
||||
"""MiniMax M2.7 via NVIDIA NIM resends the full name in every chunk.
|
||||
|
||||
Bug #8259: the old += accumulation produced "read_fileread_file".
|
||||
Assignment (matching OpenAI Node SDK / LiteLLM) prevents this.
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, tc_id="call_nim", name="read_file")
|
||||
]),
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, tc_id="call_nim", name="read_file", arguments='{"path":')
|
||||
]),
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, tc_id="call_nim", name="read_file", arguments=' "x.py"}')
|
||||
]),
|
||||
_make_stream_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
tc = response.choices[0].message.tool_calls
|
||||
assert tc is not None
|
||||
assert len(tc) == 1
|
||||
assert tc[0].function.name == "read_file"
|
||||
assert tc[0].function.arguments == '{"path": "x.py"}'
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_tool_call_extra_content_preserved(self, mock_close, mock_create):
|
||||
|
|
@ -952,3 +996,138 @@ class TestAnthropicStreamCallbacks:
|
|||
agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert touch_calls.count("receiving stream response") == len(events)
|
||||
|
||||
|
||||
class TestPartialToolCallWarning:
|
||||
"""Regression: when a stream dies mid tool-call argument generation after
|
||||
text was already delivered, the partial-stream stub at run_agent.py
|
||||
line ~6107 used to silently set ``tool_calls=None`` and return
|
||||
``finish_reason=stop``, losing the attempted action with zero user-facing
|
||||
signal. Live-observed Apr 2026 with MiniMax M2.7 on a 6-minute audit
|
||||
task — agent streamed commentary, emitted a write_file tool call,
|
||||
MiniMax stalled for 240 s mid-arguments, stale-stream detector killed
|
||||
the connection, the stub returned, session ended with no file written
|
||||
and no error shown.
|
||||
|
||||
Fix: when the stream accumulator captured any tool-call names before the
|
||||
error, the stub now appends a user-visible warning to content AND fires
|
||||
it as a stream delta so the user sees it immediately.
|
||||
"""
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_partial_tool_call_surfaces_warning(self, mock_close, mock_create):
|
||||
"""Stream with text + partial tool-call name + mid-stream error
|
||||
produces a stub whose content contains the user-visible warning
|
||||
and whose tool_calls is None."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
class _StallError(RuntimeError):
|
||||
pass
|
||||
|
||||
def _stalling_stream():
|
||||
yield _make_stream_chunk(content="Let me write the audit: ")
|
||||
yield _make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, tc_id="call_1", name="write_file"),
|
||||
])
|
||||
yield _make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, arguments='{"path": "/tmp/x", '),
|
||||
])
|
||||
raise _StallError("simulated upstream stall")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.side_effect = lambda *a, **kw: _stalling_stream()
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
fired_deltas: list = []
|
||||
agent._fire_stream_delta = lambda text: fired_deltas.append(text)
|
||||
agent._current_streamed_assistant_text = "Let me write the audit: "
|
||||
|
||||
import os as _os
|
||||
_prev = _os.environ.get("HERMES_STREAM_RETRIES")
|
||||
_os.environ["HERMES_STREAM_RETRIES"] = "0"
|
||||
try:
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
finally:
|
||||
if _prev is None:
|
||||
_os.environ.pop("HERMES_STREAM_RETRIES", None)
|
||||
else:
|
||||
_os.environ["HERMES_STREAM_RETRIES"] = _prev
|
||||
|
||||
content = response.choices[0].message.content or ""
|
||||
assert "Let me write the audit:" in content, (
|
||||
f"Partial text not preserved in stub: {content!r}"
|
||||
)
|
||||
assert "Stream stalled mid tool-call" in content, (
|
||||
f"Stub content is missing the dropped-tool-call warning; users "
|
||||
f"get silent failure. Got content={content!r}"
|
||||
)
|
||||
assert "write_file" in content, (
|
||||
f"Warning should name the dropped tool. Got: {content!r}"
|
||||
)
|
||||
assert response.choices[0].message.tool_calls is None
|
||||
assert any("Stream stalled mid tool-call" in d for d in fired_deltas), (
|
||||
f"Warning was not surfaced as a live stream delta. "
|
||||
f"fired_deltas={fired_deltas}"
|
||||
)
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_partial_text_only_no_warning(self, mock_close, mock_create):
|
||||
"""Text-only partial stream (no tool call mid-flight) keeps the
|
||||
pre-fix behaviour: bare recovered text, no warning noise."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
class _StallError(RuntimeError):
|
||||
pass
|
||||
|
||||
def _stalling_stream():
|
||||
yield _make_stream_chunk(content="Here's my answer so far")
|
||||
raise _StallError("simulated upstream stall")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.side_effect = lambda *a, **kw: _stalling_stream()
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
agent._current_streamed_assistant_text = "Here's my answer so far"
|
||||
|
||||
import os as _os
|
||||
_prev = _os.environ.get("HERMES_STREAM_RETRIES")
|
||||
_os.environ["HERMES_STREAM_RETRIES"] = "0"
|
||||
try:
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
finally:
|
||||
if _prev is None:
|
||||
_os.environ.pop("HERMES_STREAM_RETRIES", None)
|
||||
else:
|
||||
_os.environ["HERMES_STREAM_RETRIES"] = _prev
|
||||
|
||||
content = response.choices[0].message.content or ""
|
||||
assert content == "Here's my answer so far", (
|
||||
f"Pre-fix behaviour regressed for text-only partial streams: {content!r}"
|
||||
)
|
||||
assert "Stream stalled" not in content, (
|
||||
f"Unexpected warning on text-only partial stream: {content!r}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -479,6 +479,141 @@ class TestFTS5Search:
|
|||
assert s('my-app.config.ts') == '"my-app.config.ts"'
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# CJK (Chinese/Japanese/Korean) LIKE fallback
|
||||
# =========================================================================
|
||||
|
||||
class TestCJKSearchFallback:
|
||||
"""Regression tests for CJK search (see #11511).
|
||||
|
||||
SQLite FTS5's default tokenizer treats contiguous CJK runs as a single
|
||||
token ("和其他agent的聊天记录" → one token), so substring queries like
|
||||
"记忆断裂" return 0 rows despite the data being present. SessionDB falls
|
||||
back to LIKE substring matching whenever FTS5 returns no results and
|
||||
the query contains CJK characters.
|
||||
"""
|
||||
|
||||
def test_cjk_detection_covers_all_ranges(self):
|
||||
from hermes_state import SessionDB
|
||||
f = SessionDB._contains_cjk
|
||||
# Chinese (CJK Unified Ideographs)
|
||||
assert f("记忆断裂") is True
|
||||
# Japanese Hiragana + Katakana
|
||||
assert f("こんにちは") is True
|
||||
assert f("カタカナ") is True
|
||||
# Korean Hangul syllables (both early and late — guards against
|
||||
# the \ud7a0-\ud7af typo seen in one of the duplicate PRs)
|
||||
assert f("안녕하세요") is True
|
||||
assert f("기억") is True
|
||||
# Non-CJK
|
||||
assert f("hello world") is False
|
||||
assert f("日本語mixedwithenglish") is True
|
||||
assert f("") is False
|
||||
|
||||
def test_chinese_multichar_query_returns_results(self, db):
|
||||
"""The headline bug: multi-char Chinese query must not return []."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message(
|
||||
"s1", role="user",
|
||||
content="昨天和其他Agent的聊天记录,记忆断裂问题复现了",
|
||||
)
|
||||
results = db.search_messages("记忆断裂")
|
||||
assert len(results) == 1
|
||||
assert results[0]["session_id"] == "s1"
|
||||
|
||||
def test_chinese_bigram_query(self, db):
|
||||
db.create_session(session_id="s1", source="telegram")
|
||||
db.append_message("s1", role="user", content="今天讨论A2A通信协议的实现")
|
||||
results = db.search_messages("通信")
|
||||
assert len(results) == 1
|
||||
|
||||
def test_korean_query_returns_results(self, db):
|
||||
"""Guards against Hangul range typos (\\uac00-\\ud7af, not \\ud7a0-)."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message("s1", role="user", content="안녕하세요 반갑습니다")
|
||||
results = db.search_messages("안녕")
|
||||
assert len(results) == 1
|
||||
|
||||
def test_japanese_query_returns_results(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message("s1", role="user", content="こんにちは世界")
|
||||
assert len(db.search_messages("こんにちは")) == 1
|
||||
assert len(db.search_messages("世界")) == 1
|
||||
|
||||
def test_cjk_fallback_preserves_source_filter(self, db):
|
||||
"""Guards against the SQL-builder bug where filter clauses land
|
||||
after LIMIT/OFFSET (seen in one of the duplicate PRs)."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.create_session(session_id="s2", source="telegram")
|
||||
db.append_message("s1", role="user", content="记忆断裂在CLI")
|
||||
db.append_message("s2", role="user", content="记忆断裂在Telegram")
|
||||
|
||||
results = db.search_messages("记忆断裂", source_filter=["telegram"])
|
||||
assert len(results) == 1
|
||||
assert results[0]["source"] == "telegram"
|
||||
|
||||
def test_cjk_fallback_preserves_exclude_sources(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.create_session(session_id="s2", source="tool")
|
||||
db.append_message("s1", role="user", content="记忆断裂在CLI")
|
||||
db.append_message("s2", role="assistant", content="记忆断裂在tool")
|
||||
|
||||
results = db.search_messages("记忆断裂", exclude_sources=["tool"])
|
||||
sources = {r["source"] for r in results}
|
||||
assert "tool" not in sources
|
||||
assert "cli" in sources
|
||||
|
||||
def test_cjk_fallback_preserves_role_filter(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message("s1", role="user", content="用户说的记忆断裂")
|
||||
db.append_message("s1", role="assistant", content="助手说的记忆断裂")
|
||||
|
||||
results = db.search_messages("记忆断裂", role_filter=["assistant"])
|
||||
assert len(results) == 1
|
||||
assert results[0]["role"] == "assistant"
|
||||
|
||||
def test_cjk_snippet_is_centered_on_match(self, db):
|
||||
"""Snippet should contain the search term, not just the first N chars."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
long_prefix = "这是一段很长的前缀用来把匹配位置推到文档中间" * 3
|
||||
long_suffix = "这是一段很长的后缀内容填充剩余空间" * 3
|
||||
db.append_message(
|
||||
"s1", role="user",
|
||||
content=f"{long_prefix}记忆断裂{long_suffix}",
|
||||
)
|
||||
results = db.search_messages("记忆断裂")
|
||||
assert len(results) == 1
|
||||
# The centered substr() snippet must include the matched term.
|
||||
assert "记忆断裂" in results[0]["snippet"]
|
||||
|
||||
def test_english_query_still_uses_fts5_fast_path(self, db):
|
||||
"""English queries must not trigger the LIKE fallback (fast path regression)."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message("s1", role="user", content="Deploy docker containers")
|
||||
results = db.search_messages("docker")
|
||||
assert len(results) == 1
|
||||
# No CJK in query → LIKE fallback must not run. We don't assert this
|
||||
# directly (no instrumentation), but the FTS5 path produces an
|
||||
# FTS5-style snippet with highlight markers when the term is short.
|
||||
# At minimum: english queries must still match.
|
||||
|
||||
def test_cjk_query_with_no_matches_returns_empty(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message("s1", role="user", content="unrelated English content")
|
||||
results = db.search_messages("记忆断裂")
|
||||
assert results == []
|
||||
|
||||
def test_mixed_cjk_english_query(self, db):
|
||||
"""Mixed queries should still fall back to LIKE when FTS5 misses."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message("s1", role="user", content="讨论Agent通信协议")
|
||||
# "Agent通信" is CJK+English — FTS5 default tokenizer indexes the
|
||||
# whole CJK run with embedded "agent" as separate tokens; the LIKE
|
||||
# fallback handles the substring correctly.
|
||||
results = db.search_messages("Agent通信")
|
||||
assert len(results) == 1
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Session search and listing
|
||||
# =========================================================================
|
||||
|
|
|
|||
|
|
@ -363,6 +363,28 @@ def test_image_attach_appends_local_image(monkeypatch):
|
|||
assert len(server._sessions["sid"]["attached_images"]) == 1
|
||||
|
||||
|
||||
def test_commands_catalog_surfaces_quick_commands(monkeypatch):
|
||||
monkeypatch.setattr(server, "_load_cfg", lambda: {"quick_commands": {
|
||||
"build": {"type": "exec", "command": "npm run build"},
|
||||
"git": {"type": "alias", "target": "/shell git"},
|
||||
"notes": {"type": "exec", "command": "cat NOTES.md", "description": "Open design notes"},
|
||||
}})
|
||||
|
||||
resp = server.handle_request({"id": "1", "method": "commands.catalog", "params": {}})
|
||||
|
||||
pairs = dict(resp["result"]["pairs"])
|
||||
assert "npm run build" in pairs["/build"]
|
||||
assert pairs["/git"].startswith("alias →")
|
||||
assert pairs["/notes"] == "Open design notes"
|
||||
|
||||
user_cat = next(c for c in resp["result"]["categories"] if c["name"] == "User commands")
|
||||
user_pairs = dict(user_cat["pairs"])
|
||||
assert set(user_pairs) == {"/build", "/git", "/notes"}
|
||||
|
||||
assert resp["result"]["canon"]["/build"] == "/build"
|
||||
assert resp["result"]["canon"]["/notes"] == "/notes"
|
||||
|
||||
|
||||
def test_command_dispatch_exec_nonzero_surfaces_error(monkeypatch):
|
||||
monkeypatch.setattr(server, "_load_cfg", lambda: {"quick_commands": {"boom": {"type": "exec", "command": "boom"}}})
|
||||
monkeypatch.setattr(
|
||||
|
|
@ -438,3 +460,651 @@ def test_rollback_restore_resolves_number_and_file_path():
|
|||
assert resp["result"]["success"] is True
|
||||
assert calls["args"][1] == "bbb222"
|
||||
assert calls["args"][2] == "src/app.tsx"
|
||||
|
||||
|
||||
# ── session.steer ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_session_steer_calls_agent_steer_when_agent_supports_it():
|
||||
"""The TUI RPC method must call agent.steer(text) and return a
|
||||
queued status without touching interrupt state.
|
||||
"""
|
||||
calls = {}
|
||||
|
||||
class _Agent:
|
||||
def steer(self, text):
|
||||
calls["steer_text"] = text
|
||||
return True
|
||||
|
||||
def interrupt(self, *args, **kwargs):
|
||||
calls["interrupt_called"] = True
|
||||
|
||||
server._sessions["sid"] = _session(agent=_Agent())
|
||||
try:
|
||||
resp = server.handle_request(
|
||||
{
|
||||
"id": "1",
|
||||
"method": "session.steer",
|
||||
"params": {"session_id": "sid", "text": "also check auth.log"},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
|
||||
assert "result" in resp, resp
|
||||
assert resp["result"]["status"] == "queued"
|
||||
assert resp["result"]["text"] == "also check auth.log"
|
||||
assert calls["steer_text"] == "also check auth.log"
|
||||
assert "interrupt_called" not in calls # must NOT interrupt
|
||||
|
||||
|
||||
def test_session_steer_rejects_empty_text():
|
||||
server._sessions["sid"] = _session(agent=types.SimpleNamespace(steer=lambda t: True))
|
||||
try:
|
||||
resp = server.handle_request(
|
||||
{
|
||||
"id": "1",
|
||||
"method": "session.steer",
|
||||
"params": {"session_id": "sid", "text": " "},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
|
||||
assert "error" in resp, resp
|
||||
assert resp["error"]["code"] == 4002
|
||||
|
||||
|
||||
def test_session_steer_errors_when_agent_has_no_steer_method():
|
||||
server._sessions["sid"] = _session(agent=types.SimpleNamespace()) # no steer()
|
||||
try:
|
||||
resp = server.handle_request(
|
||||
{
|
||||
"id": "1",
|
||||
"method": "session.steer",
|
||||
"params": {"session_id": "sid", "text": "hi"},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
|
||||
assert "error" in resp, resp
|
||||
assert resp["error"]["code"] == 4010
|
||||
|
||||
|
||||
def test_session_info_includes_mcp_servers(monkeypatch):
|
||||
fake_status = [
|
||||
{"name": "github", "transport": "http", "tools": 12, "connected": True},
|
||||
{"name": "filesystem", "transport": "stdio", "tools": 4, "connected": True},
|
||||
{"name": "broken", "transport": "stdio", "tools": 0, "connected": False},
|
||||
]
|
||||
fake_mod = types.ModuleType("tools.mcp_tool")
|
||||
fake_mod.get_mcp_status = lambda: fake_status
|
||||
monkeypatch.setitem(sys.modules, "tools.mcp_tool", fake_mod)
|
||||
|
||||
info = server._session_info(types.SimpleNamespace(tools=[], model=""))
|
||||
|
||||
assert info["mcp_servers"] == fake_status
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# History-mutating commands must reject while session.running is True.
|
||||
# Without these guards, prompt.submit's post-run history write either
|
||||
# clobbers the mutation (version matches) or silently drops the agent's
|
||||
# output (version mismatch) — both produce UI<->backend state desync.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_session_undo_rejects_while_running():
|
||||
"""Fix for TUI silent-drop #1: /undo must not mutate history
|
||||
while the agent is mid-turn — would either clobber the undo or
|
||||
cause prompt.submit to silently drop the agent's response."""
|
||||
server._sessions["sid"] = _session(running=True, history=[
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "hello"},
|
||||
])
|
||||
try:
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "session.undo", "params": {"session_id": "sid"}}
|
||||
)
|
||||
assert resp.get("error"), "session.undo should reject while running"
|
||||
assert resp["error"]["code"] == 4009
|
||||
assert "session busy" in resp["error"]["message"]
|
||||
# History must be unchanged
|
||||
assert len(server._sessions["sid"]["history"]) == 2
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
|
||||
|
||||
def test_session_undo_allowed_when_idle():
|
||||
"""Regression guard: when not running, /undo still works."""
|
||||
server._sessions["sid"] = _session(running=False, history=[
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "hello"},
|
||||
])
|
||||
try:
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "session.undo", "params": {"session_id": "sid"}}
|
||||
)
|
||||
assert resp.get("result"), f"got error: {resp.get('error')}"
|
||||
assert resp["result"]["removed"] == 2
|
||||
assert server._sessions["sid"]["history"] == []
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
|
||||
|
||||
def test_session_compress_rejects_while_running(monkeypatch):
|
||||
server._sessions["sid"] = _session(running=True)
|
||||
try:
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "session.compress", "params": {"session_id": "sid"}}
|
||||
)
|
||||
assert resp.get("error")
|
||||
assert resp["error"]["code"] == 4009
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
|
||||
|
||||
def test_rollback_restore_rejects_full_history_while_running(monkeypatch):
|
||||
"""Full-history rollback must reject; file-scoped rollback still allowed."""
|
||||
server._sessions["sid"] = _session(running=True)
|
||||
try:
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "rollback.restore", "params": {"session_id": "sid", "hash": "abc"}}
|
||||
)
|
||||
assert resp.get("error"), "full-history rollback should reject while running"
|
||||
assert resp["error"]["code"] == 4009
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
|
||||
|
||||
def test_prompt_submit_history_version_mismatch_surfaces_warning(monkeypatch):
|
||||
"""Fix for TUI silent-drop #2: the defensive backstop at prompt.submit
|
||||
must attach a 'warning' to message.complete when history was
|
||||
mutated externally during the turn (instead of silently dropping
|
||||
the agent's output)."""
|
||||
# Agent bumps history_version itself mid-run to simulate an external
|
||||
# mutation slipping past the guards.
|
||||
session_ref = {"s": None}
|
||||
|
||||
class _RacyAgent:
|
||||
def run_conversation(self, prompt, conversation_history=None, stream_callback=None):
|
||||
# Simulate: something external bumped history_version
|
||||
# while we were running.
|
||||
with session_ref["s"]["history_lock"]:
|
||||
session_ref["s"]["history_version"] += 1
|
||||
return {"final_response": "agent reply", "messages": [{"role": "assistant", "content": "agent reply"}]}
|
||||
|
||||
class _ImmediateThread:
|
||||
def __init__(self, target=None, daemon=None):
|
||||
self._target = target
|
||||
|
||||
def start(self):
|
||||
self._target()
|
||||
|
||||
server._sessions["sid"] = _session(agent=_RacyAgent())
|
||||
session_ref["s"] = server._sessions["sid"]
|
||||
emits: list[tuple] = []
|
||||
try:
|
||||
monkeypatch.setattr(server.threading, "Thread", _ImmediateThread)
|
||||
monkeypatch.setattr(server, "_get_usage", lambda _a: {})
|
||||
monkeypatch.setattr(server, "render_message", lambda _t, _c: "")
|
||||
monkeypatch.setattr(server, "_emit", lambda *a: emits.append(a))
|
||||
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "prompt.submit", "params": {"session_id": "sid", "text": "hi"}}
|
||||
)
|
||||
assert resp.get("result"), f"got error: {resp.get('error')}"
|
||||
|
||||
# History should NOT contain the agent's output (version mismatch)
|
||||
assert server._sessions["sid"]["history"] == []
|
||||
|
||||
# message.complete must carry a 'warning' so the UI / operator
|
||||
# knows the output was not persisted.
|
||||
complete_calls = [a for a in emits if a[0] == "message.complete"]
|
||||
assert len(complete_calls) == 1
|
||||
_, _, payload = complete_calls[0]
|
||||
assert "warning" in payload, (
|
||||
"message.complete must include a 'warning' field on "
|
||||
"history_version mismatch — otherwise the UI silently "
|
||||
"shows output that was never persisted"
|
||||
)
|
||||
assert "not saved" in payload["warning"].lower() or "changed" in payload["warning"].lower()
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
|
||||
|
||||
def test_prompt_submit_history_version_match_persists_normally(monkeypatch):
|
||||
"""Regression guard: the backstop does not affect the happy path."""
|
||||
class _Agent:
|
||||
def run_conversation(self, prompt, conversation_history=None, stream_callback=None):
|
||||
return {"final_response": "reply", "messages": [{"role": "assistant", "content": "reply"}]}
|
||||
|
||||
class _ImmediateThread:
|
||||
def __init__(self, target=None, daemon=None):
|
||||
self._target = target
|
||||
|
||||
def start(self):
|
||||
self._target()
|
||||
|
||||
server._sessions["sid"] = _session(agent=_Agent())
|
||||
emits: list[tuple] = []
|
||||
try:
|
||||
monkeypatch.setattr(server.threading, "Thread", _ImmediateThread)
|
||||
monkeypatch.setattr(server, "_get_usage", lambda _a: {})
|
||||
monkeypatch.setattr(server, "render_message", lambda _t, _c: "")
|
||||
monkeypatch.setattr(server, "_emit", lambda *a: emits.append(a))
|
||||
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "prompt.submit", "params": {"session_id": "sid", "text": "hi"}}
|
||||
)
|
||||
assert resp.get("result")
|
||||
|
||||
# History was written
|
||||
assert server._sessions["sid"]["history"] == [{"role": "assistant", "content": "reply"}]
|
||||
assert server._sessions["sid"]["history_version"] == 1
|
||||
|
||||
# No warning should be attached
|
||||
complete_calls = [a for a in emits if a[0] == "message.complete"]
|
||||
assert len(complete_calls) == 1
|
||||
_, _, payload = complete_calls[0]
|
||||
assert "warning" not in payload
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session.interrupt must only cancel pending prompts owned by the calling
|
||||
# session — it must not blast-resolve clarify/sudo/secret prompts on
|
||||
# unrelated sessions sharing the same tui_gateway process. Without
|
||||
# session scoping the other sessions' prompts silently resolve to empty
|
||||
# strings, unblocking their agent threads as if the user cancelled.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_interrupt_only_clears_own_session_pending():
|
||||
"""session.interrupt on session A must NOT release pending prompts
|
||||
that belong to session B."""
|
||||
import types
|
||||
|
||||
session_a = _session()
|
||||
session_a["agent"] = types.SimpleNamespace(interrupt=lambda: None)
|
||||
session_b = _session()
|
||||
session_b["agent"] = types.SimpleNamespace(interrupt=lambda: None)
|
||||
server._sessions["sid_a"] = session_a
|
||||
server._sessions["sid_b"] = session_b
|
||||
|
||||
try:
|
||||
# Simulate pending prompts on both sessions (what _block creates
|
||||
# while a clarify/sudo/secret request is outstanding).
|
||||
ev_a = threading.Event()
|
||||
ev_b = threading.Event()
|
||||
server._pending["rid-a"] = ("sid_a", ev_a)
|
||||
server._pending["rid-b"] = ("sid_b", ev_b)
|
||||
server._answers.clear()
|
||||
|
||||
# Interrupt session A.
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "session.interrupt", "params": {"session_id": "sid_a"}}
|
||||
)
|
||||
assert resp.get("result"), f"got error: {resp.get('error')}"
|
||||
|
||||
# Session A's pending must be released to empty.
|
||||
assert ev_a.is_set(), "sid_a pending Event should be set after interrupt"
|
||||
assert server._answers.get("rid-a") == ""
|
||||
|
||||
# Session B's pending MUST remain untouched — no cross-session blast.
|
||||
assert not ev_b.is_set(), (
|
||||
"CRITICAL: session.interrupt on sid_a released a pending prompt "
|
||||
"belonging to sid_b — other sessions' clarify/sudo/secret "
|
||||
"prompts are being silently cancelled"
|
||||
)
|
||||
assert "rid-b" not in server._answers
|
||||
finally:
|
||||
server._sessions.pop("sid_a", None)
|
||||
server._sessions.pop("sid_b", None)
|
||||
server._pending.pop("rid-a", None)
|
||||
server._pending.pop("rid-b", None)
|
||||
server._answers.pop("rid-a", None)
|
||||
server._answers.pop("rid-b", None)
|
||||
|
||||
|
||||
def test_interrupt_clears_multiple_own_pending():
|
||||
"""When a single session has multiple pending prompts (uncommon but
|
||||
possible via nested tool calls), interrupt must release all of them."""
|
||||
import types
|
||||
|
||||
sess = _session()
|
||||
sess["agent"] = types.SimpleNamespace(interrupt=lambda: None)
|
||||
server._sessions["sid"] = sess
|
||||
|
||||
try:
|
||||
ev1, ev2 = threading.Event(), threading.Event()
|
||||
server._pending["r1"] = ("sid", ev1)
|
||||
server._pending["r2"] = ("sid", ev2)
|
||||
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "session.interrupt", "params": {"session_id": "sid"}}
|
||||
)
|
||||
assert resp.get("result")
|
||||
assert ev1.is_set() and ev2.is_set()
|
||||
assert server._answers.get("r1") == "" and server._answers.get("r2") == ""
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
for key in ("r1", "r2"):
|
||||
server._pending.pop(key, None)
|
||||
server._answers.pop(key, None)
|
||||
|
||||
|
||||
def test_clear_pending_without_sid_clears_all():
|
||||
"""_clear_pending(None) is the shutdown path — must still release
|
||||
every pending prompt regardless of owning session."""
|
||||
ev1, ev2, ev3 = threading.Event(), threading.Event(), threading.Event()
|
||||
server._pending["a"] = ("sid_x", ev1)
|
||||
server._pending["b"] = ("sid_y", ev2)
|
||||
server._pending["c"] = ("sid_z", ev3)
|
||||
try:
|
||||
server._clear_pending(None)
|
||||
assert ev1.is_set() and ev2.is_set() and ev3.is_set()
|
||||
finally:
|
||||
for key in ("a", "b", "c"):
|
||||
server._pending.pop(key, None)
|
||||
server._answers.pop(key, None)
|
||||
|
||||
|
||||
def test_respond_unpacks_sid_tuple_correctly():
|
||||
"""After the (sid, Event) tuple change, _respond must still work."""
|
||||
ev = threading.Event()
|
||||
server._pending["rid-x"] = ("sid_x", ev)
|
||||
try:
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "clarify.respond",
|
||||
"params": {"request_id": "rid-x", "answer": "the answer"}}
|
||||
)
|
||||
assert resp.get("result")
|
||||
assert ev.is_set()
|
||||
assert server._answers.get("rid-x") == "the answer"
|
||||
finally:
|
||||
server._pending.pop("rid-x", None)
|
||||
server._answers.pop("rid-x", None)
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /model switch and other agent-mutating commands must reject while the
|
||||
# session is running. agent.switch_model() mutates self.model, self.provider,
|
||||
# self.base_url, self.client etc. in place — the worker thread running
|
||||
# agent.run_conversation is reading those on every iteration. Same class of
|
||||
# bug as the session.undo / session.compress mid-run silent-drop; same fix
|
||||
# pattern: reject with 4009 while running.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_config_set_model_rejects_while_running(monkeypatch):
|
||||
"""/model via config.set must reject during an in-flight turn."""
|
||||
seen = {"called": False}
|
||||
|
||||
def _fake_apply(sid, session, raw):
|
||||
seen["called"] = True
|
||||
return {"value": raw, "warning": ""}
|
||||
|
||||
monkeypatch.setattr(server, "_apply_model_switch", _fake_apply)
|
||||
|
||||
server._sessions["sid"] = _session(running=True)
|
||||
try:
|
||||
resp = server.handle_request({
|
||||
"id": "1", "method": "config.set",
|
||||
"params": {"session_id": "sid", "key": "model", "value": "anthropic/claude-sonnet-4.6"},
|
||||
})
|
||||
assert resp.get("error")
|
||||
assert resp["error"]["code"] == 4009
|
||||
assert "session busy" in resp["error"]["message"]
|
||||
assert not seen["called"], (
|
||||
"_apply_model_switch was called mid-turn — would race with "
|
||||
"the worker thread reading agent.model / agent.client"
|
||||
)
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
|
||||
|
||||
def test_config_set_model_allowed_when_idle(monkeypatch):
|
||||
"""Regression guard: idle sessions can still switch models."""
|
||||
seen = {"called": False}
|
||||
|
||||
def _fake_apply(sid, session, raw):
|
||||
seen["called"] = True
|
||||
return {"value": "newmodel", "warning": ""}
|
||||
|
||||
monkeypatch.setattr(server, "_apply_model_switch", _fake_apply)
|
||||
|
||||
server._sessions["sid"] = _session(running=False)
|
||||
try:
|
||||
resp = server.handle_request({
|
||||
"id": "1", "method": "config.set",
|
||||
"params": {"session_id": "sid", "key": "model", "value": "newmodel"},
|
||||
})
|
||||
assert resp.get("result")
|
||||
assert resp["result"]["value"] == "newmodel"
|
||||
assert seen["called"]
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
|
||||
|
||||
def test_mirror_slash_side_effects_rejects_mutating_commands_while_running(monkeypatch):
|
||||
"""Slash worker passthrough (e.g. /model, /personality, /prompt,
|
||||
/compress) must reject during an in-flight turn. Same race as
|
||||
config.set — mutates live agent state while run_conversation is
|
||||
reading it."""
|
||||
import types
|
||||
|
||||
applied = {"model": False, "compress": False}
|
||||
|
||||
def _fake_apply_model(sid, session, arg):
|
||||
applied["model"] = True
|
||||
return {"value": arg, "warning": ""}
|
||||
|
||||
def _fake_compress(session, focus):
|
||||
applied["compress"] = True
|
||||
return (0, {})
|
||||
|
||||
monkeypatch.setattr(server, "_apply_model_switch", _fake_apply_model)
|
||||
monkeypatch.setattr(server, "_compress_session_history", _fake_compress)
|
||||
|
||||
session = _session(running=True)
|
||||
session["agent"] = types.SimpleNamespace(model="x")
|
||||
|
||||
for cmd, expected_name in [
|
||||
("/model new/model", "model"),
|
||||
("/personality default", "personality"),
|
||||
("/prompt", "prompt"),
|
||||
("/compress", "compress"),
|
||||
]:
|
||||
warning = server._mirror_slash_side_effects("sid", session, cmd)
|
||||
assert "session busy" in warning, (
|
||||
f"{cmd} should have returned busy warning, got: {warning!r}"
|
||||
)
|
||||
assert f"/{expected_name}" in warning
|
||||
|
||||
# None of the mutating side-effect helpers should have fired.
|
||||
assert not applied["model"], "model switch fired despite running session"
|
||||
assert not applied["compress"], "compress fired despite running session"
|
||||
|
||||
|
||||
def test_mirror_slash_side_effects_allowed_when_idle(monkeypatch):
|
||||
"""Regression guard: idle session still runs the side effects."""
|
||||
import types
|
||||
|
||||
applied = {"model": False}
|
||||
|
||||
def _fake_apply_model(sid, session, arg):
|
||||
applied["model"] = True
|
||||
return {"value": arg, "warning": ""}
|
||||
|
||||
monkeypatch.setattr(server, "_apply_model_switch", _fake_apply_model)
|
||||
|
||||
session = _session(running=False)
|
||||
session["agent"] = types.SimpleNamespace(model="x")
|
||||
|
||||
warning = server._mirror_slash_side_effects("sid", session, "/model foo")
|
||||
# Should NOT contain "session busy" — the switch went through.
|
||||
assert "session busy" not in warning
|
||||
assert applied["model"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session.create / session.close race: fast /new churn must not orphan the
|
||||
# slash_worker subprocess or the global approval-notify registration.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_session_create_close_race_does_not_orphan_worker(monkeypatch):
|
||||
"""Regression guard: if session.close runs while session.create's
|
||||
_build thread is still constructing the agent, the build thread
|
||||
must detect the orphan and clean up the slash_worker + notify
|
||||
registration it's about to install. Without the cleanup those
|
||||
resources leak — the subprocess stays alive until atexit and the
|
||||
notify callback lingers in the global registry."""
|
||||
import threading
|
||||
|
||||
closed_workers: list[str] = []
|
||||
unregistered_keys: list[str] = []
|
||||
|
||||
class _FakeWorker:
|
||||
def __init__(self, key, model):
|
||||
self.key = key
|
||||
self._closed = False
|
||||
|
||||
def close(self):
|
||||
self._closed = True
|
||||
closed_workers.append(self.key)
|
||||
|
||||
class _FakeAgent:
|
||||
def __init__(self):
|
||||
self.model = "x"
|
||||
self.provider = "openrouter"
|
||||
self.base_url = ""
|
||||
self.api_key = ""
|
||||
|
||||
# Make _build block until we release it — simulates slow agent init
|
||||
release_build = threading.Event()
|
||||
|
||||
def _slow_make_agent(sid, key):
|
||||
release_build.wait(timeout=3.0)
|
||||
return _FakeAgent()
|
||||
|
||||
# Stub everything _build touches
|
||||
monkeypatch.setattr(server, "_make_agent", _slow_make_agent)
|
||||
monkeypatch.setattr(server, "_SlashWorker", _FakeWorker)
|
||||
monkeypatch.setattr(server, "_get_db", lambda: types.SimpleNamespace(create_session=lambda *a, **kw: None))
|
||||
monkeypatch.setattr(server, "_session_info", lambda _a: {"model": "x"})
|
||||
monkeypatch.setattr(server, "_probe_credentials", lambda _a: None)
|
||||
monkeypatch.setattr(server, "_wire_callbacks", lambda _sid: None)
|
||||
monkeypatch.setattr(server, "_emit", lambda *a, **kw: None)
|
||||
|
||||
# Shim register/unregister to observe leaks
|
||||
import tools.approval as _approval
|
||||
monkeypatch.setattr(_approval, "register_gateway_notify",
|
||||
lambda key, cb: None)
|
||||
monkeypatch.setattr(_approval, "unregister_gateway_notify",
|
||||
lambda key: unregistered_keys.append(key))
|
||||
monkeypatch.setattr(_approval, "load_permanent_allowlist", lambda: None)
|
||||
|
||||
# Start: session.create spawns _build thread, returns synchronously
|
||||
resp = server.handle_request({
|
||||
"id": "1", "method": "session.create", "params": {"cols": 80},
|
||||
})
|
||||
assert resp.get("result"), f"got error: {resp.get('error')}"
|
||||
sid = resp["result"]["session_id"]
|
||||
|
||||
# Build thread is blocked in _slow_make_agent. Close the session
|
||||
# NOW — this pops _sessions[sid] before _build can install the
|
||||
# worker/notify.
|
||||
close_resp = server.handle_request({
|
||||
"id": "2", "method": "session.close", "params": {"session_id": sid},
|
||||
})
|
||||
assert close_resp.get("result", {}).get("closed") is True
|
||||
|
||||
# At this point session.close saw slash_worker=None (not yet
|
||||
# installed) so it didn't close anything. Release the build thread
|
||||
# and let it finish — it should detect the orphan and clean up the
|
||||
# worker it just allocated + unregister the notify.
|
||||
release_build.set()
|
||||
|
||||
# Give the build thread a moment to run through its finally.
|
||||
for _ in range(100):
|
||||
if closed_workers:
|
||||
break
|
||||
import time
|
||||
time.sleep(0.02)
|
||||
|
||||
assert len(closed_workers) == 1, (
|
||||
f"orphan worker was not cleaned up — closed_workers={closed_workers}"
|
||||
)
|
||||
# Notify may be unregistered by both session.close (unconditional)
|
||||
# and the orphan-cleanup path; the key guarantee is that the build
|
||||
# thread does at least one unregister call (any prior close
|
||||
# already popped the callback; the duplicate is a no-op).
|
||||
assert len(unregistered_keys) >= 1, (
|
||||
f"orphan notify registration was not unregistered — "
|
||||
f"unregistered_keys={unregistered_keys}"
|
||||
)
|
||||
|
||||
|
||||
def test_session_create_no_race_keeps_worker_alive(monkeypatch):
|
||||
"""Regression guard: when session.close does NOT race, the build
|
||||
thread must install the worker + notify normally and leave them
|
||||
alone (no over-eager cleanup)."""
|
||||
closed_workers: list[str] = []
|
||||
unregistered_keys: list[str] = []
|
||||
|
||||
class _FakeWorker:
|
||||
def __init__(self, key, model):
|
||||
self.key = key
|
||||
|
||||
def close(self):
|
||||
closed_workers.append(self.key)
|
||||
|
||||
class _FakeAgent:
|
||||
def __init__(self):
|
||||
self.model = "x"
|
||||
self.provider = "openrouter"
|
||||
self.base_url = ""
|
||||
self.api_key = ""
|
||||
|
||||
monkeypatch.setattr(server, "_make_agent", lambda sid, key: _FakeAgent())
|
||||
monkeypatch.setattr(server, "_SlashWorker", _FakeWorker)
|
||||
monkeypatch.setattr(server, "_get_db", lambda: types.SimpleNamespace(create_session=lambda *a, **kw: None))
|
||||
monkeypatch.setattr(server, "_session_info", lambda _a: {"model": "x"})
|
||||
monkeypatch.setattr(server, "_probe_credentials", lambda _a: None)
|
||||
monkeypatch.setattr(server, "_wire_callbacks", lambda _sid: None)
|
||||
monkeypatch.setattr(server, "_emit", lambda *a, **kw: None)
|
||||
|
||||
import tools.approval as _approval
|
||||
monkeypatch.setattr(_approval, "register_gateway_notify", lambda key, cb: None)
|
||||
monkeypatch.setattr(_approval, "unregister_gateway_notify",
|
||||
lambda key: unregistered_keys.append(key))
|
||||
monkeypatch.setattr(_approval, "load_permanent_allowlist", lambda: None)
|
||||
|
||||
resp = server.handle_request({
|
||||
"id": "1", "method": "session.create", "params": {"cols": 80},
|
||||
})
|
||||
sid = resp["result"]["session_id"]
|
||||
|
||||
# Wait for the build to finish (ready event inside session dict).
|
||||
session = server._sessions[sid]
|
||||
session["agent_ready"].wait(timeout=2.0)
|
||||
|
||||
# Build finished without a close race — nothing should have been
|
||||
# cleaned up by the orphan check.
|
||||
assert closed_workers == [], (
|
||||
f"build thread closed its own worker despite no race: {closed_workers}"
|
||||
)
|
||||
assert unregistered_keys == [], (
|
||||
f"build thread unregistered its own notify despite no race: {unregistered_keys}"
|
||||
)
|
||||
|
||||
# Session should have the live worker installed.
|
||||
assert session.get("slash_worker") is not None
|
||||
|
||||
# Cleanup
|
||||
server._sessions.pop(sid, None)
|
||||
|
|
|
|||
408
tests/tools/test_browser_cdp_tool.py
Normal file
408
tests/tools/test_browser_cdp_tool.py
Normal file
|
|
@ -0,0 +1,408 @@
|
|||
"""Unit tests for browser_cdp tool.
|
||||
|
||||
Uses a tiny in-process ``websockets`` server to simulate a CDP endpoint —
|
||||
gives real protocol coverage (connect, send, recv, close) without needing
|
||||
a real Chrome instance.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
import websockets
|
||||
from websockets.asyncio.server import serve
|
||||
|
||||
from tools import browser_cdp_tool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-process CDP mock server
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _CDPServer:
|
||||
"""A tiny CDP-over-WebSocket mock.
|
||||
|
||||
Each client gets a greeting-free stream. The server replies to each
|
||||
inbound request whose ``id`` is set, using the registered handler for
|
||||
that method. If no handler is registered, returns a generic CDP error.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._handlers: Dict[str, Any] = {}
|
||||
self._responses: List[Dict[str, Any]] = []
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._server: Any = None
|
||||
self._thread: threading.Thread | None = None
|
||||
self._host = "127.0.0.1"
|
||||
self._port = 0
|
||||
|
||||
# --- handler registration --------------------------------------------
|
||||
|
||||
def on(self, method: str, handler):
|
||||
"""Register a handler ``handler(params, session_id) -> dict or Exception``."""
|
||||
self._handlers[method] = handler
|
||||
|
||||
# --- lifecycle -------------------------------------------------------
|
||||
|
||||
def start(self) -> str:
|
||||
ready = threading.Event()
|
||||
|
||||
def _run() -> None:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
|
||||
async def _handler(ws):
|
||||
try:
|
||||
async for raw in ws:
|
||||
msg = json.loads(raw)
|
||||
call_id = msg.get("id")
|
||||
method = msg.get("method", "")
|
||||
params = msg.get("params", {}) or {}
|
||||
session_id = msg.get("sessionId")
|
||||
self._responses.append(msg)
|
||||
|
||||
fn = self._handlers.get(method)
|
||||
if fn is None:
|
||||
reply = {
|
||||
"id": call_id,
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": f"No handler for {method}",
|
||||
},
|
||||
}
|
||||
else:
|
||||
try:
|
||||
result = fn(params, session_id)
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
reply = {"id": call_id, "result": result}
|
||||
except Exception as exc:
|
||||
reply = {
|
||||
"id": call_id,
|
||||
"error": {"code": -1, "message": str(exc)},
|
||||
}
|
||||
if session_id:
|
||||
reply["sessionId"] = session_id
|
||||
await ws.send(json.dumps(reply))
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
|
||||
async def _serve() -> None:
|
||||
self._server = await serve(_handler, self._host, 0)
|
||||
sock = next(iter(self._server.sockets))
|
||||
self._port = sock.getsockname()[1]
|
||||
ready.set()
|
||||
await self._server.wait_closed()
|
||||
|
||||
try:
|
||||
self._loop.run_until_complete(_serve())
|
||||
finally:
|
||||
self._loop.close()
|
||||
|
||||
self._thread = threading.Thread(target=_run, daemon=True)
|
||||
self._thread.start()
|
||||
if not ready.wait(timeout=5.0):
|
||||
raise RuntimeError("CDP mock server failed to start within 5s")
|
||||
return f"ws://{self._host}:{self._port}/devtools/browser/mock"
|
||||
|
||||
def stop(self) -> None:
|
||||
if self._loop and self._server:
|
||||
def _close() -> None:
|
||||
self._server.close()
|
||||
|
||||
self._loop.call_soon_threadsafe(_close)
|
||||
if self._thread:
|
||||
self._thread.join(timeout=3.0)
|
||||
|
||||
def received(self) -> List[Dict[str, Any]]:
|
||||
return list(self._responses)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cdp_server(monkeypatch):
|
||||
"""Start a CDP mock and route tool resolution to it."""
|
||||
server = _CDPServer()
|
||||
ws_url = server.start()
|
||||
monkeypatch.setattr(
|
||||
browser_cdp_tool, "_resolve_cdp_endpoint", lambda: ws_url
|
||||
)
|
||||
try:
|
||||
yield server
|
||||
finally:
|
||||
server.stop()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_missing_method_returns_error():
|
||||
result = json.loads(browser_cdp_tool.browser_cdp(method=""))
|
||||
assert "error" in result
|
||||
assert "method" in result["error"].lower()
|
||||
assert result.get("cdp_docs") == browser_cdp_tool.CDP_DOCS_URL
|
||||
|
||||
|
||||
def test_non_string_method_returns_error():
|
||||
result = json.loads(browser_cdp_tool.browser_cdp(method=123)) # type: ignore[arg-type]
|
||||
assert "error" in result
|
||||
assert "method" in result["error"].lower()
|
||||
|
||||
|
||||
def test_non_dict_params_returns_error(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
browser_cdp_tool, "_resolve_cdp_endpoint", lambda: "ws://localhost:9999"
|
||||
)
|
||||
result = json.loads(
|
||||
browser_cdp_tool.browser_cdp(method="Target.getTargets", params="not-a-dict") # type: ignore[arg-type]
|
||||
)
|
||||
assert "error" in result
|
||||
assert "object" in result["error"].lower() or "dict" in result["error"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoint resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_no_endpoint_returns_helpful_error(monkeypatch):
|
||||
monkeypatch.setattr(browser_cdp_tool, "_resolve_cdp_endpoint", lambda: "")
|
||||
result = json.loads(browser_cdp_tool.browser_cdp(method="Target.getTargets"))
|
||||
assert "error" in result
|
||||
assert "/browser connect" in result["error"]
|
||||
assert result.get("cdp_docs") == browser_cdp_tool.CDP_DOCS_URL
|
||||
|
||||
|
||||
def test_non_ws_endpoint_returns_error(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
browser_cdp_tool, "_resolve_cdp_endpoint", lambda: "http://localhost:9222"
|
||||
)
|
||||
result = json.loads(browser_cdp_tool.browser_cdp(method="Target.getTargets"))
|
||||
assert "error" in result
|
||||
assert "WebSocket" in result["error"]
|
||||
|
||||
|
||||
def test_websockets_missing_returns_error(monkeypatch):
|
||||
monkeypatch.setattr(browser_cdp_tool, "_WS_AVAILABLE", False)
|
||||
result = json.loads(browser_cdp_tool.browser_cdp(method="Target.getTargets"))
|
||||
assert "error" in result
|
||||
assert "websockets" in result["error"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Happy-path: browser-level call
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_browser_level_success(cdp_server):
|
||||
cdp_server.on(
|
||||
"Target.getTargets",
|
||||
lambda params, sid: {
|
||||
"targetInfos": [
|
||||
{"targetId": "A", "type": "page", "title": "Tab 1", "url": "about:blank"},
|
||||
{"targetId": "B", "type": "page", "title": "Tab 2", "url": "https://a.test"},
|
||||
]
|
||||
},
|
||||
)
|
||||
result = json.loads(browser_cdp_tool.browser_cdp(method="Target.getTargets"))
|
||||
assert result["success"] is True
|
||||
assert result["method"] == "Target.getTargets"
|
||||
assert "target_id" not in result
|
||||
assert len(result["result"]["targetInfos"]) == 2
|
||||
# Verify the server actually received exactly one call (no extra traffic)
|
||||
calls = cdp_server.received()
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["method"] == "Target.getTargets"
|
||||
assert "sessionId" not in calls[0]
|
||||
|
||||
|
||||
def test_empty_params_sends_empty_object(cdp_server):
|
||||
cdp_server.on("Browser.getVersion", lambda params, sid: {"product": "Mock/1.0"})
|
||||
json.loads(browser_cdp_tool.browser_cdp(method="Browser.getVersion"))
|
||||
assert cdp_server.received()[0]["params"] == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Happy-path: target-attached call
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_target_attach_then_call(cdp_server):
|
||||
cdp_server.on(
|
||||
"Target.attachToTarget",
|
||||
lambda params, sid: {"sessionId": f"sess-{params['targetId']}"},
|
||||
)
|
||||
cdp_server.on(
|
||||
"Runtime.evaluate",
|
||||
lambda params, sid: {
|
||||
"result": {"type": "string", "value": f"evaluated[{sid}]"},
|
||||
},
|
||||
)
|
||||
result = json.loads(
|
||||
browser_cdp_tool.browser_cdp(
|
||||
method="Runtime.evaluate",
|
||||
params={"expression": "document.title", "returnByValue": True},
|
||||
target_id="tab-A",
|
||||
)
|
||||
)
|
||||
assert result["success"] is True
|
||||
assert result["target_id"] == "tab-A"
|
||||
assert result["result"]["result"]["value"] == "evaluated[sess-tab-A]"
|
||||
|
||||
calls = cdp_server.received()
|
||||
# First call: attach
|
||||
assert calls[0]["method"] == "Target.attachToTarget"
|
||||
assert calls[0]["params"] == {"targetId": "tab-A", "flatten": True}
|
||||
# Second call: dispatched method on the session
|
||||
assert calls[1]["method"] == "Runtime.evaluate"
|
||||
assert calls[1]["sessionId"] == "sess-tab-A"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CDP error responses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_cdp_method_error_returns_tool_error(cdp_server):
|
||||
# No handler registered -> server returns CDP error
|
||||
result = json.loads(
|
||||
browser_cdp_tool.browser_cdp(method="NonExistent.method")
|
||||
)
|
||||
assert "error" in result
|
||||
assert "CDP error" in result["error"]
|
||||
assert result.get("method") == "NonExistent.method"
|
||||
|
||||
|
||||
def test_attach_failure_returns_tool_error(cdp_server):
|
||||
# Target.attachToTarget has no handler -> server errors on attach
|
||||
result = json.loads(
|
||||
browser_cdp_tool.browser_cdp(
|
||||
method="Runtime.evaluate",
|
||||
params={"expression": "1+1"},
|
||||
target_id="missing",
|
||||
)
|
||||
)
|
||||
assert "error" in result
|
||||
assert "Target.attachToTarget" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Timeouts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_timeout_when_server_never_replies(cdp_server):
|
||||
# Register a handler that blocks forever
|
||||
def slow(params, sid):
|
||||
time.sleep(10)
|
||||
return {}
|
||||
|
||||
cdp_server.on("Page.slowMethod", slow)
|
||||
result = json.loads(
|
||||
browser_cdp_tool.browser_cdp(
|
||||
method="Page.slowMethod", timeout=0.5
|
||||
)
|
||||
)
|
||||
assert "error" in result
|
||||
assert "tim" in result["error"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Timeout clamping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_timeout_clamped_above_max(cdp_server):
|
||||
cdp_server.on("Browser.getVersion", lambda p, s: {"product": "ok"})
|
||||
# timeout=10_000 should be clamped to 300 but still succeed
|
||||
result = json.loads(
|
||||
browser_cdp_tool.browser_cdp(method="Browser.getVersion", timeout=10_000)
|
||||
)
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
def test_invalid_timeout_falls_back_to_default(cdp_server):
|
||||
cdp_server.on("Browser.getVersion", lambda p, s: {"product": "ok"})
|
||||
result = json.loads(
|
||||
browser_cdp_tool.browser_cdp(method="Browser.getVersion", timeout="nope") # type: ignore[arg-type]
|
||||
)
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_registered_in_browser_toolset():
|
||||
from tools.registry import registry
|
||||
|
||||
entry = registry.get_entry("browser_cdp")
|
||||
assert entry is not None
|
||||
assert entry.toolset == "browser"
|
||||
assert entry.schema["name"] == "browser_cdp"
|
||||
assert entry.schema["parameters"]["required"] == ["method"]
|
||||
assert "Chrome DevTools Protocol" in entry.schema["description"]
|
||||
assert browser_cdp_tool.CDP_DOCS_URL in entry.schema["description"]
|
||||
|
||||
|
||||
def test_dispatch_through_registry(cdp_server):
|
||||
from tools.registry import registry
|
||||
|
||||
cdp_server.on("Target.getTargets", lambda p, s: {"targetInfos": []})
|
||||
raw = registry.dispatch(
|
||||
"browser_cdp", {"method": "Target.getTargets"}, task_id="t1"
|
||||
)
|
||||
result = json.loads(raw)
|
||||
assert result["success"] is True
|
||||
assert result["method"] == "Target.getTargets"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_fn gating
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_check_fn_false_when_no_cdp_url(monkeypatch):
|
||||
"""Gate closes when no CDP URL is set — even if the browser toolset is
|
||||
otherwise configured."""
|
||||
import tools.browser_tool as bt
|
||||
|
||||
monkeypatch.setattr(bt, "check_browser_requirements", lambda: True)
|
||||
monkeypatch.setattr(bt, "_get_cdp_override", lambda: "")
|
||||
assert browser_cdp_tool._browser_cdp_check() is False
|
||||
|
||||
|
||||
def test_check_fn_true_when_cdp_url_set(monkeypatch):
|
||||
"""Gate opens as soon as a CDP URL is resolvable."""
|
||||
import tools.browser_tool as bt
|
||||
|
||||
monkeypatch.setattr(bt, "check_browser_requirements", lambda: True)
|
||||
monkeypatch.setattr(
|
||||
bt, "_get_cdp_override", lambda: "ws://localhost:9222/devtools/browser/x"
|
||||
)
|
||||
assert browser_cdp_tool._browser_cdp_check() is True
|
||||
|
||||
|
||||
def test_check_fn_false_when_browser_requirements_fail(monkeypatch):
|
||||
"""Even with a CDP URL, gate closes if the overall browser toolset is
|
||||
unavailable (e.g. agent-browser not installed)."""
|
||||
import tools.browser_tool as bt
|
||||
|
||||
monkeypatch.setattr(bt, "check_browser_requirements", lambda: False)
|
||||
monkeypatch.setattr(
|
||||
bt, "_get_cdp_override", lambda: "ws://localhost:9222/devtools/browser/x"
|
||||
)
|
||||
assert browser_cdp_tool._browser_cdp_check() is False
|
||||
455
tests/tools/test_code_execution_modes.py
Normal file
455
tests/tools/test_code_execution_modes.py
Normal file
|
|
@ -0,0 +1,455 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Tests for execute_code's strict / project execution modes.
|
||||
|
||||
The mode switch controls two things:
|
||||
- working directory: staging tmpdir (strict) vs session CWD (project)
|
||||
- interpreter: sys.executable (strict) vs active venv's python (project)
|
||||
|
||||
Security-critical invariants — env scrubbing, tool whitelist, resource caps —
|
||||
must apply identically in both modes. These tests guard all three layers.
|
||||
|
||||
Mode is sourced exclusively from ``code_execution.mode`` in config.yaml —
|
||||
there is no env-var override. Tests patch ``_load_config`` directly.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ["TERMINAL_ENV"] = "local"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _force_local_terminal(monkeypatch):
|
||||
"""Mirror test_code_execution.py — guarantee local backend under xdist."""
|
||||
monkeypatch.setenv("TERMINAL_ENV", "local")
|
||||
|
||||
|
||||
from tools.code_execution_tool import (
|
||||
SANDBOX_ALLOWED_TOOLS,
|
||||
DEFAULT_EXECUTION_MODE,
|
||||
EXECUTION_MODES,
|
||||
_get_execution_mode,
|
||||
_is_usable_python,
|
||||
_resolve_child_cwd,
|
||||
_resolve_child_python,
|
||||
build_execute_code_schema,
|
||||
execute_code,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _mock_mode(mode):
|
||||
"""Context manager that pins code_execution.mode to the given value."""
|
||||
with patch("tools.code_execution_tool._load_config",
|
||||
return_value={"mode": mode}):
|
||||
yield
|
||||
|
||||
|
||||
def _mock_handle_function_call(function_name, function_args, task_id=None, user_task=None):
|
||||
"""Minimal mock dispatcher reused across tests."""
|
||||
if function_name == "terminal":
|
||||
return json.dumps({"output": "mock", "exit_code": 0})
|
||||
if function_name == "read_file":
|
||||
return json.dumps({"content": "line1\n", "total_lines": 1})
|
||||
return json.dumps({"error": f"Unknown tool: {function_name}"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetExecutionMode(unittest.TestCase):
|
||||
"""_get_execution_mode reads config.yaml only (no env var surface)."""
|
||||
|
||||
def test_default_is_project(self):
|
||||
self.assertEqual(DEFAULT_EXECUTION_MODE, "project")
|
||||
|
||||
def test_config_project(self):
|
||||
with patch("tools.code_execution_tool._load_config",
|
||||
return_value={"mode": "project"}):
|
||||
self.assertEqual(_get_execution_mode(), "project")
|
||||
|
||||
def test_config_strict(self):
|
||||
with patch("tools.code_execution_tool._load_config",
|
||||
return_value={"mode": "strict"}):
|
||||
self.assertEqual(_get_execution_mode(), "strict")
|
||||
|
||||
def test_config_case_insensitive(self):
|
||||
with patch("tools.code_execution_tool._load_config",
|
||||
return_value={"mode": "STRICT"}):
|
||||
self.assertEqual(_get_execution_mode(), "strict")
|
||||
|
||||
def test_config_strips_whitespace(self):
|
||||
with patch("tools.code_execution_tool._load_config",
|
||||
return_value={"mode": " project "}):
|
||||
self.assertEqual(_get_execution_mode(), "project")
|
||||
|
||||
def test_empty_config_falls_back_to_default(self):
|
||||
with patch("tools.code_execution_tool._load_config", return_value={}):
|
||||
self.assertEqual(_get_execution_mode(), DEFAULT_EXECUTION_MODE)
|
||||
|
||||
def test_bogus_config_falls_back_to_default(self):
|
||||
with patch("tools.code_execution_tool._load_config",
|
||||
return_value={"mode": "banana"}):
|
||||
self.assertEqual(_get_execution_mode(), DEFAULT_EXECUTION_MODE)
|
||||
|
||||
def test_none_config_falls_back_to_default(self):
|
||||
with patch("tools.code_execution_tool._load_config",
|
||||
return_value={"mode": None}):
|
||||
# str(None).lower() = "none" → not in EXECUTION_MODES → default
|
||||
self.assertEqual(_get_execution_mode(), DEFAULT_EXECUTION_MODE)
|
||||
|
||||
def test_execution_modes_tuple(self):
|
||||
"""Canonical set of modes — tests + config layer rely on this shape."""
|
||||
self.assertEqual(set(EXECUTION_MODES), {"project", "strict"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Interpreter resolver
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResolveChildPython(unittest.TestCase):
|
||||
"""_resolve_child_python — picks the right interpreter per mode."""
|
||||
|
||||
def test_strict_always_sys_executable(self):
|
||||
"""Strict mode never leaves sys.executable, even if venv is set."""
|
||||
with patch.dict(os.environ, {"VIRTUAL_ENV": "/some/venv"}):
|
||||
self.assertEqual(_resolve_child_python("strict"), sys.executable)
|
||||
|
||||
def test_project_with_no_venv_falls_back(self):
|
||||
"""Project mode without VIRTUAL_ENV or CONDA_PREFIX → sys.executable."""
|
||||
env = {k: v for k, v in os.environ.items()
|
||||
if k not in ("VIRTUAL_ENV", "CONDA_PREFIX")}
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
self.assertEqual(_resolve_child_python("project"), sys.executable)
|
||||
|
||||
def test_project_with_virtualenv_picks_venv_python(self):
|
||||
"""Project mode + VIRTUAL_ENV pointing at a real venv → that python."""
|
||||
import tempfile, pathlib
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
fake_venv = pathlib.Path(td)
|
||||
(fake_venv / "bin").mkdir()
|
||||
# Symlink to real python so the version check actually passes
|
||||
(fake_venv / "bin" / "python").symlink_to(sys.executable)
|
||||
with patch.dict(os.environ, {"VIRTUAL_ENV": str(fake_venv)}):
|
||||
# Clear cache — _is_usable_python memoizes on path
|
||||
_is_usable_python.cache_clear()
|
||||
result = _resolve_child_python("project")
|
||||
self.assertEqual(result, str(fake_venv / "bin" / "python"))
|
||||
|
||||
def test_project_with_broken_venv_falls_back(self):
|
||||
"""VIRTUAL_ENV set but bin/python missing → sys.executable."""
|
||||
import tempfile
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
# No bin/python inside — broken venv
|
||||
with patch.dict(os.environ, {"VIRTUAL_ENV": td}):
|
||||
_is_usable_python.cache_clear()
|
||||
self.assertEqual(_resolve_child_python("project"), sys.executable)
|
||||
|
||||
def test_project_prefers_virtualenv_over_conda(self):
|
||||
"""If both VIRTUAL_ENV and CONDA_PREFIX are set, VIRTUAL_ENV wins."""
|
||||
import tempfile, pathlib
|
||||
with tempfile.TemporaryDirectory() as ve_td, tempfile.TemporaryDirectory() as conda_td:
|
||||
ve = pathlib.Path(ve_td)
|
||||
(ve / "bin").mkdir()
|
||||
(ve / "bin" / "python").symlink_to(sys.executable)
|
||||
|
||||
conda = pathlib.Path(conda_td)
|
||||
(conda / "bin").mkdir()
|
||||
(conda / "bin" / "python").symlink_to(sys.executable)
|
||||
|
||||
with patch.dict(os.environ, {"VIRTUAL_ENV": str(ve), "CONDA_PREFIX": str(conda)}):
|
||||
_is_usable_python.cache_clear()
|
||||
result = _resolve_child_python("project")
|
||||
self.assertEqual(result, str(ve / "bin" / "python"))
|
||||
|
||||
def test_is_usable_python_rejects_nonexistent(self):
|
||||
_is_usable_python.cache_clear()
|
||||
self.assertFalse(_is_usable_python("/does/not/exist/python"))
|
||||
|
||||
def test_is_usable_python_accepts_real_python(self):
|
||||
_is_usable_python.cache_clear()
|
||||
self.assertTrue(_is_usable_python(sys.executable))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CWD resolver
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResolveChildCwd(unittest.TestCase):
|
||||
|
||||
def test_strict_uses_staging_dir(self):
|
||||
self.assertEqual(_resolve_child_cwd("strict", "/tmp/staging"), "/tmp/staging")
|
||||
|
||||
def test_project_without_terminal_cwd_uses_getcwd(self):
|
||||
env = {k: v for k, v in os.environ.items() if k != "TERMINAL_CWD"}
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
self.assertEqual(_resolve_child_cwd("project", "/tmp/staging"), os.getcwd())
|
||||
|
||||
def test_project_uses_terminal_cwd_when_set(self):
|
||||
import tempfile
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
with patch.dict(os.environ, {"TERMINAL_CWD": td}):
|
||||
self.assertEqual(_resolve_child_cwd("project", "/tmp/staging"), td)
|
||||
|
||||
def test_project_bogus_terminal_cwd_falls_back_to_getcwd(self):
|
||||
with patch.dict(os.environ, {"TERMINAL_CWD": "/does/not/exist/anywhere"}):
|
||||
self.assertEqual(_resolve_child_cwd("project", "/tmp/staging"), os.getcwd())
|
||||
|
||||
def test_project_expands_tilde(self):
|
||||
import pathlib
|
||||
home = str(pathlib.Path.home())
|
||||
with patch.dict(os.environ, {"TERMINAL_CWD": "~"}):
|
||||
self.assertEqual(_resolve_child_cwd("project", "/tmp/staging"), home)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema description
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestModeAwareSchema(unittest.TestCase):
|
||||
|
||||
def test_strict_description_mentions_temp_dir(self):
|
||||
desc = build_execute_code_schema(mode="strict")["description"]
|
||||
self.assertIn("temp dir", desc)
|
||||
|
||||
def test_project_description_mentions_session_and_venv(self):
|
||||
desc = build_execute_code_schema(mode="project")["description"]
|
||||
self.assertIn("session", desc)
|
||||
self.assertIn("venv", desc)
|
||||
|
||||
def test_neither_description_uses_sandbox_language(self):
|
||||
"""REGRESSION GUARD for commit 39b83f34.
|
||||
|
||||
Agents on local backends falsely believed they were sandboxed and
|
||||
refused networking tasks. Do not reintroduce any 'sandbox' /
|
||||
'isolated' / 'cloud' language in the tool description.
|
||||
"""
|
||||
for mode in EXECUTION_MODES:
|
||||
desc = build_execute_code_schema(mode=mode)["description"].lower()
|
||||
for forbidden in ("sandbox", "isolated", "cloud"):
|
||||
self.assertNotIn(forbidden, desc,
|
||||
f"mode={mode}: '{forbidden}' leaked into description")
|
||||
|
||||
def test_descriptions_are_similar_length(self):
|
||||
"""Both modes should have roughly the same-size description."""
|
||||
strict = len(build_execute_code_schema(mode="strict")["description"])
|
||||
project = len(build_execute_code_schema(mode="project")["description"])
|
||||
self.assertLess(abs(strict - project), 200)
|
||||
|
||||
def test_default_mode_reads_config(self):
|
||||
"""build_execute_code_schema() with mode=None reads config.yaml."""
|
||||
with _mock_mode("strict"):
|
||||
desc = build_execute_code_schema()["description"]
|
||||
self.assertIn("temp dir", desc)
|
||||
with _mock_mode("project"):
|
||||
desc = build_execute_code_schema()["description"]
|
||||
self.assertIn("session", desc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: what actually happens when execute_code runs per mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="execute_code is POSIX-only")
|
||||
class TestExecuteCodeModeIntegration(unittest.TestCase):
|
||||
"""End-to-end: verify the subprocess actually runs where we expect."""
|
||||
|
||||
def _run(self, code, mode, enabled_tools=None, extra_env=None):
|
||||
env_overrides = extra_env or {}
|
||||
with _mock_mode(mode):
|
||||
with patch.dict(os.environ, env_overrides):
|
||||
with patch("model_tools.handle_function_call",
|
||||
side_effect=_mock_handle_function_call):
|
||||
raw = execute_code(
|
||||
code=code,
|
||||
task_id=f"test-{mode}",
|
||||
enabled_tools=enabled_tools or list(SANDBOX_ALLOWED_TOOLS),
|
||||
)
|
||||
return json.loads(raw)
|
||||
|
||||
def test_strict_mode_runs_in_tmpdir(self):
|
||||
"""Strict mode: script's os.getcwd() is the staging tmpdir."""
|
||||
result = self._run("import os; print(os.getcwd())", mode="strict")
|
||||
self.assertEqual(result["status"], "success")
|
||||
self.assertIn("hermes_sandbox_", result["output"])
|
||||
|
||||
def test_project_mode_runs_in_session_cwd(self):
|
||||
"""Project mode: script's os.getcwd() is the session's working dir."""
|
||||
import tempfile
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
result = self._run(
|
||||
"import os; print(os.getcwd())",
|
||||
mode="project",
|
||||
extra_env={"TERMINAL_CWD": td},
|
||||
)
|
||||
self.assertEqual(result["status"], "success")
|
||||
# Resolve symlinks (macOS /tmp → /private/tmp) on both sides
|
||||
self.assertEqual(
|
||||
os.path.realpath(result["output"].strip()),
|
||||
os.path.realpath(td),
|
||||
)
|
||||
|
||||
def test_project_mode_interpreter_is_venv_python(self):
|
||||
"""Project mode: sys.executable inside the child is the venv's python
|
||||
when VIRTUAL_ENV is set to a real venv."""
|
||||
# The hermes-agent venv is always active during tests, so this also
|
||||
# happens to equal sys.executable of the parent. What we're asserting
|
||||
# is: resolver picked a venv-bin/python path, not that it differs
|
||||
# from sys.executable.
|
||||
result = self._run("import sys; print(sys.executable)", mode="project")
|
||||
self.assertEqual(result["status"], "success")
|
||||
# Either VIRTUAL_ENV-bin/python or sys.executable fallback, both OK.
|
||||
output = result["output"].strip()
|
||||
ve = os.environ.get("VIRTUAL_ENV", "").strip()
|
||||
if ve:
|
||||
self.assertTrue(
|
||||
output.startswith(ve) or output == sys.executable,
|
||||
f"project-mode python should be under VIRTUAL_ENV={ve} or sys.executable={sys.executable}, got {output}",
|
||||
)
|
||||
|
||||
def test_project_mode_can_still_import_hermes_tools(self):
|
||||
"""Regression: hermes_tools still importable from non-tmpdir CWD.
|
||||
|
||||
This is the PYTHONPATH fix — without it, switching to session CWD
|
||||
breaks `from hermes_tools import terminal`.
|
||||
"""
|
||||
import tempfile
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
code = (
|
||||
"from hermes_tools import terminal\n"
|
||||
"r = terminal('echo x')\n"
|
||||
"print(r.get('output', 'MISSING'))\n"
|
||||
)
|
||||
result = self._run(code, mode="project", extra_env={"TERMINAL_CWD": td})
|
||||
self.assertEqual(result["status"], "success")
|
||||
self.assertIn("mock", result["output"])
|
||||
|
||||
def test_strict_mode_can_still_import_hermes_tools(self):
|
||||
"""Regression: strict mode's tmpdir CWD still works for imports."""
|
||||
code = (
|
||||
"from hermes_tools import terminal\n"
|
||||
"r = terminal('echo x')\n"
|
||||
"print(r.get('output', 'MISSING'))\n"
|
||||
)
|
||||
result = self._run(code, mode="strict")
|
||||
self.assertEqual(result["status"], "success")
|
||||
self.assertIn("mock", result["output"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SECURITY-CRITICAL regression guards
|
||||
#
|
||||
# These MUST pass in both strict and project mode. The whole tiered-mode
|
||||
# proposition rests on the claim that switching from strict to project only
|
||||
# changes CWD + interpreter, not the security posture.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="execute_code is POSIX-only")
|
||||
class TestSecurityInvariantsAcrossModes(unittest.TestCase):
|
||||
|
||||
def _run(self, code, mode):
|
||||
with _mock_mode(mode):
|
||||
with patch("model_tools.handle_function_call",
|
||||
side_effect=_mock_handle_function_call):
|
||||
raw = execute_code(
|
||||
code=code,
|
||||
task_id=f"test-sec-{mode}",
|
||||
enabled_tools=list(SANDBOX_ALLOWED_TOOLS),
|
||||
)
|
||||
return json.loads(raw)
|
||||
|
||||
def test_api_keys_scrubbed_in_strict_mode(self):
|
||||
code = (
|
||||
"import os\n"
|
||||
"print('KEY=' + os.environ.get('OPENAI_API_KEY', 'MISSING'))\n"
|
||||
"print('TOK=' + os.environ.get('ANTHROPIC_API_KEY', 'MISSING'))\n"
|
||||
)
|
||||
with patch.dict(os.environ, {
|
||||
"OPENAI_API_KEY": "sk-should-not-leak",
|
||||
"ANTHROPIC_API_KEY": "ant-should-not-leak",
|
||||
}):
|
||||
result = self._run(code, mode="strict")
|
||||
self.assertEqual(result["status"], "success")
|
||||
self.assertIn("KEY=MISSING", result["output"])
|
||||
self.assertIn("TOK=MISSING", result["output"])
|
||||
self.assertNotIn("sk-should-not-leak", result["output"])
|
||||
self.assertNotIn("ant-should-not-leak", result["output"])
|
||||
|
||||
def test_api_keys_scrubbed_in_project_mode(self):
|
||||
"""CRITICAL: the project-mode default does NOT leak user credentials."""
|
||||
code = (
|
||||
"import os\n"
|
||||
"print('KEY=' + os.environ.get('OPENAI_API_KEY', 'MISSING'))\n"
|
||||
"print('TOK=' + os.environ.get('ANTHROPIC_API_KEY', 'MISSING'))\n"
|
||||
"print('SEC=' + os.environ.get('GITHUB_TOKEN', 'MISSING'))\n"
|
||||
)
|
||||
with patch.dict(os.environ, {
|
||||
"OPENAI_API_KEY": "sk-should-not-leak",
|
||||
"ANTHROPIC_API_KEY": "ant-should-not-leak",
|
||||
"GITHUB_TOKEN": "ghp-should-not-leak",
|
||||
}):
|
||||
result = self._run(code, mode="project")
|
||||
self.assertEqual(result["status"], "success")
|
||||
for needle in ("KEY=MISSING", "TOK=MISSING", "SEC=MISSING"):
|
||||
self.assertIn(needle, result["output"])
|
||||
for leaked in ("sk-should-not-leak", "ant-should-not-leak", "ghp-should-not-leak"):
|
||||
self.assertNotIn(leaked, result["output"])
|
||||
|
||||
def test_secret_substrings_scrubbed_in_project_mode(self):
|
||||
"""SECRET/PASSWORD/CREDENTIAL/PASSWD/AUTH filters still apply."""
|
||||
code = (
|
||||
"import os\n"
|
||||
"for k in ('MY_SECRET', 'DB_PASSWORD', 'VAULT_CREDENTIAL', "
|
||||
"'LDAP_PASSWD', 'AUTH_TOKEN'):\n"
|
||||
" print(f'{k}=' + os.environ.get(k, 'MISSING'))\n"
|
||||
)
|
||||
with patch.dict(os.environ, {
|
||||
"MY_SECRET": "secret-should-not-leak",
|
||||
"DB_PASSWORD": "password-should-not-leak",
|
||||
"VAULT_CREDENTIAL": "cred-should-not-leak",
|
||||
"LDAP_PASSWD": "passwd-should-not-leak",
|
||||
"AUTH_TOKEN": "auth-should-not-leak",
|
||||
}):
|
||||
result = self._run(code, mode="project")
|
||||
self.assertEqual(result["status"], "success")
|
||||
for leaked in ("secret-should-not-leak", "password-should-not-leak",
|
||||
"cred-should-not-leak", "passwd-should-not-leak",
|
||||
"auth-should-not-leak"):
|
||||
self.assertNotIn(leaked, result["output"])
|
||||
|
||||
def test_tool_whitelist_enforced_in_strict_mode(self):
|
||||
"""A script cannot RPC-call tools outside SANDBOX_ALLOWED_TOOLS."""
|
||||
# execute_code is NOT in SANDBOX_ALLOWED_TOOLS (no recursion)
|
||||
self.assertNotIn("execute_code", SANDBOX_ALLOWED_TOOLS)
|
||||
code = (
|
||||
"import hermes_tools as ht\n"
|
||||
"print('execute_code_available:', hasattr(ht, 'execute_code'))\n"
|
||||
"print('delegate_task_available:', hasattr(ht, 'delegate_task'))\n"
|
||||
)
|
||||
result = self._run(code, mode="strict")
|
||||
self.assertEqual(result["status"], "success")
|
||||
self.assertIn("execute_code_available: False", result["output"])
|
||||
self.assertIn("delegate_task_available: False", result["output"])
|
||||
|
||||
def test_tool_whitelist_enforced_in_project_mode(self):
|
||||
"""CRITICAL: project mode does NOT widen the tool whitelist."""
|
||||
code = (
|
||||
"import hermes_tools as ht\n"
|
||||
"print('execute_code_available:', hasattr(ht, 'execute_code'))\n"
|
||||
"print('delegate_task_available:', hasattr(ht, 'delegate_task'))\n"
|
||||
)
|
||||
result = self._run(code, mode="project")
|
||||
self.assertEqual(result["status"], "success")
|
||||
self.assertIn("execute_code_available: False", result["output"])
|
||||
self.assertIn("delegate_task_available: False", result["output"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
256
tests/tools/test_cron_approval_mode.py
Normal file
256
tests/tools/test_cron_approval_mode.py
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
"""Tests for approvals.cron_mode — configurable approval behavior for cron jobs."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
import tools.approval as approval_module
|
||||
from tools.approval import (
|
||||
_get_cron_approval_mode,
|
||||
check_all_command_guards,
|
||||
check_dangerous_command,
|
||||
detect_dangerous_command,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_approval_state():
|
||||
approval_module._permanent_approved.clear()
|
||||
approval_module.clear_session("default")
|
||||
approval_module.clear_session("test-session")
|
||||
yield
|
||||
approval_module._permanent_approved.clear()
|
||||
approval_module.clear_session("default")
|
||||
approval_module.clear_session("test-session")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_cron_approval_mode() config parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCronApprovalModeParsing:
|
||||
def test_default_is_deny(self):
|
||||
"""When no config is set, cron_mode defaults to 'deny'."""
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {}}):
|
||||
assert _get_cron_approval_mode() == "deny"
|
||||
|
||||
def test_explicit_deny(self):
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"cron_mode": "deny"}}):
|
||||
assert _get_cron_approval_mode() == "deny"
|
||||
|
||||
def test_explicit_approve(self):
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"cron_mode": "approve"}}):
|
||||
assert _get_cron_approval_mode() == "approve"
|
||||
|
||||
def test_off_maps_to_approve(self):
|
||||
"""'off' is an alias for 'approve' (matches --yolo semantics)."""
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"cron_mode": "off"}}):
|
||||
assert _get_cron_approval_mode() == "approve"
|
||||
|
||||
def test_allow_maps_to_approve(self):
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"cron_mode": "allow"}}):
|
||||
assert _get_cron_approval_mode() == "approve"
|
||||
|
||||
def test_yes_maps_to_approve(self):
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"cron_mode": "yes"}}):
|
||||
assert _get_cron_approval_mode() == "approve"
|
||||
|
||||
def test_case_insensitive(self):
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"cron_mode": "APPROVE"}}):
|
||||
assert _get_cron_approval_mode() == "approve"
|
||||
|
||||
def test_unknown_value_defaults_to_deny(self):
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"cron_mode": "maybe"}}):
|
||||
assert _get_cron_approval_mode() == "deny"
|
||||
|
||||
def test_config_load_failure_defaults_to_deny(self):
|
||||
"""If config loading fails entirely, default to deny (safe)."""
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("hermes_cli.config.load_config", side_effect=RuntimeError("config broken")):
|
||||
assert _get_cron_approval_mode() == "deny"
|
||||
|
||||
def test_yaml_boolean_false_maps_to_deny(self):
|
||||
"""YAML 1.1 parses bare 'off' as False. Ensure it maps to deny."""
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"cron_mode": False}}):
|
||||
# str(False) = "False", which is not in the approve set, so deny
|
||||
assert _get_cron_approval_mode() == "deny"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_dangerous_command() with cron session
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCronDenyMode:
|
||||
"""When HERMES_CRON_SESSION is set and cron_mode=deny, dangerous commands are blocked."""
|
||||
|
||||
def test_dangerous_command_blocked_in_cron_deny_mode(self, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
|
||||
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
|
||||
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("tools.approval._get_cron_approval_mode", return_value="deny"):
|
||||
result = check_dangerous_command("rm -rf /tmp/stuff", "local")
|
||||
assert not result["approved"]
|
||||
assert "BLOCKED" in result["message"]
|
||||
assert "cron_mode" in result["message"]
|
||||
|
||||
def test_safe_command_allowed_in_cron_deny_mode(self, monkeypatch):
|
||||
"""Non-dangerous commands still work even with cron_mode=deny."""
|
||||
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
|
||||
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
|
||||
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("tools.approval._get_cron_approval_mode", return_value="deny"):
|
||||
result = check_dangerous_command("ls -la", "local")
|
||||
assert result["approved"]
|
||||
|
||||
def test_multiple_dangerous_patterns_blocked(self, monkeypatch):
|
||||
"""All dangerous patterns are blocked, not just rm."""
|
||||
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
|
||||
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
|
||||
|
||||
dangerous_commands = [
|
||||
"rm -rf /",
|
||||
"chmod 777 /etc/passwd",
|
||||
"mkfs.ext4 /dev/sda1",
|
||||
"dd if=/dev/zero of=/dev/sda",
|
||||
]
|
||||
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("tools.approval._get_cron_approval_mode", return_value="deny"):
|
||||
for cmd in dangerous_commands:
|
||||
is_dangerous, _, _ = detect_dangerous_command(cmd)
|
||||
if is_dangerous:
|
||||
result = check_dangerous_command(cmd, "local")
|
||||
assert not result["approved"], f"Should be blocked: {cmd}"
|
||||
assert "BLOCKED" in result["message"]
|
||||
|
||||
def test_block_message_includes_description(self, monkeypatch):
|
||||
"""The block message should mention what pattern was matched."""
|
||||
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
|
||||
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
|
||||
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("tools.approval._get_cron_approval_mode", return_value="deny"):
|
||||
result = check_dangerous_command("rm -rf /tmp/stuff", "local")
|
||||
assert not result["approved"]
|
||||
# Should contain the description of what was flagged
|
||||
assert "dangerous" in result["message"].lower() or "delete" in result["message"].lower()
|
||||
|
||||
|
||||
class TestCronApproveMode:
|
||||
"""When HERMES_CRON_SESSION is set and cron_mode=approve, dangerous commands pass through."""
|
||||
|
||||
def test_dangerous_command_allowed_in_cron_approve_mode(self, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
|
||||
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
|
||||
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("tools.approval._get_cron_approval_mode", return_value="approve"):
|
||||
result = check_dangerous_command("rm -rf /tmp/stuff", "local")
|
||||
assert result["approved"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_all_command_guards() with cron session
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCronDenyModeAllGuards:
|
||||
"""The combined guard function also respects cron_mode."""
|
||||
|
||||
def test_dangerous_command_blocked_in_combined_guard(self, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
|
||||
monkeypatch.delenv("HERMES_EXEC_ASK", raising=False)
|
||||
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
|
||||
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("tools.approval._get_cron_approval_mode", return_value="deny"):
|
||||
result = check_all_command_guards("rm -rf /tmp/stuff", "local")
|
||||
assert not result["approved"]
|
||||
assert "BLOCKED" in result["message"]
|
||||
|
||||
def test_safe_command_allowed_in_combined_guard(self, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
|
||||
monkeypatch.delenv("HERMES_EXEC_ASK", raising=False)
|
||||
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
|
||||
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("tools.approval._get_cron_approval_mode", return_value="deny"):
|
||||
result = check_all_command_guards("echo hello", "local")
|
||||
assert result["approved"]
|
||||
|
||||
def test_combined_guard_approve_mode(self, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
|
||||
monkeypatch.delenv("HERMES_EXEC_ASK", raising=False)
|
||||
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
|
||||
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("tools.approval._get_cron_approval_mode", return_value="approve"):
|
||||
result = check_all_command_guards("rm -rf /tmp/stuff", "local")
|
||||
assert result["approved"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases: cron mode interaction with other approval mechanisms
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCronModeInteractions:
|
||||
"""Cron mode should NOT interfere with other approval bypass mechanisms."""
|
||||
|
||||
def test_container_env_still_auto_approves(self, monkeypatch):
|
||||
"""Docker/sandbox environments bypass approvals regardless of cron_mode."""
|
||||
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
|
||||
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
|
||||
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("tools.approval._get_cron_approval_mode", return_value="deny"):
|
||||
result = check_dangerous_command("rm -rf /", "docker")
|
||||
assert result["approved"]
|
||||
|
||||
def test_yolo_overrides_cron_deny(self, monkeypatch):
|
||||
"""--yolo still works even if cron_mode=deny."""
|
||||
monkeypatch.setenv("HERMES_CRON_SESSION", "1")
|
||||
monkeypatch.setenv("HERMES_YOLO_MODE", "1")
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
|
||||
|
||||
from unittest.mock import patch as mock_patch
|
||||
with mock_patch("tools.approval._get_cron_approval_mode", return_value="deny"):
|
||||
result = check_dangerous_command("rm -rf /", "local")
|
||||
assert result["approved"]
|
||||
|
||||
def test_non_cron_non_interactive_still_auto_approves(self, monkeypatch):
|
||||
"""Non-cron, non-interactive sessions (e.g. scripted usage) still auto-approve."""
|
||||
monkeypatch.delenv("HERMES_CRON_SESSION", raising=False)
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
|
||||
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
|
||||
|
||||
result = check_dangerous_command("rm -rf /tmp/stuff", "local")
|
||||
assert result["approved"]
|
||||
|
|
@ -192,23 +192,23 @@ class TestUnifiedCronjobTool:
|
|||
result = json.loads(
|
||||
cronjob(
|
||||
action="create",
|
||||
skills=["blogwatcher", "find-nearby"],
|
||||
skills=["blogwatcher", "maps"],
|
||||
prompt="Use both skills and combine the result.",
|
||||
schedule="every 1h",
|
||||
name="Combo job",
|
||||
)
|
||||
)
|
||||
assert result["success"] is True
|
||||
assert result["skills"] == ["blogwatcher", "find-nearby"]
|
||||
assert result["skills"] == ["blogwatcher", "maps"]
|
||||
|
||||
listing = json.loads(cronjob(action="list"))
|
||||
assert listing["jobs"][0]["skills"] == ["blogwatcher", "find-nearby"]
|
||||
assert listing["jobs"][0]["skills"] == ["blogwatcher", "maps"]
|
||||
|
||||
def test_multi_skill_default_name_prefers_prompt_when_present(self):
|
||||
result = json.loads(
|
||||
cronjob(
|
||||
action="create",
|
||||
skills=["blogwatcher", "find-nearby"],
|
||||
skills=["blogwatcher", "maps"],
|
||||
prompt="Use both skills and combine the result.",
|
||||
schedule="every 1h",
|
||||
)
|
||||
|
|
@ -220,7 +220,7 @@ class TestUnifiedCronjobTool:
|
|||
created = json.loads(
|
||||
cronjob(
|
||||
action="create",
|
||||
skills=["blogwatcher", "find-nearby"],
|
||||
skills=["blogwatcher", "maps"],
|
||||
prompt="Use both skills and combine the result.",
|
||||
schedule="every 1h",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -274,6 +274,7 @@ class TestDelegateTask(unittest.TestCase):
|
|||
model=None,
|
||||
max_iterations=10,
|
||||
parent_agent=parent,
|
||||
task_count=1,
|
||||
)
|
||||
|
||||
self.assertIs(mock_child._print_fn, sink)
|
||||
|
|
@ -294,6 +295,7 @@ class TestDelegateTask(unittest.TestCase):
|
|||
model=None,
|
||||
max_iterations=10,
|
||||
parent_agent=parent,
|
||||
task_count=1,
|
||||
)
|
||||
|
||||
self.assertTrue(callable(mock_child.thinking_callback))
|
||||
|
|
@ -363,6 +365,7 @@ class TestToolNamePreservation(unittest.TestCase):
|
|||
model=None,
|
||||
max_iterations=10,
|
||||
parent_agent=parent,
|
||||
task_count=1,
|
||||
)
|
||||
except NameError as exc:
|
||||
self.fail(
|
||||
|
|
@ -1000,6 +1003,7 @@ class TestChildCredentialPoolResolution(unittest.TestCase):
|
|||
model=None,
|
||||
max_iterations=10,
|
||||
parent_agent=parent,
|
||||
task_count=1,
|
||||
)
|
||||
|
||||
self.assertEqual(mock_child._credential_pool, mock_pool)
|
||||
|
|
@ -1225,6 +1229,7 @@ class TestDelegationReasoningEffort(unittest.TestCase):
|
|||
_build_child_agent(
|
||||
task_index=0, goal="test", context=None, toolsets=None,
|
||||
model=None, max_iterations=50, parent_agent=parent,
|
||||
task_count=1,
|
||||
)
|
||||
call_kwargs = MockAgent.call_args[1]
|
||||
self.assertEqual(call_kwargs["reasoning_config"], {"enabled": True, "effort": "xhigh"})
|
||||
|
|
@ -1241,6 +1246,7 @@ class TestDelegationReasoningEffort(unittest.TestCase):
|
|||
_build_child_agent(
|
||||
task_index=0, goal="test", context=None, toolsets=None,
|
||||
model=None, max_iterations=50, parent_agent=parent,
|
||||
task_count=1,
|
||||
)
|
||||
call_kwargs = MockAgent.call_args[1]
|
||||
self.assertEqual(call_kwargs["reasoning_config"], {"enabled": True, "effort": "low"})
|
||||
|
|
@ -1257,6 +1263,7 @@ class TestDelegationReasoningEffort(unittest.TestCase):
|
|||
_build_child_agent(
|
||||
task_index=0, goal="test", context=None, toolsets=None,
|
||||
model=None, max_iterations=50, parent_agent=parent,
|
||||
task_count=1,
|
||||
)
|
||||
call_kwargs = MockAgent.call_args[1]
|
||||
self.assertEqual(call_kwargs["reasoning_config"], {"enabled": False})
|
||||
|
|
@ -1273,6 +1280,7 @@ class TestDelegationReasoningEffort(unittest.TestCase):
|
|||
_build_child_agent(
|
||||
task_index=0, goal="test", context=None, toolsets=None,
|
||||
model=None, max_iterations=50, parent_agent=parent,
|
||||
task_count=1,
|
||||
)
|
||||
call_kwargs = MockAgent.call_args[1]
|
||||
self.assertEqual(call_kwargs["reasoning_config"], {"enabled": True, "effort": "medium"})
|
||||
|
|
|
|||
145
tests/tools/test_local_interrupt_cleanup.py
Normal file
145
tests/tools/test_local_interrupt_cleanup.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""Regression tests for _wait_for_process subprocess cleanup on exception exit.
|
||||
|
||||
When the poll loop exits via KeyboardInterrupt or SystemExit (SIGTERM via
|
||||
cli.py signal handler, SIGINT on the main thread in non-interactive -q mode,
|
||||
or explicit sys.exit from some caller), the child subprocess must be killed
|
||||
before the exception propagates — otherwise the local backend's use of
|
||||
os.setsid leaves an orphan with PPID=1.
|
||||
|
||||
The live repro that motivated this: hermes chat -q ... 'sleep 300', SIGTERM
|
||||
to the python process, sleep 300 survived with PPID=1 for the full 300 s
|
||||
because _wait_for_process never got to call _kill_process before python
|
||||
died. See commit message for full context.
|
||||
"""
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments.local import LocalEnvironment
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_hermes_home(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
(tmp_path / "logs").mkdir(exist_ok=True)
|
||||
|
||||
|
||||
def _pgid_still_alive(pgid: int) -> bool:
|
||||
"""Return True if any process in the given process group is still alive."""
|
||||
try:
|
||||
os.killpg(pgid, 0) # signal 0 = existence check
|
||||
return True
|
||||
except ProcessLookupError:
|
||||
return False
|
||||
|
||||
|
||||
def test_wait_for_process_kills_subprocess_on_keyboardinterrupt():
|
||||
"""When KeyboardInterrupt arrives mid-poll, the subprocess group must be
|
||||
killed before the exception is re-raised."""
|
||||
env = LocalEnvironment(cwd="/tmp")
|
||||
try:
|
||||
result_holder = {}
|
||||
proc_holder = {}
|
||||
started = threading.Event()
|
||||
raise_at = [None] # set by the main thread to tell worker when
|
||||
|
||||
# Drive execute() on a separate thread so we can SIGNAL-interrupt it
|
||||
# via a thread-targeted exception without killing our test process.
|
||||
def worker():
|
||||
# Spawn a subprocess that will definitely be alive long enough
|
||||
# to observe the cleanup, via env.execute(...) — the normal path
|
||||
# that goes through _wait_for_process.
|
||||
try:
|
||||
result_holder["result"] = env.execute("sleep 30", timeout=60)
|
||||
except BaseException as e: # noqa: BLE001 — we want to observe it
|
||||
result_holder["exception"] = type(e).__name__
|
||||
|
||||
t = threading.Thread(target=worker, daemon=True)
|
||||
t.start()
|
||||
# Wait until the subprocess actually exists. LocalEnvironment.execute
|
||||
# does init_session() (one spawn) before the real command, so we need
|
||||
# to wait until a sleep 30 is visible. Use pgrep-style lookup via
|
||||
# /proc to find the bash process running our sleep.
|
||||
deadline = time.monotonic() + 5.0
|
||||
target_pid = None
|
||||
while time.monotonic() < deadline:
|
||||
# Walk our children and grand-children to find one running 'sleep 30'
|
||||
try:
|
||||
import psutil # optional — fall back if absent
|
||||
for p in psutil.Process(os.getpid()).children(recursive=True):
|
||||
try:
|
||||
if "sleep 30" in " ".join(p.cmdline()):
|
||||
target_pid = p.pid
|
||||
break
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
continue
|
||||
except ImportError:
|
||||
# Fall back to ps
|
||||
ps = subprocess.run(
|
||||
["ps", "-eo", "pid,ppid,pgid,cmd"], capture_output=True, text=True,
|
||||
)
|
||||
for line in ps.stdout.splitlines():
|
||||
if "sleep 30" in line and "grep" not in line:
|
||||
parts = line.split()
|
||||
if parts and parts[0].isdigit():
|
||||
target_pid = int(parts[0])
|
||||
break
|
||||
if target_pid:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert target_pid is not None, (
|
||||
"test setup: couldn't find 'sleep 30' subprocess after 5 s"
|
||||
)
|
||||
pgid = os.getpgid(target_pid)
|
||||
assert _pgid_still_alive(pgid), "sanity: subprocess should be alive"
|
||||
|
||||
# Now inject a KeyboardInterrupt into the worker thread the same
|
||||
# way CPython's signal machinery would. We use ctypes.PyThreadState_SetAsyncExc
|
||||
# which is how signal delivery to non-main threads is simulated.
|
||||
import ctypes
|
||||
import sys as _sys
|
||||
# py-thread-state exception targets need the ident, not the Thread
|
||||
tid = t.ident
|
||||
assert tid is not None
|
||||
# Fire KeyboardInterrupt into the worker thread
|
||||
ret = ctypes.pythonapi.PyThreadState_SetAsyncExc(
|
||||
ctypes.c_ulong(tid), ctypes.py_object(KeyboardInterrupt),
|
||||
)
|
||||
assert ret == 1, f"SetAsyncExc returned {ret}, expected 1"
|
||||
|
||||
# Give the worker a moment to: hit the exception at the next poll,
|
||||
# run the except-block cleanup (_kill_process), and exit.
|
||||
t.join(timeout=5.0)
|
||||
assert not t.is_alive(), "worker didn't exit within 5 s of the interrupt"
|
||||
|
||||
# The critical assertion: the subprocess GROUP must be dead. Not
|
||||
# just the bash wrapper — the 'sleep 30' child too.
|
||||
# Give the SIGTERM+1s wait+SIGKILL escalation a moment to complete.
|
||||
deadline = time.monotonic() + 3.0
|
||||
while time.monotonic() < deadline:
|
||||
if not _pgid_still_alive(pgid):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
assert not _pgid_still_alive(pgid), (
|
||||
f"subprocess group {pgid} is STILL ALIVE after worker received "
|
||||
f"KeyboardInterrupt — orphan bug regressed. This is the "
|
||||
f"sleep-300-survives-SIGTERM scenario from Physikal's Apr 2026 "
|
||||
f"report. See tools/environments/base.py _wait_for_process "
|
||||
f"except-block."
|
||||
)
|
||||
# And the worker should have observed the KeyboardInterrupt (i.e.
|
||||
# it re-raised cleanly, not silently swallowed).
|
||||
assert result_holder.get("exception") == "KeyboardInterrupt", (
|
||||
f"worker result: {result_holder!r} — expected KeyboardInterrupt "
|
||||
f"propagation after cleanup"
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
env.cleanup()
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -296,6 +296,8 @@ class TestBuiltinDiscovery:
|
|||
"tools.code_execution_tool",
|
||||
"tools.cronjob_tools",
|
||||
"tools.delegate_tool",
|
||||
"tools.feishu_doc_tool",
|
||||
"tools.feishu_drive_tool",
|
||||
"tools.file_tools",
|
||||
"tools.homeassistant_tool",
|
||||
"tools.image_generation_tool",
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import io
|
|||
import json
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -120,7 +121,9 @@ def test_block_and_respond(capture):
|
|||
|
||||
rid = next(iter(server._pending))
|
||||
server._answers[rid] = "my_answer"
|
||||
server._pending[rid].set()
|
||||
# _pending values are (sid, Event) tuples — unpack to set the Event
|
||||
_, ev = server._pending[rid]
|
||||
ev.set()
|
||||
|
||||
threading.Event().wait(0.1)
|
||||
assert result[0] == "my_answer"
|
||||
|
|
@ -128,7 +131,8 @@ def test_block_and_respond(capture):
|
|||
|
||||
def test_clear_pending(server):
|
||||
ev = threading.Event()
|
||||
server._pending["r1"] = ev
|
||||
# _pending values are (sid, Event) tuples
|
||||
server._pending["r1"] = ("sid-x", ev)
|
||||
server._clear_pending()
|
||||
|
||||
assert ev.is_set()
|
||||
|
|
@ -231,3 +235,279 @@ def test_cli_exec_blocked(server, argv):
|
|||
])
|
||||
def test_cli_exec_allowed(server, argv):
|
||||
assert server._cli_exec_blocked(argv) is None
|
||||
|
||||
|
||||
# ── slash.exec skill command interception ────────────────────────────
|
||||
|
||||
|
||||
def test_slash_exec_rejects_skill_commands(server):
|
||||
"""slash.exec must reject skill commands so the TUI falls through to command.dispatch."""
|
||||
# Register a mock session
|
||||
sid = "test-session"
|
||||
server._sessions[sid] = {"session_key": sid, "agent": None}
|
||||
|
||||
# Mock scan_skill_commands to return a known skill
|
||||
fake_skills = {"/hermes-agent-dev": {"name": "hermes-agent-dev", "description": "Dev workflow"}}
|
||||
|
||||
with patch("agent.skill_commands.get_skill_commands", return_value=fake_skills):
|
||||
resp = server.handle_request({
|
||||
"id": "r1",
|
||||
"method": "slash.exec",
|
||||
"params": {"command": "hermes-agent-dev", "session_id": sid},
|
||||
})
|
||||
|
||||
# Should return an error so the TUI's .catch() fires command.dispatch
|
||||
assert "error" in resp
|
||||
assert resp["error"]["code"] == 4018
|
||||
assert "skill command" in resp["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cmd", ["retry", "queue hello", "q hello", "steer fix the test", "plan"])
|
||||
def test_slash_exec_rejects_pending_input_commands(server, cmd):
|
||||
"""slash.exec must reject commands that use _pending_input in the CLI."""
|
||||
sid = "test-session"
|
||||
server._sessions[sid] = {"session_key": sid, "agent": None}
|
||||
|
||||
resp = server.handle_request({
|
||||
"id": "r1",
|
||||
"method": "slash.exec",
|
||||
"params": {"command": cmd, "session_id": sid},
|
||||
})
|
||||
|
||||
assert "error" in resp
|
||||
assert resp["error"]["code"] == 4018
|
||||
assert "pending-input command" in resp["error"]["message"]
|
||||
|
||||
|
||||
def test_command_dispatch_queue_sends_message(server):
|
||||
"""command.dispatch /queue returns {type: 'send', message: ...} for the TUI."""
|
||||
sid = "test-session"
|
||||
server._sessions[sid] = {"session_key": sid}
|
||||
|
||||
resp = server.handle_request({
|
||||
"id": "r1",
|
||||
"method": "command.dispatch",
|
||||
"params": {"name": "queue", "arg": "tell me about quantum computing", "session_id": sid},
|
||||
})
|
||||
|
||||
assert "error" not in resp
|
||||
result = resp["result"]
|
||||
assert result["type"] == "send"
|
||||
assert result["message"] == "tell me about quantum computing"
|
||||
|
||||
|
||||
def test_command_dispatch_queue_requires_arg(server):
|
||||
"""command.dispatch /queue without an argument returns an error."""
|
||||
sid = "test-session"
|
||||
server._sessions[sid] = {"session_key": sid}
|
||||
|
||||
resp = server.handle_request({
|
||||
"id": "r2",
|
||||
"method": "command.dispatch",
|
||||
"params": {"name": "queue", "arg": "", "session_id": sid},
|
||||
})
|
||||
|
||||
assert "error" in resp
|
||||
assert resp["error"]["code"] == 4004
|
||||
|
||||
|
||||
def test_command_dispatch_steer_fallback_sends_message(server):
|
||||
"""command.dispatch /steer with no active agent falls back to send."""
|
||||
sid = "test-session"
|
||||
server._sessions[sid] = {"session_key": sid, "agent": None}
|
||||
|
||||
resp = server.handle_request({
|
||||
"id": "r3",
|
||||
"method": "command.dispatch",
|
||||
"params": {"name": "steer", "arg": "focus on testing", "session_id": sid},
|
||||
})
|
||||
|
||||
assert "error" not in resp
|
||||
result = resp["result"]
|
||||
assert result["type"] == "send"
|
||||
assert result["message"] == "focus on testing"
|
||||
|
||||
|
||||
def test_command_dispatch_retry_finds_last_user_message(server):
|
||||
"""command.dispatch /retry walks session['history'] to find the last user message."""
|
||||
sid = "test-session"
|
||||
history = [
|
||||
{"role": "user", "content": "first question"},
|
||||
{"role": "assistant", "content": "first answer"},
|
||||
{"role": "user", "content": "second question"},
|
||||
{"role": "assistant", "content": "second answer"},
|
||||
]
|
||||
server._sessions[sid] = {
|
||||
"session_key": sid,
|
||||
"agent": None,
|
||||
"history": history,
|
||||
"history_lock": threading.Lock(),
|
||||
"history_version": 0,
|
||||
}
|
||||
|
||||
resp = server.handle_request({
|
||||
"id": "r4",
|
||||
"method": "command.dispatch",
|
||||
"params": {"name": "retry", "session_id": sid},
|
||||
})
|
||||
|
||||
assert "error" not in resp
|
||||
result = resp["result"]
|
||||
assert result["type"] == "send"
|
||||
assert result["message"] == "second question"
|
||||
# Verify history was truncated: everything from last user message onward removed
|
||||
assert len(server._sessions[sid]["history"]) == 2
|
||||
assert server._sessions[sid]["history"][-1]["role"] == "assistant"
|
||||
assert server._sessions[sid]["history_version"] == 1
|
||||
|
||||
|
||||
def test_command_dispatch_retry_empty_history(server):
|
||||
"""command.dispatch /retry with empty history returns error."""
|
||||
sid = "test-session"
|
||||
server._sessions[sid] = {
|
||||
"session_key": sid,
|
||||
"agent": None,
|
||||
"history": [],
|
||||
"history_lock": threading.Lock(),
|
||||
"history_version": 0,
|
||||
}
|
||||
|
||||
resp = server.handle_request({
|
||||
"id": "r5",
|
||||
"method": "command.dispatch",
|
||||
"params": {"name": "retry", "session_id": sid},
|
||||
})
|
||||
|
||||
assert "error" in resp
|
||||
assert resp["error"]["code"] == 4018
|
||||
|
||||
|
||||
def test_command_dispatch_retry_handles_multipart_content(server):
|
||||
"""command.dispatch /retry extracts text from multipart content lists."""
|
||||
sid = "test-session"
|
||||
history = [
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": "analyze this"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
|
||||
]},
|
||||
{"role": "assistant", "content": "I see the image."},
|
||||
]
|
||||
server._sessions[sid] = {
|
||||
"session_key": sid,
|
||||
"agent": None,
|
||||
"history": history,
|
||||
"history_lock": threading.Lock(),
|
||||
"history_version": 0,
|
||||
}
|
||||
|
||||
resp = server.handle_request({
|
||||
"id": "r6",
|
||||
"method": "command.dispatch",
|
||||
"params": {"name": "retry", "session_id": sid},
|
||||
})
|
||||
|
||||
assert "error" not in resp
|
||||
result = resp["result"]
|
||||
assert result["type"] == "send"
|
||||
assert result["message"] == "analyze this"
|
||||
|
||||
|
||||
def test_command_dispatch_returns_skill_payload(server):
|
||||
"""command.dispatch returns structured skill payload for the TUI to send()."""
|
||||
sid = "test-session"
|
||||
server._sessions[sid] = {"session_key": sid}
|
||||
|
||||
fake_skills = {"/hermes-agent-dev": {"name": "hermes-agent-dev", "description": "Dev workflow"}}
|
||||
fake_msg = "Loaded skill content here"
|
||||
|
||||
with patch("agent.skill_commands.scan_skill_commands", return_value=fake_skills), \
|
||||
patch("agent.skill_commands.build_skill_invocation_message", return_value=fake_msg):
|
||||
resp = server.handle_request({
|
||||
"id": "r2",
|
||||
"method": "command.dispatch",
|
||||
"params": {"name": "hermes-agent-dev", "session_id": sid},
|
||||
})
|
||||
|
||||
assert "error" not in resp
|
||||
result = resp["result"]
|
||||
assert result["type"] == "skill"
|
||||
assert result["message"] == fake_msg
|
||||
assert result["name"] == "hermes-agent-dev"
|
||||
|
||||
|
||||
# ── dispatch(): pool routing for long handlers (#12546) ──────────────
|
||||
|
||||
|
||||
def test_dispatch_runs_short_handlers_inline(server):
|
||||
"""Non-long handlers return their response synchronously from dispatch()."""
|
||||
server._methods["fast.ping"] = lambda rid, params: server._ok(rid, {"pong": True})
|
||||
|
||||
resp = server.dispatch({"id": "r1", "method": "fast.ping", "params": {}})
|
||||
|
||||
assert resp == {"jsonrpc": "2.0", "id": "r1", "result": {"pong": True}}
|
||||
|
||||
|
||||
def test_dispatch_offloads_long_handlers_and_emits_via_stdout(capture):
|
||||
"""Long handlers run on the pool and write their response via write_json."""
|
||||
server, buf = capture
|
||||
server._methods["slash.exec"] = lambda rid, params: server._ok(rid, {"output": "hi"})
|
||||
|
||||
resp = server.dispatch({"id": "r2", "method": "slash.exec", "params": {}})
|
||||
assert resp is None
|
||||
|
||||
for _ in range(50):
|
||||
if buf.getvalue():
|
||||
break
|
||||
time.sleep(0.01)
|
||||
|
||||
written = json.loads(buf.getvalue())
|
||||
assert written == {"jsonrpc": "2.0", "id": "r2", "result": {"output": "hi"}}
|
||||
|
||||
|
||||
def test_dispatch_long_handler_does_not_block_fast_handler(server):
|
||||
"""A slow long handler must not prevent a concurrent fast handler from completing."""
|
||||
released = threading.Event()
|
||||
server._methods["slash.exec"] = lambda rid, params: (released.wait(timeout=5), server._ok(rid, {"done": True}))[1]
|
||||
server._methods["fast.ping"] = lambda rid, params: server._ok(rid, {"pong": True})
|
||||
|
||||
t0 = time.monotonic()
|
||||
assert server.dispatch({"id": "slow", "method": "slash.exec", "params": {}}) is None
|
||||
|
||||
fast_resp = server.dispatch({"id": "fast", "method": "fast.ping", "params": {}})
|
||||
fast_elapsed = time.monotonic() - t0
|
||||
|
||||
assert fast_resp["result"] == {"pong": True}
|
||||
assert fast_elapsed < 0.5, f"fast handler blocked for {fast_elapsed:.2f}s behind slow handler"
|
||||
|
||||
released.set()
|
||||
|
||||
|
||||
def test_dispatch_long_handler_exception_produces_error_response(capture):
|
||||
"""An exception inside a pool-dispatched handler still yields a JSON-RPC error."""
|
||||
server, buf = capture
|
||||
|
||||
def boom(rid, params):
|
||||
raise RuntimeError("kaboom")
|
||||
|
||||
server._methods["slash.exec"] = boom
|
||||
|
||||
server.dispatch({"id": "r3", "method": "slash.exec", "params": {}})
|
||||
|
||||
for _ in range(50):
|
||||
if buf.getvalue():
|
||||
break
|
||||
time.sleep(0.01)
|
||||
|
||||
written = json.loads(buf.getvalue())
|
||||
assert written["id"] == "r3"
|
||||
assert written["error"]["code"] == -32000
|
||||
assert "kaboom" in written["error"]["message"]
|
||||
|
||||
|
||||
def test_dispatch_unknown_long_method_still_goes_inline(server):
|
||||
"""Method name not in _LONG_HANDLERS takes the sync path even if handler is slow."""
|
||||
server._methods["some.method"] = lambda rid, params: server._ok(rid, {"ok": True})
|
||||
|
||||
resp = server.dispatch({"id": "r4", "method": "some.method", "params": {}})
|
||||
|
||||
assert resp["result"] == {"ok": True}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue