mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-08 08:11:38 +00:00
Merge branch 'main' into fix/show-reasoning-per-platform
This commit is contained in:
commit
ed7b42f889
401 changed files with 66696 additions and 1966 deletions
|
|
@ -42,9 +42,10 @@ class TestToolProgressCallback:
|
|||
def test_emits_tool_call_start(self, mock_conn, event_loop_fixture):
|
||||
"""Tool progress should emit a ToolCallStart update."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
# Run callback in the event loop context
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
|
|
@ -66,9 +67,10 @@ class TestToolProgressCallback:
|
|||
def test_handles_string_args(self, mock_conn, event_loop_fixture):
|
||||
"""If args is a JSON string, it should be parsed."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
|
|
@ -82,9 +84,10 @@ class TestToolProgressCallback:
|
|||
def test_handles_non_dict_args(self, mock_conn, event_loop_fixture):
|
||||
"""If args is not a dict, it should be wrapped."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
|
|
@ -98,10 +101,11 @@ class TestToolProgressCallback:
|
|||
def test_duplicate_same_name_tool_calls_use_fifo_ids(self, mock_conn, event_loop_fixture):
|
||||
"""Multiple same-name tool calls should be tracked independently in order."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
|
|
@ -163,7 +167,7 @@ class TestStepCallback:
|
|||
tool_call_ids = {"terminal": "tc-abc123"}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
|
|
@ -181,7 +185,7 @@ class TestStepCallback:
|
|||
tool_call_ids = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
cb(1, [{"name": "unknown_tool", "result": "ok"}])
|
||||
|
|
@ -193,7 +197,7 @@ class TestStepCallback:
|
|||
tool_call_ids = {"read_file": "tc-def456"}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
|
|
@ -212,7 +216,7 @@ class TestStepCallback:
|
|||
tool_call_ids = {"terminal": deque(["tc-xyz789"])}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
|
|
@ -224,7 +228,7 @@ class TestStepCallback:
|
|||
cb(1, [{"name": "terminal", "result": '{"output": "hello"}'}])
|
||||
|
||||
mock_btc.assert_called_once_with(
|
||||
"tc-xyz789", "terminal", result='{"output": "hello"}'
|
||||
"tc-xyz789", "terminal", result='{"output": "hello"}', function_args=None, snapshot=None
|
||||
)
|
||||
|
||||
def test_none_result_passed_through(self, mock_conn, event_loop_fixture):
|
||||
|
|
@ -234,7 +238,7 @@ class TestStepCallback:
|
|||
tool_call_ids = {"web_search": deque(["tc-aaa"])}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
|
|
@ -244,7 +248,50 @@ class TestStepCallback:
|
|||
|
||||
cb(1, [{"name": "web_search", "result": None}])
|
||||
|
||||
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None)
|
||||
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None, function_args=None, snapshot=None)
|
||||
|
||||
def test_step_callback_passes_arguments_and_snapshot(self, mock_conn, event_loop_fixture):
|
||||
from collections import deque
|
||||
|
||||
tool_call_ids = {"write_file": deque(["tc-write"])}
|
||||
tool_call_meta = {"tc-write": {"args": {"path": "fallback.txt"}, "snapshot": "snap"}}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb(1, [{"name": "write_file", "result": '{"bytes_written": 23}', "arguments": {"path": "diff-test.txt"}}])
|
||||
|
||||
mock_btc.assert_called_once_with(
|
||||
"tc-write",
|
||||
"write_file",
|
||||
result='{"bytes_written": 23}',
|
||||
function_args={"path": "diff-test.txt"},
|
||||
snapshot="snap",
|
||||
)
|
||||
|
||||
def test_tool_progress_captures_snapshot_metadata(self, mock_conn, event_loop_fixture):
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
with patch("acp_adapter.events.make_tool_call_id", return_value="tc-meta"), \
|
||||
patch("acp_adapter.events._send_update") as mock_send, \
|
||||
patch("agent.display.capture_local_edit_snapshot", return_value="snapshot"):
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
cb("tool.started", "write_file", None, {"path": "diff-test.txt", "content": "hello"})
|
||||
|
||||
assert list(tool_call_ids["write_file"]) == ["tc-meta"]
|
||||
assert tool_call_meta["tc-meta"] == {
|
||||
"args": {"path": "diff-test.txt", "content": "hello"},
|
||||
"snapshot": "snapshot",
|
||||
}
|
||||
mock_send.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from acp.schema import (
|
|||
|
||||
from acp_adapter.server import HermesACPAgent
|
||||
from acp_adapter.session import SessionManager
|
||||
from acp_adapter.tools import build_tool_start
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -181,6 +182,25 @@ class TestMcpRegistrationE2E:
|
|||
assert complete_event.raw_output is not None
|
||||
assert "hello" in str(complete_event.raw_output)
|
||||
|
||||
def test_patch_mode_tool_start_emits_diff_blocks_for_v4a_patch(self):
|
||||
update = build_tool_start(
|
||||
"tc-1",
|
||||
"patch",
|
||||
{
|
||||
"mode": "patch",
|
||||
"patch": "*** Begin Patch\n*** Update File: src/app.py\n@@\n-old line\n+new line\n*** Add File: src/new.py\n+hello\n*** End Patch",
|
||||
},
|
||||
)
|
||||
|
||||
assert len(update.content) == 2
|
||||
assert update.content[0].type == "diff"
|
||||
assert update.content[0].path == "src/app.py"
|
||||
assert update.content[0].old_text == "old line"
|
||||
assert update.content[0].new_text == "new line"
|
||||
assert update.content[1].type == "diff"
|
||||
assert update.content[1].path == "src/new.py"
|
||||
assert update.content[1].new_text == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_tool_results_paired_by_call_id(self, acp_agent, mock_manager):
|
||||
"""The ToolCallUpdate's toolCallId must match the ToolCallStart's."""
|
||||
|
|
|
|||
|
|
@ -20,7 +20,9 @@ from acp.schema import (
|
|||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
ResumeSessionResponse,
|
||||
SessionModelState,
|
||||
SetSessionConfigOptionResponse,
|
||||
SetSessionModelResponse,
|
||||
SetSessionModeResponse,
|
||||
SessionInfo,
|
||||
TextContentBlock,
|
||||
|
|
@ -127,6 +129,25 @@ class TestSessionOps:
|
|||
assert state is not None
|
||||
assert state.cwd == "/home/user/project"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_session_returns_model_state(self):
|
||||
manager = SessionManager(
|
||||
agent_factory=lambda: SimpleNamespace(model="gpt-5.4", provider="openai-codex")
|
||||
)
|
||||
acp_agent = HermesACPAgent(session_manager=manager)
|
||||
|
||||
with patch(
|
||||
"hermes_cli.models.curated_models_for_provider",
|
||||
return_value=[("gpt-5.4", "recommended"), ("gpt-5.4-mini", "")],
|
||||
):
|
||||
resp = await acp_agent.new_session(cwd="/tmp")
|
||||
|
||||
assert isinstance(resp.models, SessionModelState)
|
||||
assert resp.models.current_model_id == "openai-codex:gpt-5.4"
|
||||
assert resp.models.available_models[0].model_id == "openai-codex:gpt-5.4"
|
||||
assert resp.models.available_models[0].description is not None
|
||||
assert "Provider:" in resp.models.available_models[0].description
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_available_commands_include_help(self, agent):
|
||||
help_cmd = next(
|
||||
|
|
@ -204,6 +225,33 @@ class TestListAndFork:
|
|||
assert fork_resp.session_id
|
||||
assert fork_resp.session_id != new_resp.session_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_includes_title_and_updated_at(self, agent):
|
||||
with patch.object(
|
||||
agent.session_manager,
|
||||
"list_sessions",
|
||||
return_value=[
|
||||
{
|
||||
"session_id": "session-1",
|
||||
"cwd": "/tmp/project",
|
||||
"title": "Fix Zed session history",
|
||||
"updated_at": 123.0,
|
||||
}
|
||||
],
|
||||
):
|
||||
resp = await agent.list_sessions(cwd="/tmp/project")
|
||||
|
||||
assert isinstance(resp.sessions[0], SessionInfo)
|
||||
assert resp.sessions[0].title == "Fix Zed session history"
|
||||
assert resp.sessions[0].updated_at == "123.0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_passes_cwd_filter(self, agent):
|
||||
with patch.object(agent.session_manager, "list_sessions", return_value=[]) as mock_list:
|
||||
await agent.list_sessions(cwd="/mnt/e/Projects/AI/browser-link-3")
|
||||
|
||||
mock_list.assert_called_once_with(cwd="/mnt/e/Projects/AI/browser-link-3")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session configuration / model routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -257,6 +305,53 @@ class TestSessionConfiguration:
|
|||
assert result == {}
|
||||
assert state.model == "gpt-5.4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_session_model_accepts_provider_prefixed_choice(self, tmp_path, monkeypatch):
|
||||
runtime_calls = []
|
||||
|
||||
def fake_resolve_runtime_provider(requested=None, **kwargs):
|
||||
runtime_calls.append(requested)
|
||||
provider = requested or "openrouter"
|
||||
return {
|
||||
"provider": provider,
|
||||
"api_mode": "anthropic_messages" if provider == "anthropic" else "chat_completions",
|
||||
"base_url": f"https://{provider}.example/v1",
|
||||
"api_key": f"{provider}-key",
|
||||
"command": None,
|
||||
"args": [],
|
||||
}
|
||||
|
||||
def fake_agent(**kwargs):
|
||||
return SimpleNamespace(
|
||||
model=kwargs.get("model"),
|
||||
provider=kwargs.get("provider"),
|
||||
base_url=kwargs.get("base_url"),
|
||||
api_mode=kwargs.get("api_mode"),
|
||||
)
|
||||
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: {
|
||||
"model": {"provider": "openrouter", "default": "openrouter/gpt-5"}
|
||||
})
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
fake_resolve_runtime_provider,
|
||||
)
|
||||
manager = SessionManager(db=SessionDB(tmp_path / "state.db"))
|
||||
|
||||
with patch("run_agent.AIAgent", side_effect=fake_agent):
|
||||
acp_agent = HermesACPAgent(session_manager=manager)
|
||||
state = manager.create_session(cwd="/tmp")
|
||||
result = await acp_agent.set_session_model(
|
||||
model_id="anthropic:claude-sonnet-4-6",
|
||||
session_id=state.session_id,
|
||||
)
|
||||
|
||||
assert isinstance(result, SetSessionModelResponse)
|
||||
assert state.model == "claude-sonnet-4-6"
|
||||
assert state.agent.provider == "anthropic"
|
||||
assert state.agent.base_url == "https://anthropic.example/v1"
|
||||
assert runtime_calls[-1] == "anthropic"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prompt
|
||||
|
|
@ -354,6 +449,31 @@ class TestPrompt:
|
|||
update = last_call[1].get("update") or last_call[0][1]
|
||||
assert update.session_update == "agent_message_chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_auto_titles_session(self, agent):
|
||||
new_resp = await agent.new_session(cwd=".")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
state.agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "Here is the fix.",
|
||||
"messages": [
|
||||
{"role": "user", "content": "fix the broken ACP history"},
|
||||
{"role": "assistant", "content": "Here is the fix."},
|
||||
],
|
||||
})
|
||||
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
with patch("agent.title_generator.maybe_auto_title") as mock_title:
|
||||
prompt = [TextContentBlock(type="text", text="fix the broken ACP history")]
|
||||
await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
mock_title.assert_called_once()
|
||||
assert mock_title.call_args.args[1] == new_resp.session_id
|
||||
assert mock_title.call_args.args[2] == "fix the broken ACP history"
|
||||
assert mock_title.call_args.args[3] == "Here is the fix."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_populates_usage_from_top_level_run_conversation_fields(self, agent):
|
||||
"""ACP should map top-level token fields into PromptResponse.usage."""
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import contextlib
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
|
@ -100,15 +101,23 @@ class TestListAndCleanup:
|
|||
def test_list_sessions_returns_created(self, manager):
|
||||
s1 = manager.create_session(cwd="/a")
|
||||
s2 = manager.create_session(cwd="/b")
|
||||
s1.history.append({"role": "user", "content": "hello from a"})
|
||||
s2.history.append({"role": "user", "content": "hello from b"})
|
||||
listing = manager.list_sessions()
|
||||
ids = {s["session_id"] for s in listing}
|
||||
assert s1.session_id in ids
|
||||
assert s2.session_id in ids
|
||||
assert len(listing) == 2
|
||||
|
||||
def test_list_sessions_hides_empty_threads(self, manager):
|
||||
manager.create_session(cwd="/empty")
|
||||
assert manager.list_sessions() == []
|
||||
|
||||
def test_cleanup_clears_all(self, manager):
|
||||
manager.create_session()
|
||||
manager.create_session()
|
||||
s1 = manager.create_session()
|
||||
s2 = manager.create_session()
|
||||
s1.history.append({"role": "user", "content": "one"})
|
||||
s2.history.append({"role": "user", "content": "two"})
|
||||
assert len(manager.list_sessions()) == 2
|
||||
manager.cleanup()
|
||||
assert manager.list_sessions() == []
|
||||
|
|
@ -194,6 +203,8 @@ class TestPersistence:
|
|||
def test_list_sessions_includes_db_only(self, manager):
|
||||
"""Sessions only in DB (not in memory) appear in list_sessions."""
|
||||
state = manager.create_session(cwd="/db-only")
|
||||
state.history.append({"role": "user", "content": "database only thread"})
|
||||
manager.save_session(state.session_id)
|
||||
sid = state.session_id
|
||||
|
||||
# Drop from memory.
|
||||
|
|
@ -204,6 +215,53 @@ class TestPersistence:
|
|||
ids = {s["session_id"] for s in listing}
|
||||
assert sid in ids
|
||||
|
||||
def test_list_sessions_filters_by_cwd(self, manager):
|
||||
keep = manager.create_session(cwd="/keep")
|
||||
drop = manager.create_session(cwd="/drop")
|
||||
keep.history.append({"role": "user", "content": "keep me"})
|
||||
drop.history.append({"role": "user", "content": "drop me"})
|
||||
|
||||
listing = manager.list_sessions(cwd="/keep")
|
||||
ids = {s["session_id"] for s in listing}
|
||||
assert keep.session_id in ids
|
||||
assert drop.session_id not in ids
|
||||
|
||||
def test_list_sessions_matches_windows_and_wsl_paths(self, manager):
|
||||
state = manager.create_session(cwd="/mnt/e/Projects/AI/browser-link-3")
|
||||
state.history.append({"role": "user", "content": "same project from WSL"})
|
||||
|
||||
listing = manager.list_sessions(cwd=r"E:\Projects\AI\browser-link-3")
|
||||
ids = {s["session_id"] for s in listing}
|
||||
assert state.session_id in ids
|
||||
|
||||
def test_list_sessions_prefers_title_then_preview(self, manager):
|
||||
state = manager.create_session(cwd="/named")
|
||||
state.history.append({"role": "user", "content": "Investigate broken ACP history in Zed"})
|
||||
manager.save_session(state.session_id)
|
||||
db = manager._get_db()
|
||||
db.set_session_title(state.session_id, "Fix Zed ACP history")
|
||||
|
||||
listing = manager.list_sessions(cwd="/named")
|
||||
assert listing[0]["title"] == "Fix Zed ACP history"
|
||||
|
||||
db.set_session_title(state.session_id, "")
|
||||
listing = manager.list_sessions(cwd="/named")
|
||||
assert listing[0]["title"].startswith("Investigate broken ACP history")
|
||||
|
||||
def test_list_sessions_sorted_by_most_recent_activity(self, manager):
|
||||
older = manager.create_session(cwd="/ordered")
|
||||
older.history.append({"role": "user", "content": "older"})
|
||||
manager.save_session(older.session_id)
|
||||
time.sleep(0.02)
|
||||
newer = manager.create_session(cwd="/ordered")
|
||||
newer.history.append({"role": "user", "content": "newer"})
|
||||
manager.save_session(newer.session_id)
|
||||
|
||||
listing = manager.list_sessions(cwd="/ordered")
|
||||
assert [item["session_id"] for item in listing[:2]] == [newer.session_id, older.session_id]
|
||||
assert listing[0]["updated_at"]
|
||||
assert listing[1]["updated_at"]
|
||||
|
||||
def test_fork_restores_source_from_db(self, manager):
|
||||
"""Forking a session that is only in DB should work."""
|
||||
original = manager.create_session()
|
||||
|
|
|
|||
|
|
@ -215,6 +215,46 @@ class TestBuildToolComplete:
|
|||
assert len(display_text) < 6000
|
||||
assert "truncated" in display_text
|
||||
|
||||
def test_build_tool_complete_for_patch_uses_diff_blocks(self):
|
||||
"""Completed patch calls should keep structured diff content for Zed."""
|
||||
patch_result = (
|
||||
'{"success": true, "diff": "--- a/README.md\\n+++ b/README.md\\n@@ -1 +1,2 @@\\n old line\\n+new line\\n", '
|
||||
'"files_modified": ["README.md"]}'
|
||||
)
|
||||
result = build_tool_complete("tc-p1", "patch", patch_result)
|
||||
assert isinstance(result, ToolCallProgress)
|
||||
assert len(result.content) == 1
|
||||
diff_item = result.content[0]
|
||||
assert isinstance(diff_item, FileEditToolCallContent)
|
||||
assert diff_item.path == "README.md"
|
||||
assert diff_item.old_text == "old line"
|
||||
assert diff_item.new_text == "old line\nnew line"
|
||||
|
||||
def test_build_tool_complete_for_patch_falls_back_to_text_when_no_diff(self):
|
||||
result = build_tool_complete("tc-p2", "patch", '{"success": true}')
|
||||
assert isinstance(result, ToolCallProgress)
|
||||
assert isinstance(result.content[0], ContentToolCallContent)
|
||||
|
||||
def test_build_tool_complete_for_write_file_uses_snapshot_diff(self, tmp_path):
|
||||
target = tmp_path / "diff-test.txt"
|
||||
snapshot = type("Snapshot", (), {"paths": [target], "before": {str(target): None}})()
|
||||
target.write_text("hello from hermes\n", encoding="utf-8")
|
||||
|
||||
result = build_tool_complete(
|
||||
"tc-wf1",
|
||||
"write_file",
|
||||
'{"bytes_written": 18, "dirs_created": false}',
|
||||
function_args={"path": str(target), "content": "hello from hermes\n"},
|
||||
snapshot=snapshot,
|
||||
)
|
||||
assert isinstance(result, ToolCallProgress)
|
||||
assert len(result.content) == 1
|
||||
diff_item = result.content[0]
|
||||
assert isinstance(diff_item, FileEditToolCallContent)
|
||||
assert diff_item.path.endswith("diff-test.txt")
|
||||
assert diff_item.old_text is None
|
||||
assert diff_item.new_text == "hello from hermes"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_locations
|
||||
|
|
|
|||
|
|
@ -696,6 +696,95 @@ class TestIsConnectionError:
|
|||
assert _is_connection_error(err) is False
|
||||
|
||||
|
||||
class TestKimiForCodingTemperature:
|
||||
"""kimi-for-coding now requires temperature=0.6 exactly."""
|
||||
|
||||
def test_build_call_kwargs_forces_fixed_temperature(self):
|
||||
from agent.auxiliary_client import _build_call_kwargs
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
provider="kimi-coding",
|
||||
model="kimi-for-coding",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
assert kwargs["temperature"] == 0.6
|
||||
|
||||
def test_build_call_kwargs_injects_temperature_when_missing(self):
|
||||
from agent.auxiliary_client import _build_call_kwargs
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
provider="kimi-coding",
|
||||
model="kimi-for-coding",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=None,
|
||||
)
|
||||
|
||||
assert kwargs["temperature"] == 0.6
|
||||
|
||||
def test_auto_routed_kimi_for_coding_sync_call_uses_fixed_temperature(self):
|
||||
client = MagicMock()
|
||||
client.base_url = "https://api.kimi.com/coding/v1"
|
||||
response = MagicMock()
|
||||
client.chat.completions.create.return_value = response
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._get_cached_client",
|
||||
return_value=(client, "kimi-for-coding"),
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", "kimi-for-coding", None, None, None),
|
||||
):
|
||||
result = call_llm(
|
||||
task="session_search",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
assert result is response
|
||||
kwargs = client.chat.completions.create.call_args.kwargs
|
||||
assert kwargs["model"] == "kimi-for-coding"
|
||||
assert kwargs["temperature"] == 0.6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_routed_kimi_for_coding_async_call_uses_fixed_temperature(self):
|
||||
client = MagicMock()
|
||||
client.base_url = "https://api.kimi.com/coding/v1"
|
||||
response = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(return_value=response)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._get_cached_client",
|
||||
return_value=(client, "kimi-for-coding"),
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", "kimi-for-coding", None, None, None),
|
||||
):
|
||||
result = await async_call_llm(
|
||||
task="session_search",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
assert result is response
|
||||
kwargs = client.chat.completions.create.call_args.kwargs
|
||||
assert kwargs["model"] == "kimi-for-coding"
|
||||
assert kwargs["temperature"] == 0.6
|
||||
|
||||
def test_non_kimi_model_still_preserves_temperature(self):
|
||||
from agent.auxiliary_client import _build_call_kwargs
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
provider="kimi-coding",
|
||||
model="kimi-k2.5",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
assert kwargs["temperature"] == 0.3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# async_call_llm payment / connection fallback (#7512 bug 2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
311
tests/agent/test_auxiliary_main_first.py
Normal file
311
tests/agent/test_auxiliary_main_first.py
Normal file
|
|
@ -0,0 +1,311 @@
|
|||
"""Regression tests for the ``auto`` → main-model-first policy.
|
||||
|
||||
Prior to this change, aggregator users (OpenRouter / Nous Portal) had aux
|
||||
tasks routed through a cheap provider-side default (Gemini Flash) while
|
||||
non-aggregator users got their main model. This made behavior inconsistent
|
||||
and surprising — users picked Claude but got Gemini Flash summaries.
|
||||
|
||||
The current policy: ``auto`` means "use my main chat model" for every user,
|
||||
regardless of provider type. Explicit per-task overrides in ``config.yaml``
|
||||
(``auxiliary.<task>.provider``) still win. The cheap fallback chain only
|
||||
runs when the main provider has no working client.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ── Text aux tasks — _resolve_auto ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestResolveAutoMainFirst:
|
||||
"""_resolve_auto() must prefer main provider + main model for every user."""
|
||||
|
||||
def test_openrouter_main_uses_main_model_for_aux(self, monkeypatch):
|
||||
"""OpenRouter main user → aux uses their picked OR model, not Gemini Flash."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-test-key")
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider",
|
||||
return_value="openrouter",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="anthropic/claude-sonnet-4.6",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve:
|
||||
mock_client = MagicMock()
|
||||
mock_resolve.return_value = (mock_client, "anthropic/claude-sonnet-4.6")
|
||||
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
client, model = _resolve_auto()
|
||||
|
||||
assert client is mock_client
|
||||
assert model == "anthropic/claude-sonnet-4.6"
|
||||
# Verify it asked resolve_provider_client for the MAIN provider+model,
|
||||
# not a fallback-chain provider
|
||||
mock_resolve.assert_called_once()
|
||||
assert mock_resolve.call_args.args[0] == "openrouter"
|
||||
assert mock_resolve.call_args.args[1] == "anthropic/claude-sonnet-4.6"
|
||||
|
||||
def test_nous_main_uses_main_model_for_aux(self, monkeypatch):
|
||||
"""Nous Portal main user → aux uses their picked Nous model, not free-tier MiMo."""
|
||||
# No OPENROUTER_API_KEY → ensures if main failed we'd fall to chain
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="nous",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="anthropic/claude-opus-4.6",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve:
|
||||
mock_client = MagicMock()
|
||||
mock_resolve.return_value = (mock_client, "anthropic/claude-opus-4.6")
|
||||
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
client, model = _resolve_auto()
|
||||
|
||||
assert client is mock_client
|
||||
assert model == "anthropic/claude-opus-4.6"
|
||||
assert mock_resolve.call_args.args[0] == "nous"
|
||||
|
||||
def test_non_aggregator_main_still_uses_main(self, monkeypatch):
|
||||
"""Non-aggregator main (DeepSeek) → unchanged behavior, main model used."""
|
||||
monkeypatch.setenv("DEEPSEEK_API_KEY", "ds-test")
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="deepseek",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model", return_value="deepseek-chat",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve:
|
||||
mock_client = MagicMock()
|
||||
mock_resolve.return_value = (mock_client, "deepseek-chat")
|
||||
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
client, model = _resolve_auto()
|
||||
|
||||
assert client is mock_client
|
||||
assert model == "deepseek-chat"
|
||||
assert mock_resolve.call_args.args[0] == "deepseek"
|
||||
|
||||
def test_main_unavailable_falls_through_to_chain(self, monkeypatch):
|
||||
"""Main provider with no working client → fall back to aux chain."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
chain_client = MagicMock()
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="anthropic",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model", return_value="claude-opus",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(None, None), # main provider has no client
|
||||
), patch(
|
||||
"agent.auxiliary_client._try_openrouter",
|
||||
return_value=(chain_client, "google/gemini-3-flash-preview"),
|
||||
):
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
client, model = _resolve_auto()
|
||||
|
||||
assert client is chain_client
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
|
||||
def test_no_main_config_uses_chain_directly(self):
|
||||
"""No main provider configured → skip step 1, use chain (no regression)."""
|
||||
chain_client = MagicMock()
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model", return_value="",
|
||||
), patch(
|
||||
"agent.auxiliary_client._try_openrouter",
|
||||
return_value=(chain_client, "google/gemini-3-flash-preview"),
|
||||
):
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
client, model = _resolve_auto()
|
||||
|
||||
assert client is chain_client
|
||||
|
||||
def test_runtime_override_wins_over_config(self, monkeypatch):
|
||||
"""main_runtime kwarg overrides config-read main provider/model."""
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider",
|
||||
return_value="openrouter",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model", return_value="config-model",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve:
|
||||
mock_resolve.return_value = (MagicMock(), "runtime-model")
|
||||
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
_resolve_auto(main_runtime={
|
||||
"provider": "anthropic",
|
||||
"model": "runtime-model",
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
"api_mode": "",
|
||||
})
|
||||
|
||||
# Runtime override wins
|
||||
assert mock_resolve.call_args.args[0] == "anthropic"
|
||||
assert mock_resolve.call_args.args[1] == "runtime-model"
|
||||
|
||||
|
||||
# ── Vision — resolve_vision_provider_client ─────────────────────────────────
|
||||
|
||||
|
||||
class TestResolveVisionMainFirst:
|
||||
"""Vision auto-detection prefers main provider + main model first."""
|
||||
|
||||
def test_openrouter_main_vision_uses_main_model(self, monkeypatch):
|
||||
"""OpenRouter main with vision-capable model → aux vision uses main model."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="openrouter",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="anthropic/claude-sonnet-4.6",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve, patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_resolve.return_value = (mock_client, "anthropic/claude-sonnet-4.6")
|
||||
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "openrouter"
|
||||
assert client is mock_client
|
||||
assert model == "anthropic/claude-sonnet-4.6"
|
||||
# Verify it did NOT call the strict vision backend for OpenRouter
|
||||
# (which would have used a cheap gemini-flash-preview default)
|
||||
mock_resolve.assert_called_once()
|
||||
assert mock_resolve.call_args.args[0] == "openrouter"
|
||||
assert mock_resolve.call_args.args[1] == "anthropic/claude-sonnet-4.6"
|
||||
|
||||
def test_nous_main_vision_uses_main_model(self):
|
||||
"""Nous Portal main → aux vision uses main model, not free-tier MiMo-V2-Omni."""
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="nous",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="openai/gpt-5",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve, patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_resolve.return_value = (mock_client, "openai/gpt-5")
|
||||
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "nous"
|
||||
assert model == "openai/gpt-5"
|
||||
|
||||
def test_exotic_provider_with_vision_override_preserved(self):
|
||||
"""xiaomi → mimo-v2-omni override still wins over main_model."""
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="xiaomi",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="mimo-v2-pro", # text model
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve, patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
):
|
||||
mock_resolve.return_value = (MagicMock(), "mimo-v2-omni")
|
||||
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "xiaomi"
|
||||
# Should use mimo-v2-omni (vision override), not mimo-v2-pro (text main)
|
||||
assert mock_resolve.call_args.args[1] == "mimo-v2-omni"
|
||||
|
||||
def test_main_unavailable_vision_falls_through_to_aggregators(self):
|
||||
"""Main provider fails → fall back to OpenRouter/Nous strict backends."""
|
||||
fallback_client = MagicMock()
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="deepseek",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model", return_value="deepseek-chat",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(None, None),
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_strict_vision_backend",
|
||||
return_value=(fallback_client, "google/gemini-3-flash-preview"),
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
):
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert client is fallback_client
|
||||
assert provider in ("openrouter", "nous")
|
||||
|
||||
def test_explicit_provider_override_still_wins(self):
|
||||
"""Explicit config override bypasses main-first policy."""
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="openrouter",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="anthropic/claude-opus-4.6",
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("nous", None, None, None, None), # explicit override
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_strict_vision_backend"
|
||||
) as mock_strict:
|
||||
mock_strict.return_value = (MagicMock(), "nous-default-model")
|
||||
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
# Explicit "nous" override → uses strict backend, NOT main model path
|
||||
assert provider == "nous"
|
||||
mock_strict.assert_called_once_with("nous")
|
||||
|
||||
|
||||
# ── Constant cleanup ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_aggregator_providers_constant_removed():
|
||||
"""The dead _AGGREGATOR_PROVIDERS constant should no longer live in the module.
|
||||
|
||||
Removed when the main-first policy made the aggregator-skip guard obsolete.
|
||||
"""
|
||||
import agent.auxiliary_client as aux_mod
|
||||
|
||||
assert not hasattr(aux_mod, "_AGGREGATOR_PROVIDERS"), (
|
||||
"_AGGREGATOR_PROVIDERS was removed when _resolve_auto stopped "
|
||||
"treating aggregators specially. If you re-added it, the main-first "
|
||||
"policy may have regressed."
|
||||
)
|
||||
|
|
@ -826,6 +826,160 @@ class TestGeminiCloudCodeClient:
|
|||
finally:
|
||||
client.close()
|
||||
|
||||
|
||||
class TestGeminiHttpErrorParsing:
|
||||
"""Regression coverage for _gemini_http_error Google-envelope parsing.
|
||||
|
||||
These are the paths that users actually hit during Google-side throttling
|
||||
(April 2026: gemini-2.5-pro MODEL_CAPACITY_EXHAUSTED, gemma-4-26b-it
|
||||
returning 404). The error needs to carry status_code + response so the
|
||||
main loop's error_classifier and Retry-After logic work.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _fake_response(status: int, body: dict | str = "", headers=None):
|
||||
"""Minimal httpx.Response stand-in (duck-typed for _gemini_http_error)."""
|
||||
class _FakeResponse:
|
||||
def __init__(self):
|
||||
self.status_code = status
|
||||
if isinstance(body, dict):
|
||||
self.text = json.dumps(body)
|
||||
else:
|
||||
self.text = body
|
||||
self.headers = headers or {}
|
||||
return _FakeResponse()
|
||||
|
||||
def test_model_capacity_exhausted_produces_friendly_message(self):
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
body = {
|
||||
"error": {
|
||||
"code": 429,
|
||||
"message": "Resource has been exhausted (e.g. check quota).",
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
|
||||
"reason": "MODEL_CAPACITY_EXHAUSTED",
|
||||
"domain": "googleapis.com",
|
||||
"metadata": {"model": "gemini-2.5-pro"},
|
||||
},
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.RetryInfo",
|
||||
"retryDelay": "30s",
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
err = _gemini_http_error(self._fake_response(429, body))
|
||||
assert err.status_code == 429
|
||||
assert err.code == "code_assist_capacity_exhausted"
|
||||
assert err.retry_after == 30.0
|
||||
assert err.details["reason"] == "MODEL_CAPACITY_EXHAUSTED"
|
||||
# Message must be user-friendly, not a raw JSON dump.
|
||||
message = str(err)
|
||||
assert "gemini-2.5-pro" in message
|
||||
assert "capacity exhausted" in message.lower()
|
||||
assert "30s" in message
|
||||
# response attr is preserved for run_agent's Retry-After header path.
|
||||
assert err.response is not None
|
||||
|
||||
def test_resource_exhausted_without_reason(self):
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
body = {
|
||||
"error": {
|
||||
"code": 429,
|
||||
"message": "Quota exceeded for requests per minute.",
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
}
|
||||
}
|
||||
err = _gemini_http_error(self._fake_response(429, body))
|
||||
assert err.status_code == 429
|
||||
assert err.code == "code_assist_rate_limited"
|
||||
message = str(err)
|
||||
assert "quota" in message.lower()
|
||||
|
||||
def test_404_model_not_found_produces_model_retired_message(self):
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
body = {
|
||||
"error": {
|
||||
"code": 404,
|
||||
"message": "models/gemma-4-26b-it is not found for API version v1internal",
|
||||
"status": "NOT_FOUND",
|
||||
}
|
||||
}
|
||||
err = _gemini_http_error(self._fake_response(404, body))
|
||||
assert err.status_code == 404
|
||||
message = str(err)
|
||||
assert "not available" in message.lower() or "retired" in message.lower()
|
||||
# Error message should reference the actual model text from Google.
|
||||
assert "gemma-4-26b-it" in message
|
||||
|
||||
def test_unauthorized_preserves_status_code(self):
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
err = _gemini_http_error(self._fake_response(
|
||||
401, {"error": {"code": 401, "message": "Invalid token", "status": "UNAUTHENTICATED"}},
|
||||
))
|
||||
assert err.status_code == 401
|
||||
assert err.code == "code_assist_unauthorized"
|
||||
|
||||
def test_retry_after_header_fallback(self):
|
||||
"""If the body has no RetryInfo detail, fall back to Retry-After header."""
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
resp = self._fake_response(
|
||||
429,
|
||||
{"error": {"code": 429, "message": "Rate limited", "status": "RESOURCE_EXHAUSTED"}},
|
||||
headers={"Retry-After": "45"},
|
||||
)
|
||||
err = _gemini_http_error(resp)
|
||||
assert err.retry_after == 45.0
|
||||
|
||||
def test_malformed_body_still_produces_structured_error(self):
|
||||
"""Non-JSON body must not swallow status_code — we still want the classifier path."""
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
err = _gemini_http_error(self._fake_response(500, "<html>internal error</html>"))
|
||||
assert err.status_code == 500
|
||||
# Raw body snippet must still be there for debugging.
|
||||
assert "500" in str(err)
|
||||
|
||||
def test_status_code_flows_through_error_classifier(self):
|
||||
"""End-to-end: CodeAssistError from a 429 must classify as rate_limit.
|
||||
|
||||
This is the whole point of adding status_code to CodeAssistError —
|
||||
_extract_status_code must see it and FailoverReason.rate_limit must
|
||||
fire, so the main loop triggers fallback_providers.
|
||||
"""
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
from agent.error_classifier import classify_api_error, FailoverReason
|
||||
|
||||
body = {
|
||||
"error": {
|
||||
"code": 429,
|
||||
"message": "Resource has been exhausted",
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
|
||||
"reason": "MODEL_CAPACITY_EXHAUSTED",
|
||||
"metadata": {"model": "gemini-2.5-pro"},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
err = _gemini_http_error(self._fake_response(429, body))
|
||||
|
||||
classified = classify_api_error(
|
||||
err, provider="google-gemini-cli", model="gemini-2.5-pro",
|
||||
)
|
||||
assert classified.status_code == 429
|
||||
assert classified.reason == FailoverReason.rate_limit
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Provider registration
|
||||
# =============================================================================
|
||||
|
|
|
|||
71
tests/cli/test_cli_copy_command.py
Normal file
71
tests/cli/test_cli_copy_command.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""Tests for CLI /copy command."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from cli import HermesCLI
|
||||
|
||||
|
||||
def _make_cli() -> HermesCLI:
|
||||
cli_obj = HermesCLI.__new__(HermesCLI)
|
||||
cli_obj.config = {}
|
||||
cli_obj.console = MagicMock()
|
||||
cli_obj.agent = None
|
||||
cli_obj.conversation_history = []
|
||||
cli_obj.session_id = "sess-copy-test"
|
||||
cli_obj._pending_input = MagicMock()
|
||||
cli_obj._app = None
|
||||
return cli_obj
|
||||
|
||||
|
||||
def test_copy_copies_latest_assistant_message():
|
||||
cli_obj = _make_cli()
|
||||
cli_obj.conversation_history = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "first"},
|
||||
{"role": "assistant", "content": "latest"},
|
||||
]
|
||||
|
||||
with patch.object(cli_obj, "_write_osc52_clipboard") as mock_copy:
|
||||
result = cli_obj.process_command("/copy")
|
||||
|
||||
assert result is True
|
||||
mock_copy.assert_called_once_with("latest")
|
||||
|
||||
|
||||
def test_copy_with_index_uses_requested_assistant_message():
|
||||
cli_obj = _make_cli()
|
||||
cli_obj.conversation_history = [
|
||||
{"role": "assistant", "content": "one"},
|
||||
{"role": "assistant", "content": "two"},
|
||||
]
|
||||
|
||||
with patch.object(cli_obj, "_write_osc52_clipboard") as mock_copy:
|
||||
cli_obj.process_command("/copy 1")
|
||||
|
||||
mock_copy.assert_called_once_with("one")
|
||||
|
||||
|
||||
def test_copy_strips_reasoning_blocks_before_copy():
|
||||
cli_obj = _make_cli()
|
||||
cli_obj.conversation_history = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<REASONING_SCRATCHPAD>internal</REASONING_SCRATCHPAD>\nVisible answer",
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(cli_obj, "_write_osc52_clipboard") as mock_copy:
|
||||
cli_obj.process_command("/copy")
|
||||
|
||||
mock_copy.assert_called_once_with("Visible answer")
|
||||
|
||||
|
||||
def test_copy_invalid_index_does_not_copy():
|
||||
cli_obj = _make_cli()
|
||||
cli_obj.conversation_history = [{"role": "assistant", "content": "only"}]
|
||||
|
||||
with patch.object(cli_obj, "_write_osc52_clipboard") as mock_copy, patch("cli._cprint") as mock_print:
|
||||
cli_obj.process_command("/copy 99")
|
||||
|
||||
mock_copy.assert_not_called()
|
||||
assert any("Invalid response number" in str(call) for call in mock_print.call_args_list)
|
||||
|
|
@ -64,6 +64,24 @@ class TestSaveConfigValueAtomic:
|
|||
result = yaml.safe_load(config_env.read_text())
|
||||
assert result["display"]["skin"] == "ares"
|
||||
|
||||
def test_preserves_env_ref_templates_in_unrelated_fields(self, config_env):
|
||||
"""The /model --global persistence path must not inline env-backed secrets."""
|
||||
config_env.write_text(yaml.dump({
|
||||
"custom_providers": [{
|
||||
"name": "tuzi",
|
||||
"api_key": "${TU_ZI_API_KEY}",
|
||||
"model": "claude-opus-4-6",
|
||||
}],
|
||||
"model": {"default": "test-model", "provider": "openrouter"},
|
||||
}))
|
||||
|
||||
from cli import save_config_value
|
||||
save_config_value("model.default", "doubao-pro")
|
||||
|
||||
result = yaml.safe_load(config_env.read_text())
|
||||
assert result["model"]["default"] == "doubao-pro"
|
||||
assert result["custom_providers"][0]["api_key"] == "${TU_ZI_API_KEY}"
|
||||
|
||||
def test_file_not_truncated_on_error(self, config_env, monkeypatch):
|
||||
"""If atomic_yaml_write raises, the original file is untouched."""
|
||||
original_content = config_env.read_text()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@
|
|||
|
||||
Surrogates (U+D800..U+DFFF) are invalid in UTF-8 and crash json.dumps()
|
||||
inside the OpenAI SDK. They can appear via clipboard paste from rich-text
|
||||
editors like Google Docs.
|
||||
editors like Google Docs, OR from byte-level reasoning models (xiaomi/mimo,
|
||||
kimi, glm) emitting lone halves in reasoning output.
|
||||
"""
|
||||
import json
|
||||
import pytest
|
||||
|
|
@ -11,6 +12,7 @@ from unittest.mock import MagicMock, patch
|
|||
from run_agent import (
|
||||
_sanitize_surrogates,
|
||||
_sanitize_messages_surrogates,
|
||||
_sanitize_structure_surrogates,
|
||||
_SURROGATE_RE,
|
||||
)
|
||||
|
||||
|
|
@ -109,6 +111,186 @@ class TestSanitizeMessagesSurrogates:
|
|||
assert "\ufffd" in msgs[0]["content"]
|
||||
|
||||
|
||||
class TestReasoningFieldSurrogates:
|
||||
"""Surrogates in reasoning fields (byte-level reasoning models).
|
||||
|
||||
xiaomi/mimo, kimi, glm and similar byte-level tokenizers can emit lone
|
||||
surrogates in reasoning output. These fields are carried through to the
|
||||
API as `reasoning_content` on assistant messages, and must be sanitized
|
||||
or json.dumps() crashes with 'utf-8' codec can't encode surrogates.
|
||||
"""
|
||||
|
||||
def test_reasoning_field_sanitized(self):
|
||||
msgs = [
|
||||
{"role": "assistant", "content": "ok", "reasoning": "thought \udce2 here"},
|
||||
]
|
||||
assert _sanitize_messages_surrogates(msgs) is True
|
||||
assert "\udce2" not in msgs[0]["reasoning"]
|
||||
assert "\ufffd" in msgs[0]["reasoning"]
|
||||
|
||||
def test_reasoning_content_field_sanitized(self):
|
||||
"""api_messages carry `reasoning_content` built from `reasoning`."""
|
||||
msgs = [
|
||||
{"role": "assistant", "content": "ok", "reasoning_content": "thought \udce2 here"},
|
||||
]
|
||||
assert _sanitize_messages_surrogates(msgs) is True
|
||||
assert "\udce2" not in msgs[0]["reasoning_content"]
|
||||
assert "\ufffd" in msgs[0]["reasoning_content"]
|
||||
|
||||
def test_reasoning_details_nested_sanitized(self):
|
||||
"""reasoning_details is a list of dicts with nested string fields."""
|
||||
msgs = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "ok",
|
||||
"reasoning_details": [
|
||||
{"type": "reasoning.summary", "summary": "summary \udce2 text"},
|
||||
{"type": "reasoning.text", "text": "chain \udc00 of thought"},
|
||||
],
|
||||
},
|
||||
]
|
||||
assert _sanitize_messages_surrogates(msgs) is True
|
||||
assert "\udce2" not in msgs[0]["reasoning_details"][0]["summary"]
|
||||
assert "\ufffd" in msgs[0]["reasoning_details"][0]["summary"]
|
||||
assert "\udc00" not in msgs[0]["reasoning_details"][1]["text"]
|
||||
assert "\ufffd" in msgs[0]["reasoning_details"][1]["text"]
|
||||
|
||||
def test_deeply_nested_reasoning_sanitized(self):
|
||||
"""Nested dicts / lists inside extra fields are recursed into."""
|
||||
msgs = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "ok",
|
||||
"reasoning_details": [
|
||||
{
|
||||
"type": "reasoning.encrypted",
|
||||
"content": {
|
||||
"encrypted_content": "opaque",
|
||||
"text_parts": ["part1", "part2 \udce2 part"],
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
assert _sanitize_messages_surrogates(msgs) is True
|
||||
assert (
|
||||
msgs[0]["reasoning_details"][0]["content"]["text_parts"][1]
|
||||
== "part2 \ufffd part"
|
||||
)
|
||||
|
||||
def test_reasoning_end_to_end_json_serialization(self):
|
||||
"""After sanitization, the full message dict must serialize clean."""
|
||||
msgs = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "answer",
|
||||
"reasoning_content": "reasoning with \udce2 surrogate",
|
||||
"reasoning_details": [
|
||||
{"summary": "nested \udcb0 surrogate"},
|
||||
],
|
||||
},
|
||||
]
|
||||
_sanitize_messages_surrogates(msgs)
|
||||
# Must round-trip through json + utf-8 encoding without error
|
||||
payload = json.dumps(msgs, ensure_ascii=False).encode("utf-8")
|
||||
assert b"\\" not in payload[:0] # sanity — just ensure we got bytes
|
||||
assert len(payload) > 0
|
||||
|
||||
def test_no_surrogates_returns_false(self):
|
||||
"""Clean reasoning fields don't trigger a modification."""
|
||||
msgs = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "ok",
|
||||
"reasoning": "clean thought",
|
||||
"reasoning_content": "also clean",
|
||||
"reasoning_details": [{"summary": "clean summary"}],
|
||||
},
|
||||
]
|
||||
assert _sanitize_messages_surrogates(msgs) is False
|
||||
|
||||
|
||||
class TestSanitizeStructureSurrogates:
|
||||
"""Test the _sanitize_structure_surrogates() helper for nested payloads."""
|
||||
|
||||
def test_empty_payload(self):
|
||||
assert _sanitize_structure_surrogates({}) is False
|
||||
assert _sanitize_structure_surrogates([]) is False
|
||||
|
||||
def test_flat_dict(self):
|
||||
payload = {"a": "clean", "b": "dirty \udce2 text"}
|
||||
assert _sanitize_structure_surrogates(payload) is True
|
||||
assert payload["a"] == "clean"
|
||||
assert "\ufffd" in payload["b"]
|
||||
|
||||
def test_flat_list(self):
|
||||
payload = ["clean", "dirty \udce2"]
|
||||
assert _sanitize_structure_surrogates(payload) is True
|
||||
assert payload[0] == "clean"
|
||||
assert "\ufffd" in payload[1]
|
||||
|
||||
def test_nested_dict_in_list(self):
|
||||
payload = [{"x": "dirty \udce2"}, {"x": "clean"}]
|
||||
assert _sanitize_structure_surrogates(payload) is True
|
||||
assert "\ufffd" in payload[0]["x"]
|
||||
assert payload[1]["x"] == "clean"
|
||||
|
||||
def test_deeply_nested(self):
|
||||
payload = {
|
||||
"level1": {
|
||||
"level2": [
|
||||
{"level3": "deep \udce2 surrogate"},
|
||||
],
|
||||
},
|
||||
}
|
||||
assert _sanitize_structure_surrogates(payload) is True
|
||||
assert "\ufffd" in payload["level1"]["level2"][0]["level3"]
|
||||
|
||||
def test_clean_payload_returns_false(self):
|
||||
payload = {"a": "clean", "b": [{"c": "also clean"}]}
|
||||
assert _sanitize_structure_surrogates(payload) is False
|
||||
|
||||
def test_non_string_values_ignored(self):
|
||||
payload = {"int": 42, "list": [1, 2, 3], "dict": {"none": None}, "bool": True}
|
||||
assert _sanitize_structure_surrogates(payload) is False
|
||||
# Non-string values survive unchanged
|
||||
assert payload["int"] == 42
|
||||
assert payload["list"] == [1, 2, 3]
|
||||
|
||||
|
||||
class TestApiMessagesSurrogateRecovery:
|
||||
"""Integration: verify the recovery block sanitizes api_messages.
|
||||
|
||||
The bug this guards against: a surrogate in `reasoning_content` on
|
||||
api_messages (transformed from `reasoning` during build) crashes the
|
||||
OpenAI SDK's json.dumps(), and the recovery block previously only
|
||||
sanitized the canonical `messages` list — not `api_messages` — so the
|
||||
next retry would send the same broken payload and fail 3 times.
|
||||
"""
|
||||
|
||||
def test_api_messages_reasoning_content_sanitized(self):
|
||||
"""The extended sanitizer catches reasoning_content in api_messages."""
|
||||
api_messages = [
|
||||
{"role": "system", "content": "sys"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "response",
|
||||
"reasoning_content": "thought \udce2 trail",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"function": {"name": "tool", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "result", "tool_call_id": "call_1"},
|
||||
]
|
||||
assert _sanitize_messages_surrogates(api_messages) is True
|
||||
assert "\udce2" not in api_messages[1]["reasoning_content"]
|
||||
# Full payload must now serialize clean
|
||||
json.dumps(api_messages, ensure_ascii=False).encode("utf-8")
|
||||
|
||||
|
||||
class TestRunConversationSurrogateSanitization:
|
||||
"""Integration: verify run_conversation sanitizes user_message."""
|
||||
|
||||
|
|
|
|||
|
|
@ -184,6 +184,8 @@ _HERMES_BEHAVIORAL_VARS = frozenset({
|
|||
"HERMES_BACKGROUND_NOTIFICATIONS",
|
||||
"HERMES_EXEC_ASK",
|
||||
"HERMES_HOME_MODE",
|
||||
"BROWSER_CDP_URL",
|
||||
"CAMOFOX_URL",
|
||||
})
|
||||
|
||||
|
||||
|
|
@ -229,6 +231,15 @@ def _hermetic_environment(tmp_path, monkeypatch):
|
|||
monkeypatch.setenv("LC_ALL", "C.UTF-8")
|
||||
monkeypatch.setenv("PYTHONHASHSEED", "0")
|
||||
|
||||
# 4b. Disable AWS IMDS lookups. Without this, any test that ends up
|
||||
# calling has_aws_credentials() / resolve_aws_auth_env_var()
|
||||
# (e.g. provider auto-detect, status command, cron run_job) burns
|
||||
# ~2s waiting for the metadata service at 169.254.169.254 to time
|
||||
# out. Tests don't run on EC2 — IMDS is always unreachable here.
|
||||
monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "true")
|
||||
monkeypatch.setenv("AWS_METADATA_SERVICE_TIMEOUT", "1")
|
||||
monkeypatch.setenv("AWS_METADATA_SERVICE_NUM_ATTEMPTS", "1")
|
||||
|
||||
# 5. Reset plugin singleton so tests don't leak plugins from
|
||||
# ~/.hermes/plugins/ (which, per step 3, is now empty — but the
|
||||
# singleton might still be cached from a previous test).
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from unittest.mock import patch
|
|||
|
||||
from gateway.channel_directory import (
|
||||
build_channel_directory,
|
||||
lookup_channel_type,
|
||||
resolve_channel_name,
|
||||
format_directory_for_display,
|
||||
load_directory,
|
||||
|
|
@ -285,3 +286,49 @@ class TestFormatDirectoryForDisplay:
|
|||
assert "Discord (Server1):" in result
|
||||
assert "Discord (Server2):" in result
|
||||
assert "discord:#general" in result
|
||||
|
||||
|
||||
class TestLookupChannelType:
|
||||
def _setup(self, tmp_path, platforms):
|
||||
cache_file = _write_directory(tmp_path, platforms)
|
||||
return patch("gateway.channel_directory.DIRECTORY_PATH", cache_file)
|
||||
|
||||
def test_forum_channel(self, tmp_path):
|
||||
platforms = {
|
||||
"discord": [
|
||||
{"id": "100", "name": "ideas", "guild": "Server1", "type": "forum"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert lookup_channel_type("discord", "100") == "forum"
|
||||
|
||||
def test_regular_channel(self, tmp_path):
|
||||
platforms = {
|
||||
"discord": [
|
||||
{"id": "200", "name": "general", "guild": "Server1", "type": "channel"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert lookup_channel_type("discord", "200") == "channel"
|
||||
|
||||
def test_unknown_chat_id_returns_none(self, tmp_path):
|
||||
platforms = {
|
||||
"discord": [
|
||||
{"id": "200", "name": "general", "guild": "Server1", "type": "channel"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert lookup_channel_type("discord", "999") is None
|
||||
|
||||
def test_unknown_platform_returns_none(self, tmp_path):
|
||||
with self._setup(tmp_path, {}):
|
||||
assert lookup_channel_type("discord", "100") is None
|
||||
|
||||
def test_channel_without_type_key_returns_none(self, tmp_path):
|
||||
platforms = {
|
||||
"discord": [
|
||||
{"id": "300", "name": "general", "guild": "Server1"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert lookup_channel_type("discord", "300") is None
|
||||
|
|
|
|||
|
|
@ -160,6 +160,30 @@ class TestCommandBypassActiveSession:
|
|||
assert sk not in adapter._pending_messages
|
||||
assert any("handled:status" in r for r in adapter.sent_responses)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agents_bypasses_guard(self):
|
||||
"""/agents must bypass so active-task queries don't interrupt runs."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/agents"))
|
||||
|
||||
assert sk not in adapter._pending_messages
|
||||
assert any("handled:agents" in r for r in adapter.sent_responses)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tasks_alias_bypasses_guard(self):
|
||||
"""/tasks alias must bypass active-session guard too."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/tasks"))
|
||||
|
||||
assert sk not in adapter._pending_messages
|
||||
assert any("handled:tasks" in r for r in adapter.sent_responses)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_bypasses_guard(self):
|
||||
"""/background must bypass so it spawns a parallel task, not an interrupt."""
|
||||
|
|
@ -176,6 +200,38 @@ class TestCommandBypassActiveSession:
|
|||
"/background response was not sent back to the user"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_bypasses_guard(self):
|
||||
"""/help must bypass so it is not silently dropped as pending slash text."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/help"))
|
||||
|
||||
assert sk not in adapter._pending_messages, (
|
||||
"/help was queued as a pending message instead of being dispatched"
|
||||
)
|
||||
assert any("handled:help" in r for r in adapter.sent_responses), (
|
||||
"/help response was not sent back to the user"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_bypasses_guard(self):
|
||||
"""/update must bypass so it is not discarded by the pending-command safety net."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/update"))
|
||||
|
||||
assert sk not in adapter._pending_messages, (
|
||||
"/update was queued as a pending message instead of being dispatched"
|
||||
)
|
||||
assert any("handled:update" in r for r in adapter.sent_responses), (
|
||||
"/update response was not sent back to the user"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_bypasses_guard(self):
|
||||
"""/queue must bypass so it can queue without interrupting."""
|
||||
|
|
|
|||
|
|
@ -198,7 +198,7 @@ class TestSend:
|
|||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
adapter._http_client = mock_client
|
||||
adapter._session_webhooks["chat-123"] = "https://cached.example/webhook"
|
||||
adapter._session_webhooks["chat-123"] = ("https://cached.example/webhook", 9999999999999)
|
||||
|
||||
result = await adapter.send("chat-123", "Hello!")
|
||||
assert result.success is True
|
||||
|
|
@ -681,3 +681,290 @@ class TestIncomingHandlerProcess:
|
|||
processing_gate.set()
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text extraction — mention preservation + platform sanity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractTextMentions:
|
||||
|
||||
def test_preserves_at_mentions_in_text(self):
|
||||
"""@mentions are routing signals (via isInAtList), not text to strip.
|
||||
|
||||
Stripping all @handles collateral-damages emails, SSH URLs, and
|
||||
literal references the user wrote.
|
||||
"""
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
cases = [
|
||||
("@bot hello", "@bot hello"),
|
||||
("contact alice@example.com", "contact alice@example.com"),
|
||||
("git@github.com:foo/bar.git", "git@github.com:foo/bar.git"),
|
||||
("what does @openai think", "what does @openai think"),
|
||||
("@机器人 转发给 @老王", "@机器人 转发给 @老王"),
|
||||
]
|
||||
for text, expected in cases:
|
||||
msg = MagicMock()
|
||||
msg.text = text
|
||||
msg.rich_text = None
|
||||
msg.rich_text_content = None
|
||||
assert DingTalkAdapter._extract_text(msg) == expected, (
|
||||
f"mangled: {text!r} -> {DingTalkAdapter._extract_text(msg)!r}"
|
||||
)
|
||||
|
||||
def test_dingtalk_in_platform_enum(self):
|
||||
assert Platform.DINGTALK.value == "dingtalk"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concurrency — chat-scoped message context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageContextIsolation:
|
||||
|
||||
def test_contexts_keyed_by_chat_id(self):
|
||||
"""Two concurrent chats must not clobber each other's context."""
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
|
||||
msg_a = MagicMock(conversation_id="chat-A", sender_staff_id="user-A")
|
||||
msg_b = MagicMock(conversation_id="chat-B", sender_staff_id="user-B")
|
||||
adapter._message_contexts["chat-A"] = msg_a
|
||||
adapter._message_contexts["chat-B"] = msg_b
|
||||
|
||||
assert adapter._message_contexts["chat-A"] is msg_a
|
||||
assert adapter._message_contexts["chat-B"] is msg_b
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Card lifecycle: finalize via metadata["streaming"]
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCardLifecycle:
|
||||
|
||||
@pytest.fixture
|
||||
def adapter_with_card(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
a = DingTalkAdapter(PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"card_template_id": "tmpl-1"},
|
||||
))
|
||||
a._card_sdk = MagicMock()
|
||||
a._card_sdk.create_card_with_options_async = AsyncMock()
|
||||
a._card_sdk.deliver_card_with_options_async = AsyncMock()
|
||||
a._card_sdk.streaming_update_with_options_async = AsyncMock()
|
||||
a._http_client = AsyncMock()
|
||||
a._get_access_token = AsyncMock(return_value="token")
|
||||
# Minimal message context
|
||||
msg = MagicMock(
|
||||
conversation_id="chat-1",
|
||||
conversation_type="1",
|
||||
sender_staff_id="staff-1",
|
||||
message_id="user-msg-1",
|
||||
)
|
||||
a._message_contexts["chat-1"] = msg
|
||||
a._session_webhooks["chat-1"] = (
|
||||
"https://api.dingtalk.com/x", 9999999999999,
|
||||
)
|
||||
return a
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_final_reply_finalizes_card(self, adapter_with_card):
|
||||
"""send(reply_to=...) creates a closed card (final response path)."""
|
||||
a = adapter_with_card
|
||||
result = await a.send("chat-1", "Hello", reply_to="user-msg-1")
|
||||
assert result.success
|
||||
call = a._card_sdk.streaming_update_with_options_async.call_args
|
||||
assert call[0][0].is_finalize is True
|
||||
# Not tracked as streaming — it's already closed.
|
||||
assert "chat-1" not in a._streaming_cards
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intermediate_send_stays_streaming(self, adapter_with_card):
|
||||
"""send() without reply_to creates an OPEN card (tool progress /
|
||||
commentary / streaming first chunk). No flicker closed→streaming
|
||||
when edit_message follows."""
|
||||
a = adapter_with_card
|
||||
result = await a.send("chat-1", "💻 terminal: ls")
|
||||
assert result.success
|
||||
call = a._card_sdk.streaming_update_with_options_async.call_args
|
||||
assert call[0][0].is_finalize is False
|
||||
# Tracked for sibling cleanup.
|
||||
assert result.message_id in a._streaming_cards.get("chat-1", {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_done_fires_only_when_reply_to_is_set(self, adapter_with_card):
|
||||
"""reply_to distinguishes final response (base.py) from tool-progress
|
||||
sends (run.py). Done must only fire for the former."""
|
||||
a = adapter_with_card
|
||||
fired: list[str] = []
|
||||
a._fire_done_reaction = lambda cid: fired.append(cid)
|
||||
|
||||
# Tool-progress / commentary path: no reply_to — no Done.
|
||||
await a.send("chat-1", "tool line")
|
||||
assert fired == []
|
||||
|
||||
# Final response path: reply_to set — Done fires.
|
||||
await a.send("chat-1", "final", reply_to="user-msg-1")
|
||||
assert fired == ["chat-1"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_message_finalize_fires_done(self, adapter_with_card):
|
||||
"""Stream consumer's final edit_message(finalize=True) fires Done."""
|
||||
a = adapter_with_card
|
||||
fired: list[str] = []
|
||||
a._fire_done_reaction = lambda cid: fired.append(cid)
|
||||
|
||||
await a.send("chat-1", "initial")
|
||||
# Reopen via edit_message(finalize=False) then close.
|
||||
await a.edit_message(
|
||||
chat_id="chat-1", message_id="track-X",
|
||||
content="streaming...", finalize=False,
|
||||
)
|
||||
await a.edit_message(
|
||||
chat_id="chat-1", message_id="track-X",
|
||||
content="final", finalize=True,
|
||||
)
|
||||
assert "chat-1" in fired
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_message_finalize_false_tracks_sibling(self, adapter_with_card):
|
||||
"""After edit_message(finalize=False), card is tracked as open."""
|
||||
a = adapter_with_card
|
||||
await a.edit_message(
|
||||
chat_id="chat-1", message_id="track-1",
|
||||
content="partial", finalize=False,
|
||||
)
|
||||
assert "chat-1" in a._streaming_cards
|
||||
assert a._streaming_cards["chat-1"].get("track-1") == "partial"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_next_send_auto_closes_sibling_streaming_cards(
|
||||
self, adapter_with_card,
|
||||
):
|
||||
"""Tool-progress card left open (send without reply_to + edits) must
|
||||
be auto-closed when the final-reply send arrives."""
|
||||
a = adapter_with_card
|
||||
# First tool: intermediate send — card stays open.
|
||||
r1 = await a.send("chat-1", "💻 tool1")
|
||||
# Second tool: edit_message(finalize=False) — keeps streaming.
|
||||
await a.edit_message(
|
||||
chat_id="chat-1", message_id=r1.message_id,
|
||||
content="💻 tool1\n💻 tool2", finalize=False,
|
||||
)
|
||||
assert r1.message_id in a._streaming_cards.get("chat-1", {})
|
||||
a._card_sdk.streaming_update_with_options_async.reset_mock()
|
||||
|
||||
# Final response send auto-closes the sibling.
|
||||
await a.send("chat-1", "final answer", reply_to="user-msg")
|
||||
|
||||
calls = a._card_sdk.streaming_update_with_options_async.call_args_list
|
||||
assert len(calls) >= 2
|
||||
# First call was the sibling close with last-seen tool-progress content.
|
||||
first_req = calls[0][0][0]
|
||||
assert first_req.out_track_id == r1.message_id
|
||||
assert first_req.is_finalize is True
|
||||
assert "tool1" in first_req.content
|
||||
# Streaming tracking is cleared after close.
|
||||
assert "chat-1" not in a._streaming_cards
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_message_requires_message_id(self, adapter_with_card):
|
||||
a = adapter_with_card
|
||||
result = await a.edit_message(
|
||||
chat_id="chat-1", message_id="", content="x", finalize=True,
|
||||
)
|
||||
assert result.success is False
|
||||
a._card_sdk.streaming_update_with_options_async.assert_not_called()
|
||||
|
||||
def test_fire_done_reaction_is_idempotent(self, adapter_with_card):
|
||||
a = adapter_with_card
|
||||
captured = []
|
||||
def _capture(coro):
|
||||
captured.append(coro)
|
||||
a._spawn_bg = _capture
|
||||
|
||||
a._fire_done_reaction("chat-1")
|
||||
a._fire_done_reaction("chat-1")
|
||||
assert len(captured) == 1
|
||||
captured[0].close()
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AI Card Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDingTalkAdapterAICards:
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
return PlatformConfig(
|
||||
enabled=True,
|
||||
extra={
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"card_template_id": "test_card_template",
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stream_client(self):
|
||||
client = MagicMock()
|
||||
client.get_access_token = MagicMock(return_value="test_token")
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_client(self):
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
msg = MagicMock()
|
||||
msg.message_id = "test_msg_id"
|
||||
msg.conversation_id = "test_conv_id"
|
||||
msg.conversation_type = "1"
|
||||
msg.sender_id = "sender1"
|
||||
msg.sender_nick = "Test User"
|
||||
msg.sender_staff_id = "staff1"
|
||||
msg.text = MagicMock(content="Hello")
|
||||
msg.session_webhook = "https://api.dingtalk.com/robot/sendBySession?session=test"
|
||||
msg.session_webhook_expired_time = 999999999999
|
||||
msg.create_at = int(datetime.now(tz=timezone.utc).timestamp() * 1000)
|
||||
msg.at_users = []
|
||||
return msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_ai_card_if_configured(self, config, mock_stream_client, mock_http_client, mock_message):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
|
||||
adapter = DingTalkAdapter(config)
|
||||
adapter._stream_client = mock_stream_client
|
||||
adapter._http_client = mock_http_client
|
||||
adapter._message_contexts["test_conv_id"] = mock_message
|
||||
adapter._session_webhooks = {"test_conv_id": ("https://api.dingtalk.com/robot/sendBySession?session=test", 9999999999999)}
|
||||
adapter._card_template_id = "test_card_template"
|
||||
|
||||
# Mock the card SDK with proper async methods
|
||||
mock_card_sdk = MagicMock()
|
||||
mock_card_sdk.create_card_with_options_async = AsyncMock()
|
||||
mock_card_sdk.deliver_card_with_options_async = AsyncMock()
|
||||
mock_card_sdk.streaming_update_with_options_async = AsyncMock()
|
||||
adapter._card_sdk = mock_card_sdk
|
||||
|
||||
# Mock access token
|
||||
adapter._get_access_token = AsyncMock(return_value="test_token")
|
||||
|
||||
result = await adapter.send("test_conv_id", "Hello World")
|
||||
|
||||
mock_card_sdk.create_card_with_options_async.assert_called_once()
|
||||
mock_card_sdk.deliver_card_with_options_async.assert_called_once()
|
||||
mock_card_sdk.streaming_update_with_options_async.assert_called_once()
|
||||
assert result.success is True
|
||||
|
|
|
|||
|
|
@ -157,3 +157,232 @@ async def test_send_does_not_retry_on_unrelated_errors():
|
|||
# Only the first attempt happens — no reference-retry replay.
|
||||
assert channel.send.await_count == 1
|
||||
assert send_calls[0]["reference"] is reference_obj
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Forum channel tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import discord as _discord_mod # noqa: E402 — imported after _ensure_discord_mock
|
||||
|
||||
|
||||
class TestIsForumParent:
|
||||
def test_none_returns_false(self):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
assert adapter._is_forum_parent(None) is False
|
||||
|
||||
def test_forum_channel_class_instance(self):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
forum_cls = getattr(_discord_mod, "ForumChannel", None)
|
||||
if forum_cls is None:
|
||||
# Re-create a type for the mock
|
||||
forum_cls = type("ForumChannel", (), {})
|
||||
_discord_mod.ForumChannel = forum_cls
|
||||
ch = forum_cls()
|
||||
assert adapter._is_forum_parent(ch) is True
|
||||
|
||||
def test_type_value_15(self):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
ch = SimpleNamespace(type=15)
|
||||
assert adapter._is_forum_parent(ch) is True
|
||||
|
||||
def test_regular_channel_returns_false(self):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
ch = SimpleNamespace(type=0)
|
||||
assert adapter._is_forum_parent(ch) is False
|
||||
|
||||
def test_thread_returns_false(self):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
ch = SimpleNamespace(type=11) # public thread
|
||||
assert adapter._is_forum_parent(ch) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_forum_creates_thread_post():
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
# thread object has no 'send' so _send_to_forum uses thread.thread
|
||||
thread_ch = SimpleNamespace(id=555, send=AsyncMock(return_value=SimpleNamespace(id=600)))
|
||||
thread = SimpleNamespace(
|
||||
id=555,
|
||||
message=SimpleNamespace(id=500),
|
||||
thread=thread_ch,
|
||||
)
|
||||
forum_channel = _discord_mod.ForumChannel()
|
||||
forum_channel.id = 999
|
||||
forum_channel.name = "ideas"
|
||||
forum_channel.create_thread = AsyncMock(return_value=thread)
|
||||
adapter._client = SimpleNamespace(
|
||||
get_channel=lambda _chat_id: forum_channel,
|
||||
fetch_channel=AsyncMock(),
|
||||
)
|
||||
|
||||
result = await adapter.send("999", "Hello forum!")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "500"
|
||||
forum_channel.create_thread.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_forum_sends_remaining_chunks():
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
# Force a small max message length so the message splits
|
||||
adapter.MAX_MESSAGE_LENGTH = 20
|
||||
|
||||
chunk_msg_1 = SimpleNamespace(id=500)
|
||||
chunk_msg_2 = SimpleNamespace(id=501)
|
||||
thread_ch = SimpleNamespace(
|
||||
id=555,
|
||||
send=AsyncMock(return_value=chunk_msg_2),
|
||||
)
|
||||
# thread object has no 'send' so _send_to_forum uses thread.thread
|
||||
thread = SimpleNamespace(
|
||||
id=555,
|
||||
message=chunk_msg_1,
|
||||
thread=thread_ch,
|
||||
)
|
||||
forum_channel = _discord_mod.ForumChannel()
|
||||
forum_channel.id = 999
|
||||
forum_channel.name = "ideas"
|
||||
forum_channel.create_thread = AsyncMock(return_value=thread)
|
||||
adapter._client = SimpleNamespace(
|
||||
get_channel=lambda _chat_id: forum_channel,
|
||||
fetch_channel=AsyncMock(),
|
||||
)
|
||||
|
||||
result = await adapter.send("999", "A" * 50)
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "500"
|
||||
# Should have sent at least one follow-up chunk
|
||||
assert thread_ch.send.await_count >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_forum_create_thread_failure():
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
forum_channel = _discord_mod.ForumChannel()
|
||||
forum_channel.id = 999
|
||||
forum_channel.name = "ideas"
|
||||
forum_channel.create_thread = AsyncMock(side_effect=Exception("rate limited"))
|
||||
adapter._client = SimpleNamespace(
|
||||
get_channel=lambda _chat_id: forum_channel,
|
||||
fetch_channel=AsyncMock(),
|
||||
)
|
||||
|
||||
result = await adapter.send("999", "Hello forum!")
|
||||
|
||||
assert result.success is False
|
||||
assert "rate limited" in result.error
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Forum follow-up chunk failure reporting + media on forum paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_forum_follow_up_chunk_failures_collected_as_warnings():
|
||||
"""Partial-send chunk failures surface in raw_response['warnings']."""
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
adapter.MAX_MESSAGE_LENGTH = 20
|
||||
|
||||
chunk_msg_1 = SimpleNamespace(id=500)
|
||||
# Every follow-up chunk fails — we should collect a warning per failure
|
||||
thread_ch = SimpleNamespace(
|
||||
id=555,
|
||||
send=AsyncMock(side_effect=Exception("rate limited")),
|
||||
)
|
||||
thread = SimpleNamespace(id=555, message=chunk_msg_1, thread=thread_ch)
|
||||
forum_channel = _discord_mod.ForumChannel()
|
||||
forum_channel.id = 999
|
||||
forum_channel.name = "ideas"
|
||||
forum_channel.create_thread = AsyncMock(return_value=thread)
|
||||
adapter._client = SimpleNamespace(
|
||||
get_channel=lambda _chat_id: forum_channel,
|
||||
fetch_channel=AsyncMock(),
|
||||
)
|
||||
|
||||
# Long enough to produce multiple chunks
|
||||
result = await adapter.send("999", "A" * 60)
|
||||
|
||||
# Starter message (first chunk) was delivered via create_thread, so send is
|
||||
# successful overall — but follow-up chunks all failed and are reported.
|
||||
assert result.success is True
|
||||
assert result.message_id == "500"
|
||||
warnings = (result.raw_response or {}).get("warnings") or []
|
||||
assert len(warnings) >= 1
|
||||
assert all("rate limited" in w for w in warnings)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forum_post_file_creates_thread_with_attachment():
|
||||
"""_forum_post_file routes file-bearing sends to create_thread with file kwarg."""
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
thread_ch = SimpleNamespace(id=777, send=AsyncMock())
|
||||
thread = SimpleNamespace(id=777, message=SimpleNamespace(id=800), thread=thread_ch)
|
||||
forum_channel = _discord_mod.ForumChannel()
|
||||
forum_channel.id = 999
|
||||
forum_channel.name = "ideas"
|
||||
forum_channel.create_thread = AsyncMock(return_value=thread)
|
||||
|
||||
# discord.File is a real class; build a MagicMock that looks like one
|
||||
fake_file = SimpleNamespace(filename="photo.png")
|
||||
|
||||
result = await adapter._forum_post_file(
|
||||
forum_channel,
|
||||
content="here is a photo",
|
||||
file=fake_file,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "800"
|
||||
forum_channel.create_thread.assert_awaited_once()
|
||||
call_kwargs = forum_channel.create_thread.await_args.kwargs
|
||||
assert call_kwargs["file"] is fake_file
|
||||
assert call_kwargs["content"] == "here is a photo"
|
||||
# Thread name derived from content's first line
|
||||
assert call_kwargs["name"] == "here is a photo"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forum_post_file_uses_filename_when_no_content():
|
||||
"""Thread name falls back to file.filename when no content is provided."""
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
thread = SimpleNamespace(id=1, message=SimpleNamespace(id=2), thread=SimpleNamespace(id=1, send=AsyncMock()))
|
||||
forum_channel = _discord_mod.ForumChannel()
|
||||
forum_channel.id = 10
|
||||
forum_channel.name = "forum"
|
||||
forum_channel.create_thread = AsyncMock(return_value=thread)
|
||||
|
||||
fake_file = SimpleNamespace(filename="voice-message.ogg")
|
||||
result = await adapter._forum_post_file(forum_channel, content="", file=fake_file)
|
||||
|
||||
assert result.success is True
|
||||
call_kwargs = forum_channel.create_thread.await_args.kwargs
|
||||
# Content was empty → thread name derived from filename
|
||||
assert call_kwargs["name"] == "voice-message.ogg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forum_post_file_creation_failure():
|
||||
"""_forum_post_file returns a failed SendResult when create_thread raises."""
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
forum_channel = _discord_mod.ForumChannel()
|
||||
forum_channel.id = 999
|
||||
forum_channel.create_thread = AsyncMock(side_effect=Exception("missing perms"))
|
||||
|
||||
result = await adapter._forum_post_file(
|
||||
forum_channel,
|
||||
content="hi",
|
||||
file=SimpleNamespace(filename="x.png"),
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "missing perms" in (result.error or "")
|
||||
|
|
|
|||
|
|
@ -601,6 +601,10 @@ class TestAdapterBehavior(unittest.TestCase):
|
|||
calls.append("message_recalled")
|
||||
return self
|
||||
|
||||
def register_p2_customized_event(self, event_key, _handler):
|
||||
calls.append(f"customized:{event_key}")
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
calls.append("build")
|
||||
return "handler"
|
||||
|
|
@ -628,6 +632,7 @@ class TestAdapterBehavior(unittest.TestCase):
|
|||
"bot_deleted",
|
||||
"p2p_chat_entered",
|
||||
"message_recalled",
|
||||
"customized:drive.notice.comment_add_v1",
|
||||
"build",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
261
tests/gateway/test_feishu_comment.py
Normal file
261
tests/gateway/test_feishu_comment.py
Normal file
|
|
@ -0,0 +1,261 @@
|
|||
"""Tests for feishu_comment — event filtering, access control integration, wiki reverse lookup."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from gateway.platforms.feishu_comment import (
|
||||
parse_drive_comment_event,
|
||||
_ALLOWED_NOTICE_TYPES,
|
||||
_sanitize_comment_text,
|
||||
)
|
||||
|
||||
|
||||
def _make_event(
|
||||
comment_id="c1",
|
||||
reply_id="r1",
|
||||
notice_type="add_reply",
|
||||
file_token="docx_token",
|
||||
file_type="docx",
|
||||
from_open_id="ou_user",
|
||||
to_open_id="ou_bot",
|
||||
is_mentioned=True,
|
||||
):
|
||||
"""Build a minimal drive comment event SimpleNamespace."""
|
||||
return SimpleNamespace(event={
|
||||
"event_id": "evt_1",
|
||||
"comment_id": comment_id,
|
||||
"reply_id": reply_id,
|
||||
"is_mentioned": is_mentioned,
|
||||
"timestamp": "1713200000",
|
||||
"notice_meta": {
|
||||
"file_token": file_token,
|
||||
"file_type": file_type,
|
||||
"notice_type": notice_type,
|
||||
"from_user_id": {"open_id": from_open_id},
|
||||
"to_user_id": {"open_id": to_open_id},
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
class TestParseEvent(unittest.TestCase):
|
||||
def test_parse_valid_event(self):
|
||||
evt = _make_event()
|
||||
parsed = parse_drive_comment_event(evt)
|
||||
self.assertIsNotNone(parsed)
|
||||
self.assertEqual(parsed["comment_id"], "c1")
|
||||
self.assertEqual(parsed["file_type"], "docx")
|
||||
self.assertEqual(parsed["from_open_id"], "ou_user")
|
||||
self.assertEqual(parsed["to_open_id"], "ou_bot")
|
||||
|
||||
def test_parse_missing_event_attr(self):
|
||||
self.assertIsNone(parse_drive_comment_event(object()))
|
||||
|
||||
def test_parse_none_event(self):
|
||||
self.assertIsNone(parse_drive_comment_event(SimpleNamespace()))
|
||||
|
||||
|
||||
class TestEventFiltering(unittest.TestCase):
|
||||
"""Test the filtering logic in handle_drive_comment_event."""
|
||||
|
||||
def _run(self, coro):
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
@patch("gateway.platforms.feishu_comment_rules.load_config")
|
||||
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
|
||||
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed")
|
||||
def test_self_reply_filtered(self, mock_allowed, mock_resolve, mock_load):
|
||||
"""Events where from_open_id == self_open_id should be dropped."""
|
||||
from gateway.platforms.feishu_comment import handle_drive_comment_event
|
||||
|
||||
evt = _make_event(from_open_id="ou_bot", to_open_id="ou_bot")
|
||||
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
|
||||
mock_load.assert_not_called()
|
||||
|
||||
@patch("gateway.platforms.feishu_comment_rules.load_config")
|
||||
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
|
||||
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed")
|
||||
def test_wrong_receiver_filtered(self, mock_allowed, mock_resolve, mock_load):
|
||||
"""Events where to_open_id != self_open_id should be dropped."""
|
||||
from gateway.platforms.feishu_comment import handle_drive_comment_event
|
||||
|
||||
evt = _make_event(to_open_id="ou_other_bot")
|
||||
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
|
||||
mock_load.assert_not_called()
|
||||
|
||||
@patch("gateway.platforms.feishu_comment_rules.load_config")
|
||||
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
|
||||
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed")
|
||||
def test_empty_to_open_id_filtered(self, mock_allowed, mock_resolve, mock_load):
|
||||
"""Events with empty to_open_id should be dropped."""
|
||||
from gateway.platforms.feishu_comment import handle_drive_comment_event
|
||||
|
||||
evt = _make_event(to_open_id="")
|
||||
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
|
||||
mock_load.assert_not_called()
|
||||
|
||||
@patch("gateway.platforms.feishu_comment_rules.load_config")
|
||||
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
|
||||
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed")
|
||||
def test_invalid_notice_type_filtered(self, mock_allowed, mock_resolve, mock_load):
|
||||
"""Events with unsupported notice_type should be dropped."""
|
||||
from gateway.platforms.feishu_comment import handle_drive_comment_event
|
||||
|
||||
evt = _make_event(notice_type="resolve_comment")
|
||||
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
|
||||
mock_load.assert_not_called()
|
||||
|
||||
def test_allowed_notice_types(self):
|
||||
self.assertIn("add_comment", _ALLOWED_NOTICE_TYPES)
|
||||
self.assertIn("add_reply", _ALLOWED_NOTICE_TYPES)
|
||||
self.assertNotIn("resolve_comment", _ALLOWED_NOTICE_TYPES)
|
||||
|
||||
|
||||
class TestAccessControlIntegration(unittest.TestCase):
|
||||
def _run(self, coro):
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
@patch("gateway.platforms.feishu_comment_rules.has_wiki_keys", return_value=False)
|
||||
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed", return_value=False)
|
||||
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
|
||||
@patch("gateway.platforms.feishu_comment_rules.load_config")
|
||||
def test_denied_user_no_side_effects(self, mock_load, mock_resolve, mock_allowed, mock_wiki_keys):
|
||||
"""Denied user should not trigger typing reaction or agent."""
|
||||
from gateway.platforms.feishu_comment import handle_drive_comment_event
|
||||
from gateway.platforms.feishu_comment_rules import ResolvedCommentRule
|
||||
|
||||
mock_resolve.return_value = ResolvedCommentRule(True, "allowlist", frozenset(), "top")
|
||||
mock_load.return_value = Mock()
|
||||
|
||||
client = Mock()
|
||||
evt = _make_event()
|
||||
self._run(handle_drive_comment_event(client, evt, self_open_id="ou_bot"))
|
||||
|
||||
# No API calls should be made for denied users
|
||||
client.request.assert_not_called()
|
||||
|
||||
@patch("gateway.platforms.feishu_comment_rules.has_wiki_keys", return_value=False)
|
||||
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed", return_value=False)
|
||||
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
|
||||
@patch("gateway.platforms.feishu_comment_rules.load_config")
|
||||
def test_disabled_comment_skipped(self, mock_load, mock_resolve, mock_allowed, mock_wiki_keys):
|
||||
"""Disabled comments should return immediately."""
|
||||
from gateway.platforms.feishu_comment import handle_drive_comment_event
|
||||
from gateway.platforms.feishu_comment_rules import ResolvedCommentRule
|
||||
|
||||
mock_resolve.return_value = ResolvedCommentRule(False, "allowlist", frozenset(), "top")
|
||||
mock_load.return_value = Mock()
|
||||
|
||||
evt = _make_event()
|
||||
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
|
||||
mock_allowed.assert_not_called()
|
||||
|
||||
|
||||
class TestSanitizeCommentText(unittest.TestCase):
|
||||
def test_angle_brackets_escaped(self):
|
||||
self.assertEqual(_sanitize_comment_text("List<String>"), "List<String>")
|
||||
|
||||
def test_ampersand_escaped_first(self):
|
||||
self.assertEqual(_sanitize_comment_text("a & b"), "a & b")
|
||||
|
||||
def test_ampersand_not_double_escaped(self):
|
||||
result = _sanitize_comment_text("a < b & c > d")
|
||||
self.assertEqual(result, "a < b & c > d")
|
||||
self.assertNotIn("&lt;", result)
|
||||
self.assertNotIn("&gt;", result)
|
||||
|
||||
def test_plain_text_unchanged(self):
|
||||
self.assertEqual(_sanitize_comment_text("hello world"), "hello world")
|
||||
|
||||
def test_empty_string(self):
|
||||
self.assertEqual(_sanitize_comment_text(""), "")
|
||||
|
||||
def test_code_snippet(self):
|
||||
text = 'if (a < b && c > 0) { return "ok"; }'
|
||||
result = _sanitize_comment_text(text)
|
||||
self.assertNotIn("<", result)
|
||||
self.assertNotIn(">", result)
|
||||
self.assertIn("<", result)
|
||||
self.assertIn(">", result)
|
||||
|
||||
|
||||
class TestWikiReverseLookup(unittest.TestCase):
|
||||
def _run(self, coro):
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
@patch("gateway.platforms.feishu_comment._exec_request")
|
||||
def test_reverse_lookup_success(self, mock_exec):
|
||||
from gateway.platforms.feishu_comment import _reverse_lookup_wiki_token
|
||||
|
||||
mock_exec.return_value = (0, "Success", {
|
||||
"node": {"node_token": "WIKI_TOKEN_123", "obj_token": "docx_abc"},
|
||||
})
|
||||
result = self._run(_reverse_lookup_wiki_token(Mock(), "docx", "docx_abc"))
|
||||
self.assertEqual(result, "WIKI_TOKEN_123")
|
||||
# Verify correct API params
|
||||
call_args = mock_exec.call_args
|
||||
queries = call_args[1].get("queries") or call_args[0][3]
|
||||
query_dict = dict(queries)
|
||||
self.assertEqual(query_dict["token"], "docx_abc")
|
||||
self.assertEqual(query_dict["obj_type"], "docx")
|
||||
|
||||
@patch("gateway.platforms.feishu_comment._exec_request")
|
||||
def test_reverse_lookup_not_wiki(self, mock_exec):
|
||||
from gateway.platforms.feishu_comment import _reverse_lookup_wiki_token
|
||||
|
||||
mock_exec.return_value = (131001, "not found", {})
|
||||
result = self._run(_reverse_lookup_wiki_token(Mock(), "docx", "docx_abc"))
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("gateway.platforms.feishu_comment._exec_request")
|
||||
def test_reverse_lookup_service_error(self, mock_exec):
|
||||
from gateway.platforms.feishu_comment import _reverse_lookup_wiki_token
|
||||
|
||||
mock_exec.return_value = (500, "internal error", {})
|
||||
result = self._run(_reverse_lookup_wiki_token(Mock(), "docx", "docx_abc"))
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("gateway.platforms.feishu_comment._reverse_lookup_wiki_token", new_callable=AsyncMock)
|
||||
@patch("gateway.platforms.feishu_comment_rules.has_wiki_keys", return_value=True)
|
||||
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed", return_value=True)
|
||||
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
|
||||
@patch("gateway.platforms.feishu_comment_rules.load_config")
|
||||
@patch("gateway.platforms.feishu_comment.add_comment_reaction", new_callable=AsyncMock)
|
||||
@patch("gateway.platforms.feishu_comment.batch_query_comment", new_callable=AsyncMock)
|
||||
@patch("gateway.platforms.feishu_comment.query_document_meta", new_callable=AsyncMock)
|
||||
def test_wiki_lookup_triggered_when_no_exact_match(
|
||||
self, mock_meta, mock_batch, mock_reaction,
|
||||
mock_load, mock_resolve, mock_allowed, mock_wiki_keys, mock_lookup,
|
||||
):
|
||||
"""Wiki reverse lookup should fire when rule falls to wildcard/top and wiki keys exist."""
|
||||
from gateway.platforms.feishu_comment import handle_drive_comment_event
|
||||
from gateway.platforms.feishu_comment_rules import ResolvedCommentRule
|
||||
|
||||
# First resolve returns wildcard (no exact match), second returns exact wiki match
|
||||
mock_resolve.side_effect = [
|
||||
ResolvedCommentRule(True, "allowlist", frozenset(), "wildcard"),
|
||||
ResolvedCommentRule(True, "allowlist", frozenset(), "exact:wiki:WIKI123"),
|
||||
]
|
||||
mock_load.return_value = Mock()
|
||||
mock_lookup.return_value = "WIKI123"
|
||||
mock_meta.return_value = {"title": "Test", "url": ""}
|
||||
mock_batch.return_value = {"is_whole": False, "quote": ""}
|
||||
|
||||
evt = _make_event()
|
||||
# Will proceed past access control but fail later — that's OK, we just test the lookup
|
||||
try:
|
||||
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
mock_lookup.assert_called_once_with(unittest.mock.ANY, "docx", "docx_token")
|
||||
self.assertEqual(mock_resolve.call_count, 2)
|
||||
# Second call should include wiki_token
|
||||
second_call_kwargs = mock_resolve.call_args_list[1]
|
||||
self.assertEqual(second_call_kwargs[1].get("wiki_token") or second_call_kwargs[0][3], "WIKI123")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
320
tests/gateway/test_feishu_comment_rules.py
Normal file
320
tests/gateway/test_feishu_comment_rules.py
Normal file
|
|
@ -0,0 +1,320 @@
|
|||
"""Tests for feishu_comment_rules — 3-tier access control rule engine."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from gateway.platforms.feishu_comment_rules import (
|
||||
CommentsConfig,
|
||||
CommentDocumentRule,
|
||||
ResolvedCommentRule,
|
||||
_MtimeCache,
|
||||
_parse_document_rule,
|
||||
has_wiki_keys,
|
||||
is_user_allowed,
|
||||
load_config,
|
||||
pairing_add,
|
||||
pairing_list,
|
||||
pairing_remove,
|
||||
resolve_rule,
|
||||
)
|
||||
|
||||
|
||||
class TestCommentDocumentRuleParsing(unittest.TestCase):
|
||||
def test_parse_full_rule(self):
|
||||
rule = _parse_document_rule({
|
||||
"enabled": False,
|
||||
"policy": "allowlist",
|
||||
"allow_from": ["ou_a", "ou_b"],
|
||||
})
|
||||
self.assertFalse(rule.enabled)
|
||||
self.assertEqual(rule.policy, "allowlist")
|
||||
self.assertEqual(rule.allow_from, frozenset(["ou_a", "ou_b"]))
|
||||
|
||||
def test_parse_partial_rule(self):
|
||||
rule = _parse_document_rule({"policy": "allowlist"})
|
||||
self.assertIsNone(rule.enabled)
|
||||
self.assertEqual(rule.policy, "allowlist")
|
||||
self.assertIsNone(rule.allow_from)
|
||||
|
||||
def test_parse_empty_rule(self):
|
||||
rule = _parse_document_rule({})
|
||||
self.assertIsNone(rule.enabled)
|
||||
self.assertIsNone(rule.policy)
|
||||
self.assertIsNone(rule.allow_from)
|
||||
|
||||
def test_invalid_policy_ignored(self):
|
||||
rule = _parse_document_rule({"policy": "invalid_value"})
|
||||
self.assertIsNone(rule.policy)
|
||||
|
||||
|
||||
class TestResolveRule(unittest.TestCase):
|
||||
def test_exact_match(self):
|
||||
cfg = CommentsConfig(
|
||||
policy="pairing",
|
||||
allow_from=frozenset(["ou_top"]),
|
||||
documents={
|
||||
"docx:abc": CommentDocumentRule(policy="allowlist"),
|
||||
},
|
||||
)
|
||||
rule = resolve_rule(cfg, "docx", "abc")
|
||||
self.assertEqual(rule.policy, "allowlist")
|
||||
self.assertTrue(rule.match_source.startswith("exact:"))
|
||||
|
||||
def test_wildcard_match(self):
|
||||
cfg = CommentsConfig(
|
||||
policy="pairing",
|
||||
documents={
|
||||
"*": CommentDocumentRule(policy="allowlist"),
|
||||
},
|
||||
)
|
||||
rule = resolve_rule(cfg, "docx", "unknown")
|
||||
self.assertEqual(rule.policy, "allowlist")
|
||||
self.assertEqual(rule.match_source, "wildcard")
|
||||
|
||||
def test_top_level_fallback(self):
|
||||
cfg = CommentsConfig(policy="pairing", allow_from=frozenset(["ou_top"]))
|
||||
rule = resolve_rule(cfg, "docx", "whatever")
|
||||
self.assertEqual(rule.policy, "pairing")
|
||||
self.assertEqual(rule.allow_from, frozenset(["ou_top"]))
|
||||
self.assertEqual(rule.match_source, "top")
|
||||
|
||||
def test_exact_overrides_wildcard(self):
|
||||
cfg = CommentsConfig(
|
||||
policy="pairing",
|
||||
documents={
|
||||
"*": CommentDocumentRule(policy="pairing"),
|
||||
"docx:abc": CommentDocumentRule(policy="allowlist"),
|
||||
},
|
||||
)
|
||||
rule = resolve_rule(cfg, "docx", "abc")
|
||||
self.assertEqual(rule.policy, "allowlist")
|
||||
self.assertTrue(rule.match_source.startswith("exact:"))
|
||||
|
||||
def test_field_by_field_fallback(self):
|
||||
"""Exact sets policy, wildcard sets allow_from, enabled from top."""
|
||||
cfg = CommentsConfig(
|
||||
enabled=True,
|
||||
policy="pairing",
|
||||
allow_from=frozenset(["ou_top"]),
|
||||
documents={
|
||||
"*": CommentDocumentRule(allow_from=frozenset(["ou_wildcard"])),
|
||||
"docx:abc": CommentDocumentRule(policy="allowlist"),
|
||||
},
|
||||
)
|
||||
rule = resolve_rule(cfg, "docx", "abc")
|
||||
self.assertEqual(rule.policy, "allowlist")
|
||||
self.assertEqual(rule.allow_from, frozenset(["ou_wildcard"]))
|
||||
self.assertTrue(rule.enabled)
|
||||
|
||||
def test_explicit_empty_allow_from_does_not_fall_through(self):
|
||||
"""allow_from=[] on exact should NOT inherit from wildcard or top."""
|
||||
cfg = CommentsConfig(
|
||||
allow_from=frozenset(["ou_top"]),
|
||||
documents={
|
||||
"*": CommentDocumentRule(allow_from=frozenset(["ou_wildcard"])),
|
||||
"docx:abc": CommentDocumentRule(
|
||||
policy="allowlist",
|
||||
allow_from=frozenset(),
|
||||
),
|
||||
},
|
||||
)
|
||||
rule = resolve_rule(cfg, "docx", "abc")
|
||||
self.assertEqual(rule.allow_from, frozenset())
|
||||
|
||||
def test_wiki_token_match(self):
|
||||
cfg = CommentsConfig(
|
||||
policy="pairing",
|
||||
documents={
|
||||
"wiki:WIKI123": CommentDocumentRule(policy="allowlist"),
|
||||
},
|
||||
)
|
||||
rule = resolve_rule(cfg, "docx", "obj_token", wiki_token="WIKI123")
|
||||
self.assertEqual(rule.policy, "allowlist")
|
||||
self.assertTrue(rule.match_source.startswith("exact:wiki:"))
|
||||
|
||||
def test_exact_takes_priority_over_wiki(self):
|
||||
cfg = CommentsConfig(
|
||||
documents={
|
||||
"docx:abc": CommentDocumentRule(policy="allowlist"),
|
||||
"wiki:WIKI123": CommentDocumentRule(policy="pairing"),
|
||||
},
|
||||
)
|
||||
rule = resolve_rule(cfg, "docx", "abc", wiki_token="WIKI123")
|
||||
self.assertEqual(rule.policy, "allowlist")
|
||||
self.assertTrue(rule.match_source.startswith("exact:docx:"))
|
||||
|
||||
def test_default_config(self):
|
||||
cfg = CommentsConfig()
|
||||
rule = resolve_rule(cfg, "docx", "anything")
|
||||
self.assertTrue(rule.enabled)
|
||||
self.assertEqual(rule.policy, "pairing")
|
||||
self.assertEqual(rule.allow_from, frozenset())
|
||||
|
||||
|
||||
class TestHasWikiKeys(unittest.TestCase):
|
||||
def test_no_wiki_keys(self):
|
||||
cfg = CommentsConfig(documents={
|
||||
"docx:abc": CommentDocumentRule(policy="allowlist"),
|
||||
"*": CommentDocumentRule(policy="pairing"),
|
||||
})
|
||||
self.assertFalse(has_wiki_keys(cfg))
|
||||
|
||||
def test_has_wiki_keys(self):
|
||||
cfg = CommentsConfig(documents={
|
||||
"wiki:WIKI123": CommentDocumentRule(policy="allowlist"),
|
||||
})
|
||||
self.assertTrue(has_wiki_keys(cfg))
|
||||
|
||||
def test_empty_documents(self):
|
||||
cfg = CommentsConfig()
|
||||
self.assertFalse(has_wiki_keys(cfg))
|
||||
|
||||
|
||||
class TestIsUserAllowed(unittest.TestCase):
|
||||
def test_allowlist_allows_listed(self):
|
||||
rule = ResolvedCommentRule(True, "allowlist", frozenset(["ou_a"]), "top")
|
||||
self.assertTrue(is_user_allowed(rule, "ou_a"))
|
||||
|
||||
def test_allowlist_denies_unlisted(self):
|
||||
rule = ResolvedCommentRule(True, "allowlist", frozenset(["ou_a"]), "top")
|
||||
self.assertFalse(is_user_allowed(rule, "ou_b"))
|
||||
|
||||
def test_allowlist_empty_denies_all(self):
|
||||
rule = ResolvedCommentRule(True, "allowlist", frozenset(), "top")
|
||||
self.assertFalse(is_user_allowed(rule, "ou_anyone"))
|
||||
|
||||
def test_pairing_allows_in_allow_from(self):
|
||||
rule = ResolvedCommentRule(True, "pairing", frozenset(["ou_a"]), "top")
|
||||
self.assertTrue(is_user_allowed(rule, "ou_a"))
|
||||
|
||||
def test_pairing_checks_store(self):
|
||||
rule = ResolvedCommentRule(True, "pairing", frozenset(), "top")
|
||||
with patch(
|
||||
"gateway.platforms.feishu_comment_rules._load_pairing_approved",
|
||||
return_value={"ou_approved"},
|
||||
):
|
||||
self.assertTrue(is_user_allowed(rule, "ou_approved"))
|
||||
self.assertFalse(is_user_allowed(rule, "ou_unknown"))
|
||||
|
||||
|
||||
class TestMtimeCache(unittest.TestCase):
|
||||
def test_returns_empty_dict_for_missing_file(self):
|
||||
cache = _MtimeCache(Path("/nonexistent/path.json"))
|
||||
self.assertEqual(cache.load(), {})
|
||||
|
||||
def test_reads_file_and_caches(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump({"key": "value"}, f)
|
||||
f.flush()
|
||||
path = Path(f.name)
|
||||
try:
|
||||
cache = _MtimeCache(path)
|
||||
data = cache.load()
|
||||
self.assertEqual(data, {"key": "value"})
|
||||
# Second load should use cache (same mtime)
|
||||
data2 = cache.load()
|
||||
self.assertEqual(data2, {"key": "value"})
|
||||
finally:
|
||||
path.unlink()
|
||||
|
||||
def test_reloads_on_mtime_change(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump({"v": 1}, f)
|
||||
f.flush()
|
||||
path = Path(f.name)
|
||||
try:
|
||||
cache = _MtimeCache(path)
|
||||
self.assertEqual(cache.load(), {"v": 1})
|
||||
# Modify file
|
||||
time.sleep(0.05)
|
||||
with open(path, "w") as f2:
|
||||
json.dump({"v": 2}, f2)
|
||||
# Force mtime change detection
|
||||
os.utime(path, (time.time() + 1, time.time() + 1))
|
||||
self.assertEqual(cache.load(), {"v": 2})
|
||||
finally:
|
||||
path.unlink()
|
||||
|
||||
|
||||
class TestLoadConfig(unittest.TestCase):
|
||||
def test_load_with_documents(self):
|
||||
raw = {
|
||||
"enabled": True,
|
||||
"policy": "allowlist",
|
||||
"allow_from": ["ou_a"],
|
||||
"documents": {
|
||||
"*": {"policy": "pairing"},
|
||||
"docx:abc": {"policy": "allowlist", "allow_from": ["ou_b"]},
|
||||
},
|
||||
}
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(raw, f)
|
||||
path = Path(f.name)
|
||||
try:
|
||||
with patch("gateway.platforms.feishu_comment_rules.RULES_FILE", path):
|
||||
with patch("gateway.platforms.feishu_comment_rules._rules_cache", _MtimeCache(path)):
|
||||
cfg = load_config()
|
||||
self.assertTrue(cfg.enabled)
|
||||
self.assertEqual(cfg.policy, "allowlist")
|
||||
self.assertEqual(cfg.allow_from, frozenset(["ou_a"]))
|
||||
self.assertIn("*", cfg.documents)
|
||||
self.assertIn("docx:abc", cfg.documents)
|
||||
self.assertEqual(cfg.documents["docx:abc"].policy, "allowlist")
|
||||
finally:
|
||||
path.unlink()
|
||||
|
||||
def test_load_missing_file_returns_defaults(self):
|
||||
with patch("gateway.platforms.feishu_comment_rules._rules_cache", _MtimeCache(Path("/nonexistent"))):
|
||||
cfg = load_config()
|
||||
self.assertTrue(cfg.enabled)
|
||||
self.assertEqual(cfg.policy, "pairing")
|
||||
self.assertEqual(cfg.allow_from, frozenset())
|
||||
self.assertEqual(cfg.documents, {})
|
||||
|
||||
|
||||
class TestPairingStore(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._tmpdir = tempfile.mkdtemp()
|
||||
self._pairing_file = Path(self._tmpdir) / "pairing.json"
|
||||
with open(self._pairing_file, "w") as f:
|
||||
json.dump({"approved": {}}, f)
|
||||
self._patcher_file = patch("gateway.platforms.feishu_comment_rules.PAIRING_FILE", self._pairing_file)
|
||||
self._patcher_cache = patch(
|
||||
"gateway.platforms.feishu_comment_rules._pairing_cache",
|
||||
_MtimeCache(self._pairing_file),
|
||||
)
|
||||
self._patcher_file.start()
|
||||
self._patcher_cache.start()
|
||||
|
||||
def tearDown(self):
|
||||
self._patcher_cache.stop()
|
||||
self._patcher_file.stop()
|
||||
if self._pairing_file.exists():
|
||||
self._pairing_file.unlink()
|
||||
os.rmdir(self._tmpdir)
|
||||
|
||||
def test_add_and_list(self):
|
||||
self.assertTrue(pairing_add("ou_new"))
|
||||
approved = pairing_list()
|
||||
self.assertIn("ou_new", approved)
|
||||
|
||||
def test_add_duplicate(self):
|
||||
pairing_add("ou_a")
|
||||
self.assertFalse(pairing_add("ou_a"))
|
||||
|
||||
def test_remove(self):
|
||||
pairing_add("ou_a")
|
||||
self.assertTrue(pairing_remove("ou_a"))
|
||||
self.assertNotIn("ou_a", pairing_list())
|
||||
|
||||
def test_remove_nonexistent(self):
|
||||
self.assertFalse(pairing_remove("ou_nobody"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -179,7 +179,7 @@ class TestVoiceAttachmentSSRFProtection:
|
|||
from gateway.platforms.qqbot import QQAdapter, _ssrf_redirect_guard
|
||||
|
||||
client = mock.AsyncMock()
|
||||
with mock.patch("gateway.platforms.qqbot.httpx.AsyncClient", return_value=client) as async_client_cls:
|
||||
with mock.patch("gateway.platforms.qqbot.adapter.httpx.AsyncClient", return_value=client) as async_client_cls:
|
||||
adapter = QQAdapter(_make_config(app_id="a", client_secret="b"))
|
||||
adapter._ensure_token = mock.AsyncMock(side_effect=RuntimeError("stop after client creation"))
|
||||
|
||||
|
|
|
|||
|
|
@ -202,3 +202,120 @@ async def test_start_gateway_replace_force_uses_terminate_pid(monkeypatch, tmp_p
|
|||
|
||||
assert ok is True
|
||||
assert calls == [(42, False), (42, True)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_gateway_replace_writes_takeover_marker_before_sigterm(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
"""--replace must write a takeover marker BEFORE sending SIGTERM.
|
||||
|
||||
The marker lets the target's shutdown handler identify the signal as a
|
||||
planned takeover (→ exit 0) rather than an unexpected kill (→ exit 1).
|
||||
Without the marker, PR #5646's signal-recovery path would revive the
|
||||
target via systemd Restart=on-failure, starting a flap loop.
|
||||
"""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
# Record the ORDER of marker-write + terminate_pid calls
|
||||
events: list[str] = []
|
||||
marker_paths_seen: list = []
|
||||
|
||||
def record_write_marker(target_pid: int) -> bool:
|
||||
events.append(f"write_marker(target_pid={target_pid})")
|
||||
# Also check that the marker file actually exists after this call
|
||||
marker_paths_seen.append(
|
||||
(tmp_path / ".gateway-takeover.json").exists() is False # not yet
|
||||
)
|
||||
# Actually write the marker so we can verify cleanup later
|
||||
from gateway.status import _get_takeover_marker_path, _write_json_file, _get_process_start_time
|
||||
_write_json_file(_get_takeover_marker_path(), {
|
||||
"target_pid": target_pid,
|
||||
"target_start_time": 0,
|
||||
"replacer_pid": 100,
|
||||
"written_at": "2026-04-17T00:00:00+00:00",
|
||||
})
|
||||
return True
|
||||
|
||||
def record_terminate(pid, force=False):
|
||||
events.append(f"terminate_pid(pid={pid}, force={force})")
|
||||
|
||||
class _CleanExitRunner:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.should_exit_cleanly = True
|
||||
self.exit_reason = None
|
||||
self.adapters = {}
|
||||
|
||||
async def start(self):
|
||||
return True
|
||||
|
||||
async def stop(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("gateway.status.get_running_pid", lambda: 42)
|
||||
monkeypatch.setattr("gateway.status.remove_pid_file", lambda: None)
|
||||
monkeypatch.setattr("gateway.status.release_all_scoped_locks", lambda: 0)
|
||||
monkeypatch.setattr("gateway.status.write_takeover_marker", record_write_marker)
|
||||
monkeypatch.setattr("gateway.status.terminate_pid", record_terminate)
|
||||
monkeypatch.setattr("gateway.run.os.getpid", lambda: 100)
|
||||
# Simulate old process exiting on first check so we don't loop into force-kill
|
||||
monkeypatch.setattr(
|
||||
"gateway.run.os.kill",
|
||||
lambda pid, sig: (_ for _ in ()).throw(ProcessLookupError()),
|
||||
)
|
||||
monkeypatch.setattr("time.sleep", lambda _: None)
|
||||
monkeypatch.setattr("tools.skills_sync.sync_skills", lambda quiet=True: None)
|
||||
monkeypatch.setattr("hermes_logging.setup_logging", lambda hermes_home, mode: tmp_path)
|
||||
monkeypatch.setattr("hermes_logging._add_rotating_handler", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr("gateway.run.GatewayRunner", _CleanExitRunner)
|
||||
|
||||
from gateway.run import start_gateway
|
||||
|
||||
ok = await start_gateway(config=GatewayConfig(), replace=True, verbosity=None)
|
||||
|
||||
assert ok is True
|
||||
# Ordering: marker written BEFORE SIGTERM
|
||||
assert events[0] == "write_marker(target_pid=42)"
|
||||
assert any(e.startswith("terminate_pid(pid=42") for e in events[1:])
|
||||
# Marker file cleanup: replacer cleans it after loop completes
|
||||
assert not (tmp_path / ".gateway-takeover.json").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_gateway_replace_clears_marker_on_permission_denied(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
"""If we fail to kill the existing PID (permission denied), clean up the
|
||||
marker so it doesn't grief an unrelated future shutdown."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
def write_marker(target_pid: int) -> bool:
|
||||
from gateway.status import _get_takeover_marker_path, _write_json_file
|
||||
_write_json_file(_get_takeover_marker_path(), {
|
||||
"target_pid": target_pid,
|
||||
"target_start_time": 0,
|
||||
"replacer_pid": 100,
|
||||
"written_at": "2026-04-17T00:00:00+00:00",
|
||||
})
|
||||
return True
|
||||
|
||||
def raise_permission(pid, force=False):
|
||||
raise PermissionError("simulated EPERM")
|
||||
|
||||
monkeypatch.setattr("gateway.status.get_running_pid", lambda: 42)
|
||||
monkeypatch.setattr("gateway.status.write_takeover_marker", write_marker)
|
||||
monkeypatch.setattr("gateway.status.terminate_pid", raise_permission)
|
||||
monkeypatch.setattr("gateway.run.os.getpid", lambda: 100)
|
||||
monkeypatch.setattr("tools.skills_sync.sync_skills", lambda quiet=True: None)
|
||||
monkeypatch.setattr("hermes_logging.setup_logging", lambda hermes_home, mode: tmp_path)
|
||||
monkeypatch.setattr("hermes_logging._add_rotating_handler", lambda *args, **kwargs: None)
|
||||
|
||||
from gateway.run import start_gateway
|
||||
|
||||
# Should return False due to permission error
|
||||
ok = await start_gateway(config=GatewayConfig(), replace=True, verbosity=None)
|
||||
|
||||
assert ok is False
|
||||
# Marker must NOT be left behind
|
||||
assert not (tmp_path / ".gateway-takeover.json").exists()
|
||||
|
|
|
|||
|
|
@ -288,6 +288,38 @@ async def test_command_messages_do_not_leave_sentinel():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("command_text", "handler_attr", "handler_result"),
|
||||
[
|
||||
("/help", "_handle_help_command", "Help text"),
|
||||
("/commands", "_handle_commands_command", "Commands text"),
|
||||
("/update", "_handle_update_command", "Update text"),
|
||||
("/profile", "_handle_profile_command", "Profile text"),
|
||||
],
|
||||
)
|
||||
async def test_active_session_bypass_commands_dispatch_without_interrupt(
|
||||
command_text,
|
||||
handler_attr,
|
||||
handler_result,
|
||||
):
|
||||
"""Gateway-handled bypass commands must return directly while an agent runs."""
|
||||
runner = _make_runner()
|
||||
event = _make_event(text=command_text)
|
||||
session_key = build_session_key(event.source)
|
||||
|
||||
fake_agent = MagicMock()
|
||||
fake_agent.get_activity_summary.return_value = {"seconds_since_activity": 0}
|
||||
runner._running_agents[session_key] = fake_agent
|
||||
setattr(runner, handler_attr, AsyncMock(return_value=handler_result))
|
||||
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
assert result == handler_result
|
||||
fake_agent.interrupt.assert_not_called()
|
||||
assert session_key not in runner.adapters[Platform.TELEGRAM]._pending_messages
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Test 6: /stop during sentinel force-cleans and unlocks session
|
||||
# ------------------------------------------------------------------
|
||||
|
|
|
|||
231
tests/gateway/test_session_state_cleanup.py
Normal file
231
tests/gateway/test_session_state_cleanup.py
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
"""Regression tests for _release_running_agent_state and SessionDB shutdown.
|
||||
|
||||
Before this change, running-agent state lived in three dicts that drifted
|
||||
out of sync:
|
||||
|
||||
self._running_agents — AIAgent instance per session key
|
||||
self._running_agents_ts — start timestamp per session key
|
||||
self._busy_ack_ts — last busy-ack timestamp per session key
|
||||
|
||||
Six cleanup sites did ``del self._running_agents[key]`` without touching
|
||||
the other two; one site only popped ``_running_agents`` and
|
||||
``_running_agents_ts``; and only the stale-eviction site cleaned all
|
||||
three. Each missed entry was a small persistent leak.
|
||||
|
||||
Also: SessionDB connections were never closed on gateway shutdown,
|
||||
leaving WAL locks in place until Python actually exited.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_runner():
|
||||
"""Bare GatewayRunner wired with just the state the helper touches."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._busy_ack_ts = {}
|
||||
return runner
|
||||
|
||||
|
||||
class TestReleaseRunningAgentStateUnit:
|
||||
def test_pops_all_three_dicts(self):
|
||||
runner = _make_runner()
|
||||
runner._running_agents["k"] = MagicMock()
|
||||
runner._running_agents_ts["k"] = 123.0
|
||||
runner._busy_ack_ts["k"] = 456.0
|
||||
|
||||
runner._release_running_agent_state("k")
|
||||
|
||||
assert "k" not in runner._running_agents
|
||||
assert "k" not in runner._running_agents_ts
|
||||
assert "k" not in runner._busy_ack_ts
|
||||
|
||||
def test_idempotent_on_missing_key(self):
|
||||
"""Calling twice (or on an absent key) must not raise."""
|
||||
runner = _make_runner()
|
||||
runner._release_running_agent_state("missing")
|
||||
runner._release_running_agent_state("missing") # still fine
|
||||
|
||||
def test_noop_on_empty_session_key(self):
|
||||
"""Empty string / None key is treated as a no-op."""
|
||||
runner = _make_runner()
|
||||
runner._running_agents[""] = "guard"
|
||||
runner._release_running_agent_state("")
|
||||
# Empty key not processed — guard value survives.
|
||||
assert runner._running_agents[""] == "guard"
|
||||
|
||||
def test_preserves_other_sessions(self):
|
||||
runner = _make_runner()
|
||||
for k in ("a", "b", "c"):
|
||||
runner._running_agents[k] = MagicMock()
|
||||
runner._running_agents_ts[k] = 1.0
|
||||
runner._busy_ack_ts[k] = 1.0
|
||||
|
||||
runner._release_running_agent_state("b")
|
||||
|
||||
assert set(runner._running_agents.keys()) == {"a", "c"}
|
||||
assert set(runner._running_agents_ts.keys()) == {"a", "c"}
|
||||
assert set(runner._busy_ack_ts.keys()) == {"a", "c"}
|
||||
|
||||
def test_handles_missing_busy_ack_attribute(self):
|
||||
"""Backward-compatible with older runners lacking _busy_ack_ts."""
|
||||
runner = _make_runner()
|
||||
del runner._busy_ack_ts # simulate older version
|
||||
runner._running_agents["k"] = MagicMock()
|
||||
runner._running_agents_ts["k"] = 1.0
|
||||
|
||||
runner._release_running_agent_state("k") # should not raise
|
||||
|
||||
assert "k" not in runner._running_agents
|
||||
assert "k" not in runner._running_agents_ts
|
||||
|
||||
def test_concurrent_release_is_safe(self):
|
||||
"""Multiple threads releasing different keys concurrently."""
|
||||
runner = _make_runner()
|
||||
for i in range(50):
|
||||
k = f"s{i}"
|
||||
runner._running_agents[k] = MagicMock()
|
||||
runner._running_agents_ts[k] = float(i)
|
||||
runner._busy_ack_ts[k] = float(i)
|
||||
|
||||
def worker(keys):
|
||||
for k in keys:
|
||||
runner._release_running_agent_state(k)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=worker, args=([f"s{i}" for i in range(start, 50, 5)],))
|
||||
for start in range(5)
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=5)
|
||||
assert not t.is_alive()
|
||||
|
||||
assert runner._running_agents == {}
|
||||
assert runner._running_agents_ts == {}
|
||||
assert runner._busy_ack_ts == {}
|
||||
|
||||
|
||||
class TestNoMoreBareDeleteSites:
|
||||
"""Regression: all bare `del self._running_agents[key]` sites were
|
||||
converted to use the helper. If a future contributor reverts one,
|
||||
this test flags it. Docstrings / comments mentioning the old
|
||||
pattern are allowed.
|
||||
"""
|
||||
|
||||
def test_no_bare_del_of_running_agents_in_gateway_run(self):
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
gateway_run = (Path(__file__).parent.parent.parent / "gateway" / "run.py").read_text()
|
||||
# Match `del self._running_agents[...]` that is NOT inside a
|
||||
# triple-quoted docstring. We scan non-docstring lines only.
|
||||
lines = gateway_run.splitlines()
|
||||
|
||||
in_docstring = False
|
||||
docstring_delim = None
|
||||
offenders = []
|
||||
for idx, line in enumerate(lines, start=1):
|
||||
stripped = line.strip()
|
||||
if not in_docstring:
|
||||
if stripped.startswith('"""') or stripped.startswith("'''"):
|
||||
delim = stripped[:3]
|
||||
# single-line docstring?
|
||||
if stripped.count(delim) >= 2:
|
||||
continue
|
||||
in_docstring = True
|
||||
docstring_delim = delim
|
||||
continue
|
||||
if re.search(r"\bdel\s+self\._running_agents\[", line):
|
||||
offenders.append((idx, line.rstrip()))
|
||||
else:
|
||||
if docstring_delim and docstring_delim in stripped:
|
||||
in_docstring = False
|
||||
docstring_delim = None
|
||||
|
||||
assert offenders == [], (
|
||||
"Found bare `del self._running_agents[...]` sites in gateway/run.py. "
|
||||
"Use self._release_running_agent_state(session_key) instead so "
|
||||
"_running_agents_ts and _busy_ack_ts are popped in lockstep.\n"
|
||||
+ "\n".join(f" line {n}: {l}" for n, l in offenders)
|
||||
)
|
||||
|
||||
|
||||
class TestSessionDbCloseOnShutdown:
|
||||
"""_stop_impl should call .close() on both self._session_db and
|
||||
self.session_store._db to release SQLite WAL locks before the new
|
||||
gateway (during --replace restart) tries to open the same file.
|
||||
"""
|
||||
|
||||
def test_stop_impl_closes_both_session_dbs(self):
|
||||
"""Run the exact shutdown block that closes SessionDBs and verify
|
||||
.close() was called on both holders."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
|
||||
runner_db = MagicMock()
|
||||
store_db = MagicMock()
|
||||
|
||||
runner._db = runner_db
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store._db = store_db
|
||||
|
||||
# Replicate the exact production loop from _stop_impl.
|
||||
for _db_holder in (runner, getattr(runner, "session_store", None)):
|
||||
_db = getattr(_db_holder, "_db", None) if _db_holder else None
|
||||
if _db is None or not hasattr(_db, "close"):
|
||||
continue
|
||||
_db.close()
|
||||
|
||||
runner_db.close.assert_called_once()
|
||||
store_db.close.assert_called_once()
|
||||
|
||||
def test_shutdown_tolerates_missing_session_store(self):
|
||||
"""Gateway without a session_store attribute must not crash on shutdown."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._db = MagicMock()
|
||||
# Deliberately no session_store attribute.
|
||||
|
||||
for _db_holder in (runner, getattr(runner, "session_store", None)):
|
||||
_db = getattr(_db_holder, "_db", None) if _db_holder else None
|
||||
if _db is None or not hasattr(_db, "close"):
|
||||
continue
|
||||
_db.close()
|
||||
|
||||
runner._db.close.assert_called_once()
|
||||
|
||||
def test_shutdown_tolerates_close_raising(self):
|
||||
"""A close() that raises must not prevent subsequent cleanup."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
flaky_db = MagicMock()
|
||||
flaky_db.close.side_effect = RuntimeError("simulated lock error")
|
||||
healthy_db = MagicMock()
|
||||
|
||||
runner._db = flaky_db
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store._db = healthy_db
|
||||
|
||||
# Same pattern as production: try/except around each close().
|
||||
for _db_holder in (runner, getattr(runner, "session_store", None)):
|
||||
_db = getattr(_db_holder, "_db", None) if _db_holder else None
|
||||
if _db is None or not hasattr(_db, "close"):
|
||||
continue
|
||||
try:
|
||||
_db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
flaky_db.close.assert_called_once()
|
||||
healthy_db.close.assert_called_once()
|
||||
270
tests/gateway/test_session_store_prune.py
Normal file
270
tests/gateway/test_session_store_prune.py
Normal file
|
|
@ -0,0 +1,270 @@
|
|||
"""Tests for SessionStore.prune_old_entries and the gateway watcher that calls it.
|
||||
|
||||
The SessionStore in-memory dict (and its backing sessions.json) grew
|
||||
unbounded — every unique (platform, chat_id, thread_id, user_id) tuple
|
||||
ever seen was kept forever, regardless of how stale it became. These
|
||||
tests pin the prune behaviour:
|
||||
|
||||
* Entries older than max_age_days (by updated_at) are removed
|
||||
* Entries marked ``suspended`` are preserved (user-paused)
|
||||
* Entries with an active process attached are preserved
|
||||
* max_age_days <= 0 disables pruning entirely
|
||||
* sessions.json is rewritten with the post-prune dict
|
||||
* The ``updated_at`` field — not ``created_at`` — drives the decision
|
||||
(so a long-running-but-still-active session isn't pruned)
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, SessionResetPolicy
|
||||
from gateway.session import SessionEntry, SessionStore
|
||||
|
||||
|
||||
def _make_store(tmp_path, max_age_days: int = 90, has_active_processes_fn=None):
|
||||
"""Build a SessionStore bypassing SQLite/disk-load side effects."""
|
||||
config = GatewayConfig(
|
||||
default_reset_policy=SessionResetPolicy(mode="none"),
|
||||
session_store_max_age_days=max_age_days,
|
||||
)
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(
|
||||
sessions_dir=tmp_path,
|
||||
config=config,
|
||||
has_active_processes_fn=has_active_processes_fn,
|
||||
)
|
||||
store._db = None
|
||||
store._loaded = True
|
||||
return store
|
||||
|
||||
|
||||
def _entry(key: str, age_days: float, *, suspended: bool = False,
|
||||
session_id: str | None = None) -> SessionEntry:
|
||||
now = datetime.now()
|
||||
return SessionEntry(
|
||||
session_key=key,
|
||||
session_id=session_id or f"sid_{key}",
|
||||
created_at=now - timedelta(days=age_days + 30), # arbitrary older
|
||||
updated_at=now - timedelta(days=age_days),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
suspended=suspended,
|
||||
)
|
||||
|
||||
|
||||
class TestPruneBasics:
|
||||
def test_prune_removes_entries_past_max_age(self, tmp_path):
|
||||
store = _make_store(tmp_path)
|
||||
store._entries["old"] = _entry("old", age_days=100)
|
||||
store._entries["fresh"] = _entry("fresh", age_days=5)
|
||||
|
||||
removed = store.prune_old_entries(max_age_days=90)
|
||||
|
||||
assert removed == 1
|
||||
assert "old" not in store._entries
|
||||
assert "fresh" in store._entries
|
||||
|
||||
def test_prune_uses_updated_at_not_created_at(self, tmp_path):
|
||||
"""A session created long ago but updated recently must be kept."""
|
||||
store = _make_store(tmp_path)
|
||||
now = datetime.now()
|
||||
entry = SessionEntry(
|
||||
session_key="long-lived",
|
||||
session_id="sid",
|
||||
created_at=now - timedelta(days=365), # ancient
|
||||
updated_at=now - timedelta(days=3), # but just chatted
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
store._entries["long-lived"] = entry
|
||||
|
||||
removed = store.prune_old_entries(max_age_days=30)
|
||||
|
||||
assert removed == 0
|
||||
assert "long-lived" in store._entries
|
||||
|
||||
def test_prune_disabled_when_max_age_is_zero(self, tmp_path):
|
||||
store = _make_store(tmp_path, max_age_days=0)
|
||||
for i in range(5):
|
||||
store._entries[f"s{i}"] = _entry(f"s{i}", age_days=365)
|
||||
|
||||
assert store.prune_old_entries(0) == 0
|
||||
assert len(store._entries) == 5
|
||||
|
||||
def test_prune_disabled_when_max_age_is_negative(self, tmp_path):
|
||||
store = _make_store(tmp_path)
|
||||
store._entries["s"] = _entry("s", age_days=365)
|
||||
|
||||
assert store.prune_old_entries(-1) == 0
|
||||
assert "s" in store._entries
|
||||
|
||||
def test_prune_skips_suspended_entries(self, tmp_path):
|
||||
"""/stop-suspended sessions must be kept for later resume."""
|
||||
store = _make_store(tmp_path)
|
||||
store._entries["suspended"] = _entry(
|
||||
"suspended", age_days=1000, suspended=True
|
||||
)
|
||||
store._entries["idle"] = _entry("idle", age_days=1000)
|
||||
|
||||
removed = store.prune_old_entries(max_age_days=90)
|
||||
|
||||
assert removed == 1
|
||||
assert "suspended" in store._entries
|
||||
assert "idle" not in store._entries
|
||||
|
||||
def test_prune_skips_entries_with_active_processes(self, tmp_path):
|
||||
"""Sessions with active bg processes aren't pruned even if old."""
|
||||
active_session_ids = {"sid_active"}
|
||||
|
||||
def _has_active(session_id: str) -> bool:
|
||||
return session_id in active_session_ids
|
||||
|
||||
store = _make_store(tmp_path, has_active_processes_fn=_has_active)
|
||||
store._entries["active"] = _entry(
|
||||
"active", age_days=1000, session_id="sid_active"
|
||||
)
|
||||
store._entries["idle"] = _entry(
|
||||
"idle", age_days=1000, session_id="sid_idle"
|
||||
)
|
||||
|
||||
removed = store.prune_old_entries(max_age_days=90)
|
||||
|
||||
assert removed == 1
|
||||
assert "active" in store._entries
|
||||
assert "idle" not in store._entries
|
||||
|
||||
def test_prune_does_not_write_disk_when_no_removals(self, tmp_path):
|
||||
"""If nothing is evictable, _save() should NOT be called."""
|
||||
store = _make_store(tmp_path)
|
||||
store._entries["fresh1"] = _entry("fresh1", age_days=1)
|
||||
store._entries["fresh2"] = _entry("fresh2", age_days=2)
|
||||
|
||||
save_calls = []
|
||||
store._save = lambda: save_calls.append(1)
|
||||
|
||||
assert store.prune_old_entries(max_age_days=90) == 0
|
||||
assert save_calls == []
|
||||
|
||||
def test_prune_writes_disk_after_removal(self, tmp_path):
|
||||
store = _make_store(tmp_path)
|
||||
store._entries["stale"] = _entry("stale", age_days=500)
|
||||
store._entries["fresh"] = _entry("fresh", age_days=1)
|
||||
|
||||
save_calls = []
|
||||
store._save = lambda: save_calls.append(1)
|
||||
|
||||
store.prune_old_entries(max_age_days=90)
|
||||
assert save_calls == [1]
|
||||
|
||||
def test_prune_is_thread_safe(self, tmp_path):
|
||||
"""Prune acquires _lock internally; concurrent update_session is safe."""
|
||||
store = _make_store(tmp_path)
|
||||
for i in range(20):
|
||||
age = 1000 if i % 2 == 0 else 1
|
||||
store._entries[f"s{i}"] = _entry(f"s{i}", age_days=age)
|
||||
|
||||
results = []
|
||||
|
||||
def _pruner():
|
||||
results.append(store.prune_old_entries(max_age_days=90))
|
||||
|
||||
def _reader():
|
||||
# Mimic a concurrent update_session reader iterating under lock.
|
||||
with store._lock:
|
||||
list(store._entries.keys())
|
||||
|
||||
threads = [threading.Thread(target=_pruner)]
|
||||
threads += [threading.Thread(target=_reader) for _ in range(4)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=5)
|
||||
assert not t.is_alive()
|
||||
|
||||
# Exactly one pruner ran; removed exactly the 10 stale entries.
|
||||
assert results == [10]
|
||||
assert len(store._entries) == 10
|
||||
for i in range(20):
|
||||
if i % 2 == 1: # fresh
|
||||
assert f"s{i}" in store._entries
|
||||
|
||||
|
||||
class TestPrunePersistsToDisk:
|
||||
def test_prune_rewrites_sessions_json(self, tmp_path):
|
||||
"""After prune, sessions.json on disk reflects the new dict."""
|
||||
config = GatewayConfig(
|
||||
default_reset_policy=SessionResetPolicy(mode="none"),
|
||||
session_store_max_age_days=90,
|
||||
)
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._db = None
|
||||
# Force-populate without calling get_or_create to avoid DB side-effects
|
||||
store._entries["stale"] = _entry("stale", age_days=500)
|
||||
store._entries["fresh"] = _entry("fresh", age_days=1)
|
||||
store._loaded = True
|
||||
store._save()
|
||||
|
||||
# Verify pre-prune state on disk.
|
||||
saved_pre = json.loads((tmp_path / "sessions.json").read_text())
|
||||
assert set(saved_pre.keys()) == {"stale", "fresh"}
|
||||
|
||||
# Prune and check disk.
|
||||
store.prune_old_entries(max_age_days=90)
|
||||
saved_post = json.loads((tmp_path / "sessions.json").read_text())
|
||||
assert set(saved_post.keys()) == {"fresh"}
|
||||
|
||||
|
||||
class TestGatewayConfigSerialization:
|
||||
def test_session_store_max_age_days_defaults_to_90(self):
|
||||
cfg = GatewayConfig()
|
||||
assert cfg.session_store_max_age_days == 90
|
||||
|
||||
def test_session_store_max_age_days_roundtrips(self):
|
||||
cfg = GatewayConfig(session_store_max_age_days=30)
|
||||
restored = GatewayConfig.from_dict(cfg.to_dict())
|
||||
assert restored.session_store_max_age_days == 30
|
||||
|
||||
def test_session_store_max_age_days_missing_defaults_90(self):
|
||||
"""Loading an old config (pre-this-field) falls back to default."""
|
||||
restored = GatewayConfig.from_dict({})
|
||||
assert restored.session_store_max_age_days == 90
|
||||
|
||||
def test_session_store_max_age_days_negative_coerced_to_zero(self):
|
||||
"""A negative value (accidental or hostile) becomes 0 (disabled)."""
|
||||
restored = GatewayConfig.from_dict({"session_store_max_age_days": -5})
|
||||
assert restored.session_store_max_age_days == 0
|
||||
|
||||
def test_session_store_max_age_days_bad_type_falls_back(self):
|
||||
"""Non-int values fall back to the default, not a crash."""
|
||||
restored = GatewayConfig.from_dict({"session_store_max_age_days": "nope"})
|
||||
assert restored.session_store_max_age_days == 90
|
||||
|
||||
|
||||
class TestGatewayWatcherCallsPrune:
|
||||
"""The session_expiry_watcher should call prune_old_entries once per hour."""
|
||||
|
||||
def test_prune_gate_fires_on_first_tick(self):
|
||||
"""First watcher tick has _last_prune_ts=0, so the gate opens."""
|
||||
import time as _t
|
||||
|
||||
last_ts = 0.0
|
||||
prune_interval = 3600.0
|
||||
now = _t.time()
|
||||
|
||||
# Mirror the production gate check in _session_expiry_watcher.
|
||||
should_prune = (now - last_ts) > prune_interval
|
||||
assert should_prune is True
|
||||
|
||||
def test_prune_gate_suppresses_within_interval(self):
|
||||
import time as _t
|
||||
|
||||
last_ts = _t.time() - 600 # 10 minutes ago
|
||||
prune_interval = 3600.0
|
||||
now = _t.time()
|
||||
|
||||
should_prune = (now - last_ts) > prune_interval
|
||||
assert should_prune is False
|
||||
|
|
@ -63,6 +63,24 @@ class TestGatewayPidState:
|
|||
|
||||
assert status.get_running_pid() == os.getpid()
|
||||
|
||||
def test_get_running_pid_accepts_explicit_pid_path_without_cleanup(self, tmp_path, monkeypatch):
|
||||
other_home = tmp_path / "profile-home"
|
||||
other_home.mkdir()
|
||||
pid_path = other_home / "gateway.pid"
|
||||
pid_path.write_text(json.dumps({
|
||||
"pid": os.getpid(),
|
||||
"kind": "hermes-gateway",
|
||||
"argv": ["python", "-m", "hermes_cli.main", "gateway"],
|
||||
"start_time": 123,
|
||||
}))
|
||||
|
||||
monkeypatch.setattr(status.os, "kill", lambda pid, sig: None)
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 123)
|
||||
monkeypatch.setattr(status, "_read_process_cmdline", lambda pid: None)
|
||||
|
||||
assert status.get_running_pid(pid_path, cleanup_stale=False) == os.getpid()
|
||||
assert pid_path.exists()
|
||||
|
||||
|
||||
class TestGatewayRuntimeStatus:
|
||||
def test_write_runtime_status_overwrites_stale_pid_on_restart(self, tmp_path, monkeypatch):
|
||||
|
|
@ -246,3 +264,181 @@ class TestScopedLocks:
|
|||
|
||||
status.release_scoped_lock("telegram-bot-token", "secret")
|
||||
assert not lock_path.exists()
|
||||
|
||||
|
||||
class TestTakeoverMarker:
|
||||
"""Tests for the --replace takeover marker.
|
||||
|
||||
The marker breaks the post-#5646 flap loop between two gateway services
|
||||
fighting for the same bot token. The replacer writes a file naming the
|
||||
target PID + start_time; the target's shutdown handler sees it and exits
|
||||
0 instead of 1, so systemd's Restart=on-failure doesn't revive it.
|
||||
"""
|
||||
|
||||
def test_write_marker_records_target_identity(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 42)
|
||||
|
||||
ok = status.write_takeover_marker(target_pid=12345)
|
||||
|
||||
assert ok is True
|
||||
marker = tmp_path / ".gateway-takeover.json"
|
||||
assert marker.exists()
|
||||
payload = json.loads(marker.read_text())
|
||||
assert payload["target_pid"] == 12345
|
||||
assert payload["target_start_time"] == 42
|
||||
assert payload["replacer_pid"] == os.getpid()
|
||||
assert "written_at" in payload
|
||||
|
||||
def test_consume_returns_true_when_marker_names_self(self, tmp_path, monkeypatch):
|
||||
"""Primary happy path: planned takeover is recognised."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
# Mark THIS process as the target
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 100)
|
||||
ok = status.write_takeover_marker(target_pid=os.getpid())
|
||||
assert ok is True
|
||||
|
||||
# Call consume as if this process just got SIGTERMed
|
||||
result = status.consume_takeover_marker_for_self()
|
||||
|
||||
assert result is True
|
||||
# Marker must be unlinked after consumption
|
||||
assert not (tmp_path / ".gateway-takeover.json").exists()
|
||||
|
||||
def test_consume_returns_false_for_different_pid(self, tmp_path, monkeypatch):
|
||||
"""A marker naming a DIFFERENT process must not be consumed as ours."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 100)
|
||||
# Marker names a different PID
|
||||
other_pid = os.getpid() + 9999
|
||||
ok = status.write_takeover_marker(target_pid=other_pid)
|
||||
assert ok is True
|
||||
|
||||
result = status.consume_takeover_marker_for_self()
|
||||
|
||||
assert result is False
|
||||
# Marker IS unlinked even on non-match (the record has been consumed
|
||||
# and isn't relevant to us — leaving it around would grief a later
|
||||
# legitimate check).
|
||||
assert not (tmp_path / ".gateway-takeover.json").exists()
|
||||
|
||||
def test_consume_returns_false_on_start_time_mismatch(self, tmp_path, monkeypatch):
|
||||
"""PID reuse defence: old marker's start_time mismatches current process."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
# Marker says target started at time 100 with our PID
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 100)
|
||||
status.write_takeover_marker(target_pid=os.getpid())
|
||||
|
||||
# Now change the reported start_time to simulate PID reuse
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 9999)
|
||||
|
||||
result = status.consume_takeover_marker_for_self()
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_consume_returns_false_when_marker_missing(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
result = status.consume_takeover_marker_for_self()
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_consume_returns_false_for_stale_marker(self, tmp_path, monkeypatch):
|
||||
"""A marker older than 60s must be ignored."""
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
marker_path = tmp_path / ".gateway-takeover.json"
|
||||
# Hand-craft a marker written 2 minutes ago
|
||||
stale_time = (datetime.now(timezone.utc) - timedelta(minutes=2)).isoformat()
|
||||
marker_path.write_text(json.dumps({
|
||||
"target_pid": os.getpid(),
|
||||
"target_start_time": 123,
|
||||
"replacer_pid": 99999,
|
||||
"written_at": stale_time,
|
||||
}))
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 123)
|
||||
|
||||
result = status.consume_takeover_marker_for_self()
|
||||
|
||||
assert result is False
|
||||
# Stale markers are unlinked so a later legit shutdown isn't griefed
|
||||
assert not marker_path.exists()
|
||||
|
||||
def test_consume_handles_malformed_marker_gracefully(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
marker_path = tmp_path / ".gateway-takeover.json"
|
||||
marker_path.write_text("not valid json{")
|
||||
|
||||
# Must not raise
|
||||
result = status.consume_takeover_marker_for_self()
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_consume_handles_marker_with_missing_fields(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
marker_path = tmp_path / ".gateway-takeover.json"
|
||||
marker_path.write_text(json.dumps({"only_replacer_pid": 99999}))
|
||||
|
||||
result = status.consume_takeover_marker_for_self()
|
||||
|
||||
assert result is False
|
||||
# Malformed marker should be cleaned up
|
||||
assert not marker_path.exists()
|
||||
|
||||
def test_clear_takeover_marker_is_idempotent(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
# Nothing to clear — must not raise
|
||||
status.clear_takeover_marker()
|
||||
|
||||
# Write then clear
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 100)
|
||||
status.write_takeover_marker(target_pid=12345)
|
||||
assert (tmp_path / ".gateway-takeover.json").exists()
|
||||
|
||||
status.clear_takeover_marker()
|
||||
assert not (tmp_path / ".gateway-takeover.json").exists()
|
||||
|
||||
# Clear again — still no error
|
||||
status.clear_takeover_marker()
|
||||
|
||||
def test_write_marker_returns_false_on_write_failure(self, tmp_path, monkeypatch):
|
||||
"""write_takeover_marker is best-effort; returns False but doesn't raise."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
def raise_oserror(*args, **kwargs):
|
||||
raise OSError("simulated write failure")
|
||||
|
||||
monkeypatch.setattr(status, "_write_json_file", raise_oserror)
|
||||
|
||||
ok = status.write_takeover_marker(target_pid=12345)
|
||||
|
||||
assert ok is False
|
||||
|
||||
def test_consume_ignores_marker_for_different_process_and_prevents_stale_grief(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""Regression: a stale marker from a dead replacer naming a dead
|
||||
target must not accidentally cause an unrelated future gateway to
|
||||
exit 0 on legitimate SIGTERM.
|
||||
|
||||
The distinguishing check is ``target_pid == our_pid AND
|
||||
target_start_time == our_start_time``. Different PID always wins.
|
||||
"""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
marker_path = tmp_path / ".gateway-takeover.json"
|
||||
# Fresh marker (timestamp is recent) but names a totally different PID
|
||||
from datetime import datetime, timezone
|
||||
marker_path.write_text(json.dumps({
|
||||
"target_pid": os.getpid() + 10000,
|
||||
"target_start_time": 42,
|
||||
"replacer_pid": 99999,
|
||||
"written_at": datetime.now(timezone.utc).isoformat(),
|
||||
}))
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 42)
|
||||
|
||||
result = status.consume_takeover_marker_for_self()
|
||||
|
||||
# We are not the target — must NOT consume as planned
|
||||
assert result is False
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Tests for gateway /status behavior and token persistence."""
|
||||
|
||||
from datetime import datetime
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
|
@ -111,6 +112,75 @@ async def test_status_command_includes_session_title_when_present():
|
|||
assert "**Title:** My titled session" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agents_command_reports_active_agents_and_processes(monkeypatch):
|
||||
session_key = build_session_key(_make_source())
|
||||
session_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
total_tokens=0,
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
running_agent = SimpleNamespace(
|
||||
session_id="sess-running",
|
||||
model="openrouter/test-model",
|
||||
interrupt=MagicMock(),
|
||||
get_activity_summary=lambda: {"seconds_since_activity": 0},
|
||||
)
|
||||
runner._running_agents[session_key] = running_agent
|
||||
runner._running_agents_ts = {session_key: time.time() - 8}
|
||||
runner._background_tasks = set()
|
||||
|
||||
class _FakeRegistry:
|
||||
def list_sessions(self):
|
||||
return [
|
||||
{
|
||||
"session_id": "proc-1",
|
||||
"status": "running",
|
||||
"uptime_seconds": 17,
|
||||
"command": "sleep 30",
|
||||
}
|
||||
]
|
||||
|
||||
monkeypatch.setattr("tools.process_registry.process_registry", _FakeRegistry())
|
||||
|
||||
result = await runner._handle_message(_make_event("/agents"))
|
||||
|
||||
assert "**Active agents:** 1" in result
|
||||
assert "**Running background processes:** 1" in result
|
||||
assert "proc-1" in result
|
||||
running_agent.interrupt.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tasks_alias_routes_to_agents_command(monkeypatch):
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
total_tokens=0,
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner._background_tasks = set()
|
||||
|
||||
class _FakeRegistry:
|
||||
def list_sessions(self):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr("tools.process_registry.process_registry", _FakeRegistry())
|
||||
|
||||
result = await runner._handle_message(_make_event("/tasks"))
|
||||
|
||||
assert "Active Agents & Tasks" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_persists_agent_token_counts(monkeypatch):
|
||||
import gateway.run as gateway_run
|
||||
|
|
|
|||
|
|
@ -88,6 +88,51 @@ class TestCleanForDisplay:
|
|||
# ── Integration: _send_or_edit strips MEDIA: ─────────────────────────────
|
||||
|
||||
|
||||
class TestFinalizeCapabilityGate:
|
||||
"""Verify REQUIRES_EDIT_FINALIZE gates the redundant final edit.
|
||||
|
||||
Platforms that don't need an explicit finalize signal (Telegram,
|
||||
Slack, Matrix, …) should skip the redundant final edit when the
|
||||
mid-stream edit already delivered the final content. Platforms that
|
||||
*do* need it (DingTalk AI Cards) must always receive a finalize=True
|
||||
edit at the end of the stream.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_identical_text_skip_respects_adapter_flag(self):
|
||||
"""_send_or_edit short-circuits identical-text only when the
|
||||
adapter doesn't require an explicit finalize signal."""
|
||||
# Adapter without finalize requirement — should skip identical edit.
|
||||
plain = MagicMock()
|
||||
plain.REQUIRES_EDIT_FINALIZE = False
|
||||
plain.send = AsyncMock(return_value=SimpleNamespace(
|
||||
success=True, message_id="m1",
|
||||
))
|
||||
plain.edit_message = AsyncMock()
|
||||
plain.MAX_MESSAGE_LENGTH = 4096
|
||||
c1 = GatewayStreamConsumer(plain, "chat_1")
|
||||
await c1._send_or_edit("hello") # first send
|
||||
await c1._send_or_edit("hello", finalize=True) # identical → skip
|
||||
plain.edit_message.assert_not_called()
|
||||
|
||||
# Adapter that requires finalize — must still fire the edit.
|
||||
picky = MagicMock()
|
||||
picky.REQUIRES_EDIT_FINALIZE = True
|
||||
picky.send = AsyncMock(return_value=SimpleNamespace(
|
||||
success=True, message_id="m1",
|
||||
))
|
||||
picky.edit_message = AsyncMock(return_value=SimpleNamespace(
|
||||
success=True, message_id="m1",
|
||||
))
|
||||
picky.MAX_MESSAGE_LENGTH = 4096
|
||||
c2 = GatewayStreamConsumer(picky, "chat_1")
|
||||
await c2._send_or_edit("hello")
|
||||
await c2._send_or_edit("hello", finalize=True)
|
||||
# Finalize edit must go through even on identical content.
|
||||
picky.edit_message.assert_called_once()
|
||||
assert picky.edit_message.call_args[1]["finalize"] is True
|
||||
|
||||
|
||||
class TestSendOrEditMediaStripping:
|
||||
"""Verify _send_or_edit strips MEDIA: before sending to the platform."""
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,12 @@ def _ensure_telegram_mock():
|
|||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
from gateway.platforms.telegram import TelegramAdapter, _escape_mdv2, _strip_mdv2 # noqa: E402
|
||||
from gateway.platforms.telegram import ( # noqa: E402
|
||||
TelegramAdapter,
|
||||
_escape_mdv2,
|
||||
_strip_mdv2,
|
||||
_wrap_markdown_tables,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -535,6 +540,152 @@ class TestStripMdv2:
|
|||
assert _strip_mdv2("||hidden text||") == "hidden text"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Markdown table auto-wrap
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestWrapMarkdownTables:
|
||||
"""_wrap_markdown_tables wraps GFM pipe tables in ``` fences so
|
||||
Telegram renders them as monospace preformatted text instead of the
|
||||
noisy backslash-pipe mess MarkdownV2 produces."""
|
||||
|
||||
def test_basic_table_wrapped(self):
|
||||
text = (
|
||||
"Scores:\n\n"
|
||||
"| Player | Score |\n"
|
||||
"|--------|-------|\n"
|
||||
"| Alice | 150 |\n"
|
||||
"| Bob | 120 |\n"
|
||||
"\nEnd."
|
||||
)
|
||||
out = _wrap_markdown_tables(text)
|
||||
# Table is now wrapped in a fence
|
||||
assert "```\n| Player | Score |" in out
|
||||
assert "| Bob | 120 |\n```" in out
|
||||
# Surrounding prose is preserved
|
||||
assert out.startswith("Scores:")
|
||||
assert out.endswith("End.")
|
||||
|
||||
def test_bare_pipe_table_wrapped(self):
|
||||
"""Tables without outer pipes (GFM allows this) are still detected."""
|
||||
text = "head1 | head2\n--- | ---\na | b\nc | d"
|
||||
out = _wrap_markdown_tables(text)
|
||||
assert out.startswith("```\n")
|
||||
assert out.rstrip().endswith("```")
|
||||
assert "head1 | head2" in out
|
||||
|
||||
def test_alignment_separators(self):
|
||||
"""Separator rows with :--- / ---: / :---: alignment markers match."""
|
||||
text = (
|
||||
"| Name | Age | City |\n"
|
||||
"|:-----|----:|:----:|\n"
|
||||
"| Ada | 30 | NYC |"
|
||||
)
|
||||
out = _wrap_markdown_tables(text)
|
||||
assert out.count("```") == 2
|
||||
|
||||
def test_two_consecutive_tables_wrapped_separately(self):
|
||||
text = (
|
||||
"| A | B |\n"
|
||||
"|---|---|\n"
|
||||
"| 1 | 2 |\n"
|
||||
"\n"
|
||||
"| X | Y |\n"
|
||||
"|---|---|\n"
|
||||
"| 9 | 8 |"
|
||||
)
|
||||
out = _wrap_markdown_tables(text)
|
||||
# Four fences total — one opening + closing per table
|
||||
assert out.count("```") == 4
|
||||
|
||||
def test_plain_text_with_pipes_not_wrapped(self):
|
||||
"""A bare pipe in prose must NOT trigger wrapping."""
|
||||
text = "Use the | pipe operator to chain commands."
|
||||
assert _wrap_markdown_tables(text) == text
|
||||
|
||||
def test_horizontal_rule_not_wrapped(self):
|
||||
"""A lone '---' horizontal rule must not be mistaken for a separator."""
|
||||
text = "Section A\n\n---\n\nSection B"
|
||||
assert _wrap_markdown_tables(text) == text
|
||||
|
||||
def test_existing_code_block_with_pipes_left_alone(self):
|
||||
"""A table already inside a fenced code block must not be re-wrapped."""
|
||||
text = (
|
||||
"```\n"
|
||||
"| a | b |\n"
|
||||
"|---|---|\n"
|
||||
"| 1 | 2 |\n"
|
||||
"```"
|
||||
)
|
||||
assert _wrap_markdown_tables(text) == text
|
||||
|
||||
def test_no_pipe_character_short_circuits(self):
|
||||
text = "Plain **bold** text with no table."
|
||||
assert _wrap_markdown_tables(text) == text
|
||||
|
||||
def test_no_dash_short_circuits(self):
|
||||
text = "a | b\nc | d" # has pipes but no '-' separator row
|
||||
assert _wrap_markdown_tables(text) == text
|
||||
|
||||
def test_single_column_separator_not_matched(self):
|
||||
"""Single-column tables (rare) are not detected — we require at
|
||||
least one internal pipe in the separator row to avoid false
|
||||
positives on formatting rules."""
|
||||
text = "| a |\n| - |\n| b |"
|
||||
assert _wrap_markdown_tables(text) == text
|
||||
|
||||
|
||||
class TestFormatMessageTables:
|
||||
"""End-to-end: a pipe table passes through format_message with its
|
||||
pipes and dashes left alone inside the fence, not mangled by MarkdownV2
|
||||
escaping."""
|
||||
|
||||
def test_table_rendered_as_code_block(self, adapter):
|
||||
text = (
|
||||
"Data:\n\n"
|
||||
"| Col1 | Col2 |\n"
|
||||
"|------|------|\n"
|
||||
"| A | B |\n"
|
||||
)
|
||||
out = adapter.format_message(text)
|
||||
# Pipes inside the fenced block are NOT escaped
|
||||
assert "```\n| Col1 | Col2 |" in out
|
||||
assert "\\|" not in out.split("```")[1]
|
||||
# Dashes in separator not escaped inside fence
|
||||
assert "\\-" not in out.split("```")[1]
|
||||
|
||||
def test_text_after_table_still_formatted(self, adapter):
|
||||
text = (
|
||||
"| A | B |\n"
|
||||
"|---|---|\n"
|
||||
"| 1 | 2 |\n"
|
||||
"\n"
|
||||
"Nice **work** team!"
|
||||
)
|
||||
out = adapter.format_message(text)
|
||||
# MarkdownV2 bold conversion still happens outside the table
|
||||
assert "*work*" in out
|
||||
# Exclamation outside fence is escaped
|
||||
assert "\\!" in out
|
||||
|
||||
def test_multiple_tables_in_single_message(self, adapter):
|
||||
text = (
|
||||
"First:\n"
|
||||
"| A | B |\n"
|
||||
"|---|---|\n"
|
||||
"| 1 | 2 |\n"
|
||||
"\n"
|
||||
"Second:\n"
|
||||
"| X | Y |\n"
|
||||
"|---|---|\n"
|
||||
"| 9 | 8 |\n"
|
||||
)
|
||||
out = adapter.format_message(text)
|
||||
# Two separate fenced blocks in the output
|
||||
assert out.count("```") == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_escapes_chunk_indicator_for_markdownv2(adapter):
|
||||
adapter.MAX_MESSAGE_LENGTH = 80
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ class TestWeComConnect:
|
|||
|
||||
class TestWeComReplyMode:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_passive_reply_stream_when_reply_context_exists(self):
|
||||
async def test_send_uses_passive_reply_markdown_when_reply_context_exists(self):
|
||||
from gateway.platforms.wecom import WeComAdapter
|
||||
|
||||
adapter = WeComAdapter(PlatformConfig(enabled=True))
|
||||
|
|
@ -134,9 +134,10 @@ class TestWeComReplyMode:
|
|||
adapter._send_reply_request.assert_awaited_once()
|
||||
args = adapter._send_reply_request.await_args.args
|
||||
assert args[0] == "req-1"
|
||||
assert args[1]["msgtype"] == "stream"
|
||||
assert args[1]["stream"]["finish"] is True
|
||||
assert args[1]["stream"]["content"] == "hello from reply"
|
||||
# msgtype: stream triggers WeCom errcode 600039 on many mobile clients
|
||||
# (unsupported type). Markdown renders everywhere.
|
||||
assert args[1]["msgtype"] == "markdown"
|
||||
assert args[1]["markdown"]["content"] == "hello from reply"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_image_file_uses_passive_reply_media_when_reply_context_exists(self):
|
||||
|
|
@ -593,3 +594,193 @@ class TestInboundMessages:
|
|||
await adapter._on_message(payload)
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
class TestWeComZombieSessionFix:
|
||||
"""Tests for PR #11572 — device_id, markdown reply, group req_id fallback."""
|
||||
|
||||
def test_adapter_generates_stable_device_id_per_instance(self):
|
||||
from gateway.platforms.wecom import WeComAdapter
|
||||
|
||||
adapter = WeComAdapter(PlatformConfig(enabled=True))
|
||||
assert isinstance(adapter._device_id, str)
|
||||
assert len(adapter._device_id) > 0
|
||||
# Second snapshot on the same adapter must be identical — only a fresh
|
||||
# adapter instance should get a new device_id (one-per-reconnect is the
|
||||
# zombie-session footgun we're fixing).
|
||||
assert adapter._device_id == adapter._device_id
|
||||
|
||||
def test_different_adapter_instances_get_distinct_device_ids(self):
|
||||
from gateway.platforms.wecom import WeComAdapter
|
||||
|
||||
a = WeComAdapter(PlatformConfig(enabled=True))
|
||||
b = WeComAdapter(PlatformConfig(enabled=True))
|
||||
assert a._device_id != b._device_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_connection_includes_device_id_in_subscribe(self):
|
||||
from gateway.platforms.wecom import APP_CMD_SUBSCRIBE, WeComAdapter
|
||||
|
||||
adapter = WeComAdapter(PlatformConfig(enabled=True))
|
||||
adapter._bot_id = "test-bot"
|
||||
adapter._secret = "test-secret"
|
||||
|
||||
sent_payloads = []
|
||||
|
||||
class _FakeWS:
|
||||
closed = False
|
||||
|
||||
async def send_json(self, payload):
|
||||
sent_payloads.append(payload)
|
||||
|
||||
async def close(self):
|
||||
return None
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def ws_connect(self, *args, **kwargs):
|
||||
return _FakeWS()
|
||||
|
||||
async def close(self):
|
||||
return None
|
||||
|
||||
async def _fake_cleanup():
|
||||
return None
|
||||
|
||||
async def _fake_handshake(req_id):
|
||||
return {"errcode": 0, "headers": {"req_id": req_id}}
|
||||
|
||||
adapter._cleanup_ws = _fake_cleanup
|
||||
adapter._wait_for_handshake = _fake_handshake
|
||||
|
||||
with patch("gateway.platforms.wecom.aiohttp.ClientSession", _FakeSession):
|
||||
await adapter._open_connection()
|
||||
|
||||
assert len(sent_payloads) == 1
|
||||
subscribe = sent_payloads[0]
|
||||
assert subscribe["cmd"] == APP_CMD_SUBSCRIBE
|
||||
assert subscribe["body"]["bot_id"] == "test-bot"
|
||||
assert subscribe["body"]["secret"] == "test-secret"
|
||||
assert subscribe["body"]["device_id"] == adapter._device_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_caches_last_req_id_per_chat(self):
|
||||
from gateway.platforms.wecom import WeComAdapter
|
||||
|
||||
adapter = WeComAdapter(PlatformConfig(enabled=True))
|
||||
adapter._text_batch_delay_seconds = 0
|
||||
adapter.handle_message = AsyncMock()
|
||||
adapter._extract_media = AsyncMock(return_value=([], []))
|
||||
|
||||
payload = {
|
||||
"cmd": "aibot_msg_callback",
|
||||
"headers": {"req_id": "req-abc"},
|
||||
"body": {
|
||||
"msgid": "msg-1",
|
||||
"chatid": "group-1",
|
||||
"chattype": "group",
|
||||
"from": {"userid": "user-1"},
|
||||
"msgtype": "text",
|
||||
"text": {"content": "hi"},
|
||||
},
|
||||
}
|
||||
|
||||
await adapter._on_message(payload)
|
||||
assert adapter._last_chat_req_ids["group-1"] == "req-abc"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_does_not_cache_blocked_sender_req_id(self):
|
||||
"""Blocked chats shouldn't populate the proactive-send fallback cache."""
|
||||
from gateway.platforms.wecom import WeComAdapter
|
||||
|
||||
adapter = WeComAdapter(
|
||||
PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"group_policy": "allowlist", "group_allow_from": ["group-ok"]},
|
||||
)
|
||||
)
|
||||
adapter.handle_message = AsyncMock()
|
||||
adapter._extract_media = AsyncMock(return_value=([], []))
|
||||
|
||||
payload = {
|
||||
"cmd": "aibot_msg_callback",
|
||||
"headers": {"req_id": "req-abc"},
|
||||
"body": {
|
||||
"msgid": "msg-1",
|
||||
"chatid": "group-blocked",
|
||||
"chattype": "group",
|
||||
"from": {"userid": "user-1"},
|
||||
"msgtype": "text",
|
||||
"text": {"content": "hi"},
|
||||
},
|
||||
}
|
||||
|
||||
await adapter._on_message(payload)
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
assert "group-blocked" not in adapter._last_chat_req_ids
|
||||
|
||||
def test_remember_chat_req_id_is_bounded(self):
|
||||
from gateway.platforms.wecom import DEDUP_MAX_SIZE, WeComAdapter
|
||||
|
||||
adapter = WeComAdapter(PlatformConfig(enabled=True))
|
||||
for i in range(DEDUP_MAX_SIZE + 50):
|
||||
adapter._remember_chat_req_id(f"chat-{i}", f"req-{i}")
|
||||
assert len(adapter._last_chat_req_ids) <= DEDUP_MAX_SIZE
|
||||
# The most recently remembered chat must still be present.
|
||||
latest = f"chat-{DEDUP_MAX_SIZE + 49}"
|
||||
assert adapter._last_chat_req_ids[latest] == f"req-{DEDUP_MAX_SIZE + 49}"
|
||||
|
||||
def test_remember_chat_req_id_ignores_empty_values(self):
|
||||
from gateway.platforms.wecom import WeComAdapter
|
||||
|
||||
adapter = WeComAdapter(PlatformConfig(enabled=True))
|
||||
adapter._remember_chat_req_id("", "req-1")
|
||||
adapter._remember_chat_req_id("chat-1", "")
|
||||
adapter._remember_chat_req_id(" ", " ")
|
||||
assert adapter._last_chat_req_ids == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proactive_group_send_falls_back_to_cached_req_id(self):
|
||||
"""Sending into a group without reply_to should use the last cached
|
||||
req_id via APP_CMD_RESPONSE — WeCom AI Bots cannot initiate APP_CMD_SEND
|
||||
in group chats (errcode 600039)."""
|
||||
from gateway.platforms.wecom import WeComAdapter
|
||||
|
||||
adapter = WeComAdapter(PlatformConfig(enabled=True))
|
||||
adapter._last_chat_req_ids["group-1"] = "inbound-req-42"
|
||||
adapter._send_reply_request = AsyncMock(
|
||||
return_value={"headers": {"req_id": "inbound-req-42"}, "errcode": 0}
|
||||
)
|
||||
adapter._send_request = AsyncMock(
|
||||
return_value={"headers": {"req_id": "new"}, "errcode": 0}
|
||||
)
|
||||
|
||||
result = await adapter.send("group-1", "ping", reply_to=None)
|
||||
|
||||
assert result.success is True
|
||||
# Must route through reply (APP_CMD_RESPONSE), not proactive send.
|
||||
adapter._send_reply_request.assert_awaited_once()
|
||||
adapter._send_request.assert_not_awaited()
|
||||
args = adapter._send_reply_request.await_args.args
|
||||
assert args[0] == "inbound-req-42"
|
||||
assert args[1]["msgtype"] == "markdown"
|
||||
assert args[1]["markdown"]["content"] == "ping"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proactive_send_without_cached_req_id_uses_app_cmd_send(self):
|
||||
"""When we have no prior req_id (fresh DM target), APP_CMD_SEND is used."""
|
||||
from gateway.platforms.wecom import APP_CMD_SEND, WeComAdapter
|
||||
|
||||
adapter = WeComAdapter(PlatformConfig(enabled=True))
|
||||
adapter._send_request = AsyncMock(
|
||||
return_value={"headers": {"req_id": "new"}, "errcode": 0}
|
||||
)
|
||||
|
||||
result = await adapter.send("fresh-dm-chat", "ping", reply_to=None)
|
||||
|
||||
assert result.success is True
|
||||
adapter._send_request.assert_awaited_once()
|
||||
cmd = adapter._send_request.await_args.args[0]
|
||||
assert cmd == APP_CMD_SEND
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ class TestProviderRegistry:
|
|||
("huggingface", "Hugging Face", "api_key"),
|
||||
("zai", "Z.AI / GLM", "api_key"),
|
||||
("xai", "xAI", "api_key"),
|
||||
("nvidia", "NVIDIA NIM", "api_key"),
|
||||
("kimi-coding", "Kimi / Moonshot", "api_key"),
|
||||
("minimax", "MiniMax", "api_key"),
|
||||
("minimax-cn", "MiniMax (China)", "api_key"),
|
||||
|
|
@ -57,6 +58,12 @@ class TestProviderRegistry:
|
|||
assert pconfig.base_url_env_var == "XAI_BASE_URL"
|
||||
assert pconfig.inference_base_url == "https://api.x.ai/v1"
|
||||
|
||||
def test_nvidia_env_vars(self):
|
||||
pconfig = PROVIDER_REGISTRY["nvidia"]
|
||||
assert pconfig.api_key_env_vars == ("NVIDIA_API_KEY",)
|
||||
assert pconfig.base_url_env_var == "NVIDIA_BASE_URL"
|
||||
assert pconfig.inference_base_url == "https://integrate.api.nvidia.com/v1"
|
||||
|
||||
def test_copilot_env_vars(self):
|
||||
pconfig = PROVIDER_REGISTRY["copilot"]
|
||||
assert pconfig.api_key_env_vars == ("COPILOT_GITHUB_TOKEN", "GH_TOKEN", "GITHUB_TOKEN")
|
||||
|
|
|
|||
|
|
@ -141,13 +141,93 @@ def test_auth_add_nous_oauth_persists_pool_entry(tmp_path, monkeypatch):
|
|||
auth_add_command(_Args())
|
||||
|
||||
payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
|
||||
|
||||
# Pool has exactly one canonical `device_code` entry — not a duplicate
|
||||
# pair of `manual:device_code` + `device_code` (the latter would be
|
||||
# materialised by _seed_from_singletons on every load_pool).
|
||||
entries = payload["credential_pool"]["nous"]
|
||||
entry = next(item for item in entries if item["source"] == "manual:device_code")
|
||||
assert entry["label"] == "nous@example.com"
|
||||
assert entry["source"] == "manual:device_code"
|
||||
device_code_entries = [
|
||||
item for item in entries if item["source"] == "device_code"
|
||||
]
|
||||
assert len(device_code_entries) == 1, entries
|
||||
assert not any(item["source"] == "manual:device_code" for item in entries)
|
||||
entry = device_code_entries[0]
|
||||
assert entry["source"] == "device_code"
|
||||
assert entry["agent_key"] == "ak-test"
|
||||
assert entry["portal_base_url"] == "https://portal.example.com"
|
||||
|
||||
# `hermes auth add nous` must also populate providers.nous so the
|
||||
# 401-recovery path (resolve_nous_runtime_credentials) can mint a fresh
|
||||
# agent_key when the 24h TTL expires. If this mirror is missing, recovery
|
||||
# raises "Hermes is not logged into Nous Portal" and the agent dies.
|
||||
singleton = payload["providers"]["nous"]
|
||||
assert singleton["access_token"] == token
|
||||
assert singleton["refresh_token"] == "refresh-token"
|
||||
assert singleton["agent_key"] == "ak-test"
|
||||
assert singleton["portal_base_url"] == "https://portal.example.com"
|
||||
assert singleton["inference_base_url"] == "https://inference.example.com/v1"
|
||||
|
||||
|
||||
def test_auth_add_nous_oauth_honors_custom_label(tmp_path, monkeypatch):
|
||||
"""`hermes auth add nous --type oauth --label <name>` must preserve the
|
||||
custom label end-to-end — it was silently dropped in the first cut of the
|
||||
persist_nous_credentials helper because `--label` wasn't threaded through.
|
||||
"""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(tmp_path, {"version": 1, "providers": {}})
|
||||
token = _jwt_with_email("nous@example.com")
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth._nous_device_code_login",
|
||||
lambda **kwargs: {
|
||||
"portal_base_url": "https://portal.example.com",
|
||||
"inference_base_url": "https://inference.example.com/v1",
|
||||
"client_id": "hermes-cli",
|
||||
"scope": "inference:mint_agent_key",
|
||||
"token_type": "Bearer",
|
||||
"access_token": token,
|
||||
"refresh_token": "refresh-token",
|
||||
"obtained_at": "2026-03-23T10:00:00+00:00",
|
||||
"expires_at": "2026-03-23T11:00:00+00:00",
|
||||
"expires_in": 3600,
|
||||
"agent_key": "ak-test",
|
||||
"agent_key_id": "ak-id",
|
||||
"agent_key_expires_at": "2026-03-23T10:30:00+00:00",
|
||||
"agent_key_expires_in": 1800,
|
||||
"agent_key_reused": False,
|
||||
"agent_key_obtained_at": "2026-03-23T10:00:10+00:00",
|
||||
"tls": {"insecure": False, "ca_bundle": None},
|
||||
},
|
||||
)
|
||||
|
||||
from hermes_cli.auth_commands import auth_add_command
|
||||
|
||||
class _Args:
|
||||
provider = "nous"
|
||||
auth_type = "oauth"
|
||||
api_key = None
|
||||
label = "my-nous"
|
||||
portal_url = None
|
||||
inference_url = None
|
||||
client_id = None
|
||||
scope = None
|
||||
no_browser = False
|
||||
timeout = None
|
||||
insecure = False
|
||||
ca_bundle = None
|
||||
|
||||
auth_add_command(_Args())
|
||||
|
||||
payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
|
||||
|
||||
# Custom label reaches the pool entry …
|
||||
pool_entry = payload["credential_pool"]["nous"][0]
|
||||
assert pool_entry["source"] == "device_code"
|
||||
assert pool_entry["label"] == "my-nous"
|
||||
|
||||
# … and survives in providers.nous so a subsequent load_pool() re-seeds
|
||||
# it without reverting to the auto-derived fingerprint.
|
||||
assert payload["providers"]["nous"]["label"] == "my-nous"
|
||||
|
||||
|
||||
def test_auth_add_codex_oauth_persists_pool_entry(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
|
|
|
|||
|
|
@ -456,3 +456,258 @@ class TestLoginNousSkipKeepsCurrent:
|
|||
assert "nous" in auth_after.get("providers", {})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# persist_nous_credentials: shared helper for CLI + web dashboard login paths
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _full_state_fixture() -> dict:
|
||||
"""Shape of the dict returned by _nous_device_code_login /
|
||||
refresh_nous_oauth_from_state. Used as helper input."""
|
||||
return {
|
||||
"portal_base_url": "https://portal.example.com",
|
||||
"inference_base_url": "https://inference.example.com/v1",
|
||||
"client_id": "hermes-cli",
|
||||
"scope": "inference:mint_agent_key",
|
||||
"token_type": "Bearer",
|
||||
"access_token": "access-tok",
|
||||
"refresh_token": "refresh-tok",
|
||||
"obtained_at": "2026-04-17T22:00:00+00:00",
|
||||
"expires_at": "2026-04-17T22:15:00+00:00",
|
||||
"expires_in": 900,
|
||||
"agent_key": "agent-key-value",
|
||||
"agent_key_id": "ak-id",
|
||||
"agent_key_expires_at": "2026-04-18T22:00:00+00:00",
|
||||
"agent_key_expires_in": 86400,
|
||||
"agent_key_reused": False,
|
||||
"agent_key_obtained_at": "2026-04-17T22:00:10+00:00",
|
||||
"tls": {"insecure": False, "ca_bundle": None},
|
||||
}
|
||||
|
||||
|
||||
def test_persist_nous_credentials_writes_both_pool_and_providers(tmp_path, monkeypatch):
|
||||
"""Helper must populate BOTH credential_pool.nous AND providers.nous.
|
||||
|
||||
Regression guard: before this helper existed, `hermes auth add nous`
|
||||
wrote only the pool. After the Nous agent_key's 24h TTL expired, the
|
||||
401-recovery path in run_agent.py called resolve_nous_runtime_credentials
|
||||
which reads providers.nous, found it empty, raised AuthError, and the
|
||||
agent failed with "Non-retryable client error". Both stores must stay
|
||||
in sync at write time.
|
||||
"""
|
||||
from hermes_cli.auth import persist_nous_credentials, NOUS_DEVICE_CODE_SOURCE
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1, "providers": {},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
entry = persist_nous_credentials(_full_state_fixture())
|
||||
|
||||
assert entry is not None
|
||||
assert entry.provider == "nous"
|
||||
assert entry.source == NOUS_DEVICE_CODE_SOURCE
|
||||
|
||||
payload = json.loads((hermes_home / "auth.json").read_text())
|
||||
|
||||
# providers.nous populated with the full state (new behaviour)
|
||||
singleton = payload["providers"]["nous"]
|
||||
assert singleton["access_token"] == "access-tok"
|
||||
assert singleton["refresh_token"] == "refresh-tok"
|
||||
assert singleton["agent_key"] == "agent-key-value"
|
||||
assert singleton["agent_key_expires_at"] == "2026-04-18T22:00:00+00:00"
|
||||
|
||||
# credential_pool.nous has exactly one canonical device_code entry
|
||||
pool_entries = payload["credential_pool"]["nous"]
|
||||
assert len(pool_entries) == 1, pool_entries
|
||||
pool_entry = pool_entries[0]
|
||||
assert pool_entry["source"] == NOUS_DEVICE_CODE_SOURCE
|
||||
assert pool_entry["agent_key"] == "agent-key-value"
|
||||
assert pool_entry["inference_base_url"] == "https://inference.example.com/v1"
|
||||
|
||||
|
||||
def test_persist_nous_credentials_allows_recovery_from_401(tmp_path, monkeypatch):
|
||||
"""End-to-end: after persisting via the helper, resolve_nous_runtime_credentials
|
||||
must succeed (not raise "Hermes is not logged into Nous Portal").
|
||||
|
||||
This is the exact path that run_agent.py's `_try_refresh_nous_client_credentials`
|
||||
calls after a Nous 401 — before the fix it would raise AuthError because
|
||||
providers.nous was empty.
|
||||
"""
|
||||
from hermes_cli.auth import persist_nous_credentials, resolve_nous_runtime_credentials
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1, "providers": {},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
persist_nous_credentials(_full_state_fixture())
|
||||
|
||||
# Stub the network-touching steps so we don't actually contact the
|
||||
# portal — the point of this test is that state lookup succeeds and
|
||||
# doesn't raise "Hermes is not logged into Nous Portal".
|
||||
def _fake_refresh_access_token(*, client, portal_base_url, client_id, refresh_token):
|
||||
return {
|
||||
"access_token": "access-new",
|
||||
"refresh_token": "refresh-new",
|
||||
"expires_in": 900,
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
|
||||
def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds):
|
||||
return _mint_payload(api_key="new-agent-key")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.auth._refresh_access_token", _fake_refresh_access_token)
|
||||
monkeypatch.setattr("hermes_cli.auth._mint_agent_key", _fake_mint_agent_key)
|
||||
|
||||
creds = resolve_nous_runtime_credentials(min_key_ttl_seconds=300, force_mint=True)
|
||||
assert creds["api_key"] == "new-agent-key"
|
||||
|
||||
|
||||
def test_persist_nous_credentials_idempotent_no_duplicate_pool_entries(tmp_path, monkeypatch):
|
||||
"""Re-running persist must upsert — not accumulate duplicate device_code rows.
|
||||
|
||||
Regression guard for the review comment on PR #11858: before normalisation,
|
||||
the helper wrote `manual:device_code` while `_seed_from_singletons` wrote
|
||||
`device_code`, so the pool grew a second duplicate entry on every
|
||||
``load_pool()``. The helper now writes providers.nous and lets seeding
|
||||
materialise the pool entry under the canonical ``device_code`` source, so
|
||||
two persists still leave the pool with exactly one row.
|
||||
"""
|
||||
from hermes_cli.auth import persist_nous_credentials, NOUS_DEVICE_CODE_SOURCE
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1, "providers": {},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
first = _full_state_fixture()
|
||||
persist_nous_credentials(first)
|
||||
|
||||
second = _full_state_fixture()
|
||||
second["access_token"] = "access-second"
|
||||
second["agent_key"] = "agent-key-second"
|
||||
persist_nous_credentials(second)
|
||||
|
||||
payload = json.loads((hermes_home / "auth.json").read_text())
|
||||
|
||||
# providers.nous reflects the latest write (singleton semantics)
|
||||
assert payload["providers"]["nous"]["access_token"] == "access-second"
|
||||
assert payload["providers"]["nous"]["agent_key"] == "agent-key-second"
|
||||
|
||||
# credential_pool.nous has exactly one entry, carrying the latest agent_key
|
||||
pool_entries = payload["credential_pool"]["nous"]
|
||||
assert len(pool_entries) == 1, pool_entries
|
||||
assert pool_entries[0]["source"] == NOUS_DEVICE_CODE_SOURCE
|
||||
assert pool_entries[0]["agent_key"] == "agent-key-second"
|
||||
# And no stray `manual:device_code` / `manual:dashboard_device_code` rows
|
||||
assert not any(
|
||||
e["source"].startswith("manual:") for e in pool_entries
|
||||
)
|
||||
|
||||
|
||||
def test_persist_nous_credentials_reloads_pool_after_singleton_write(tmp_path, monkeypatch):
|
||||
"""The entry returned by the helper must come from a fresh ``load_pool`` so
|
||||
callers observe the canonical seeded state, including any legacy entries
|
||||
that ``_seed_from_singletons`` pruned or upserted.
|
||||
"""
|
||||
from hermes_cli.auth import persist_nous_credentials, NOUS_DEVICE_CODE_SOURCE
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1, "providers": {},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
entry = persist_nous_credentials(_full_state_fixture())
|
||||
assert entry is not None
|
||||
assert entry.source == NOUS_DEVICE_CODE_SOURCE
|
||||
# Label derived by _seed_from_singletons via label_from_token; we don't
|
||||
# assert its exact value, just that the helper returned a real entry.
|
||||
assert entry.access_token == "access-tok"
|
||||
assert entry.agent_key == "agent-key-value"
|
||||
|
||||
|
||||
def test_persist_nous_credentials_embeds_custom_label(tmp_path, monkeypatch):
|
||||
"""User-supplied ``--label`` round-trips through providers.nous and the pool.
|
||||
|
||||
Previously `hermes auth add nous --type oauth --label <name>` silently
|
||||
dropped the label because persist_nous_credentials() ignored it and
|
||||
_seed_from_singletons always auto-derived via label_from_token(). The
|
||||
fix stashes the label inside providers.nous so seeding prefers it.
|
||||
"""
|
||||
from hermes_cli.auth import persist_nous_credentials, NOUS_DEVICE_CODE_SOURCE
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1, "providers": {},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
entry = persist_nous_credentials(_full_state_fixture(), label="my-personal")
|
||||
assert entry is not None
|
||||
assert entry.source == NOUS_DEVICE_CODE_SOURCE
|
||||
assert entry.label == "my-personal"
|
||||
|
||||
# providers.nous carries the label so re-seeding on the next load_pool
|
||||
# doesn't overwrite it with the auto-derived fingerprint.
|
||||
payload = json.loads((hermes_home / "auth.json").read_text())
|
||||
assert payload["providers"]["nous"]["label"] == "my-personal"
|
||||
|
||||
|
||||
def test_persist_nous_credentials_custom_label_survives_reseed(tmp_path, monkeypatch):
|
||||
"""Reopening the pool (which re-runs _seed_from_singletons) must keep the
|
||||
user-chosen label instead of clobbering it with label_from_token output.
|
||||
"""
|
||||
from hermes_cli.auth import persist_nous_credentials
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1, "providers": {},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
persist_nous_credentials(_full_state_fixture(), label="work-acct")
|
||||
|
||||
# Second load_pool triggers _seed_from_singletons again. Without the
|
||||
# fix, this call overwrote the label with label_from_token(access_token).
|
||||
pool = load_pool("nous")
|
||||
entries = pool.entries()
|
||||
assert len(entries) == 1
|
||||
assert entries[0].label == "work-acct"
|
||||
|
||||
|
||||
def test_persist_nous_credentials_no_label_uses_auto_derived(tmp_path, monkeypatch):
|
||||
"""When the caller doesn't pass ``label``, the auto-derived fingerprint
|
||||
is used (unchanged default behaviour — regression guard).
|
||||
"""
|
||||
from hermes_cli.auth import persist_nous_credentials
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1, "providers": {},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
entry = persist_nous_credentials(_full_state_fixture())
|
||||
assert entry is not None
|
||||
# label_from_token derives from the access_token; exact value depends on
|
||||
# the fingerprinter but it must not be empty and must not equal an
|
||||
# arbitrary user string we never passed.
|
||||
assert entry.label
|
||||
assert entry.label != "my-personal"
|
||||
|
||||
# No "label" key embedded in providers.nous when the caller didn't supply one.
|
||||
payload = json.loads((hermes_home / "auth.json").read_text())
|
||||
assert "label" not in payload["providers"]["nous"]
|
||||
|
|
|
|||
294
tests/hermes_cli/test_aux_config.py
Normal file
294
tests/hermes_cli/test_aux_config.py
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
"""Tests for the auxiliary-model configuration UI in ``hermes model``.
|
||||
|
||||
Covers the helper functions:
|
||||
- ``_save_aux_choice`` writes to config.yaml without touching main model config
|
||||
- ``_reset_aux_to_auto`` clears routing fields but preserves timeouts
|
||||
- ``_format_aux_current`` renders current task config for the menu
|
||||
- ``_AUX_TASKS`` stays in sync with ``DEFAULT_CONFIG["auxiliary"]``
|
||||
|
||||
These are pure-function tests — the interactive menu loops are not covered
|
||||
here (they're stdin-driven curses prompts).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.config import DEFAULT_CONFIG, load_config
|
||||
from hermes_cli.main import (
|
||||
_AUX_TASKS,
|
||||
_format_aux_current,
|
||||
_reset_aux_to_auto,
|
||||
_save_aux_choice,
|
||||
)
|
||||
|
||||
|
||||
# ── Default config ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_title_generation_present_in_default_config():
|
||||
"""`title_generation` task must be defined in DEFAULT_CONFIG.
|
||||
|
||||
Regression for an existing gap: title_generator.py calls
|
||||
``call_llm(task="title_generation", ...)`` but the task was missing
|
||||
from DEFAULT_CONFIG["auxiliary"], so the config-backed timeout/provider
|
||||
overrides never worked for that task.
|
||||
"""
|
||||
assert "title_generation" in DEFAULT_CONFIG["auxiliary"]
|
||||
tg = DEFAULT_CONFIG["auxiliary"]["title_generation"]
|
||||
assert tg["provider"] == "auto"
|
||||
assert tg["model"] == ""
|
||||
assert tg["timeout"] > 0
|
||||
|
||||
|
||||
def test_aux_tasks_keys_all_exist_in_default_config():
|
||||
"""Every task the menu offers must be defined in DEFAULT_CONFIG."""
|
||||
aux_keys = {k for k, _name, _desc in _AUX_TASKS}
|
||||
default_keys = set(DEFAULT_CONFIG["auxiliary"].keys())
|
||||
missing = aux_keys - default_keys
|
||||
assert not missing, (
|
||||
f"_AUX_TASKS references tasks not in DEFAULT_CONFIG.auxiliary: {missing}"
|
||||
)
|
||||
|
||||
|
||||
# ── _format_aux_current ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"task_cfg,expected",
|
||||
[
|
||||
({}, "auto"),
|
||||
({"provider": "", "model": ""}, "auto"),
|
||||
({"provider": "auto", "model": ""}, "auto"),
|
||||
({"provider": "auto", "model": "gpt-4o"}, "auto · gpt-4o"),
|
||||
({"provider": "openrouter", "model": ""}, "openrouter"),
|
||||
(
|
||||
{"provider": "openrouter", "model": "google/gemini-2.5-flash"},
|
||||
"openrouter · google/gemini-2.5-flash",
|
||||
),
|
||||
({"provider": "nous", "model": "gemini-3-flash"}, "nous · gemini-3-flash"),
|
||||
(
|
||||
{"provider": "custom", "base_url": "http://localhost:11434/v1", "model": ""},
|
||||
"custom (localhost:11434/v1)",
|
||||
),
|
||||
(
|
||||
{
|
||||
"provider": "custom",
|
||||
"base_url": "http://localhost:11434/v1/",
|
||||
"model": "qwen2.5:32b",
|
||||
},
|
||||
"custom (localhost:11434/v1) · qwen2.5:32b",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_format_aux_current(task_cfg, expected):
|
||||
assert _format_aux_current(task_cfg) == expected
|
||||
|
||||
|
||||
def test_format_aux_current_handles_non_dict():
|
||||
assert _format_aux_current(None) == "auto"
|
||||
assert _format_aux_current("string") == "auto"
|
||||
|
||||
|
||||
# ── _save_aux_choice ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_save_aux_choice_persists_to_config_yaml(tmp_path, monkeypatch):
|
||||
"""Saving a task writes provider/model/base_url/api_key to auxiliary.<task>."""
|
||||
from pathlib import Path
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
(tmp_path / ".hermes").mkdir(exist_ok=True)
|
||||
|
||||
_save_aux_choice(
|
||||
"vision", provider="openrouter", model="google/gemini-2.5-flash",
|
||||
)
|
||||
cfg = load_config()
|
||||
v = cfg["auxiliary"]["vision"]
|
||||
assert v["provider"] == "openrouter"
|
||||
assert v["model"] == "google/gemini-2.5-flash"
|
||||
assert v["base_url"] == ""
|
||||
assert v["api_key"] == ""
|
||||
|
||||
|
||||
def test_save_aux_choice_preserves_timeout(tmp_path, monkeypatch):
|
||||
"""Saving must NOT clobber user-tuned timeout values."""
|
||||
from pathlib import Path
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
(tmp_path / ".hermes").mkdir(exist_ok=True)
|
||||
|
||||
# Default vision timeout is 120
|
||||
cfg_before = load_config()
|
||||
default_timeout = cfg_before["auxiliary"]["vision"]["timeout"]
|
||||
assert default_timeout == 120
|
||||
|
||||
_save_aux_choice("vision", provider="nous", model="gemini-3-flash")
|
||||
cfg_after = load_config()
|
||||
assert cfg_after["auxiliary"]["vision"]["timeout"] == default_timeout
|
||||
# download_timeout also preserved for vision
|
||||
assert cfg_after["auxiliary"]["vision"].get("download_timeout") == 30
|
||||
|
||||
|
||||
def test_save_aux_choice_does_not_touch_main_model(tmp_path, monkeypatch):
|
||||
"""Aux config must never mutate model.default / model.provider / model.base_url."""
|
||||
from pathlib import Path
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
(tmp_path / ".hermes").mkdir(exist_ok=True)
|
||||
|
||||
# Simulate a configured main model
|
||||
from hermes_cli.config import save_config
|
||||
|
||||
cfg = load_config()
|
||||
cfg["model"] = {
|
||||
"default": "claude-sonnet-4.6",
|
||||
"provider": "anthropic",
|
||||
"base_url": "",
|
||||
}
|
||||
save_config(cfg)
|
||||
|
||||
_save_aux_choice(
|
||||
"compression", provider="custom",
|
||||
base_url="http://localhost:11434/v1", model="qwen2.5:32b",
|
||||
)
|
||||
|
||||
cfg = load_config()
|
||||
# Main model untouched
|
||||
assert cfg["model"]["default"] == "claude-sonnet-4.6"
|
||||
assert cfg["model"]["provider"] == "anthropic"
|
||||
# Aux saved correctly
|
||||
c = cfg["auxiliary"]["compression"]
|
||||
assert c["provider"] == "custom"
|
||||
assert c["model"] == "qwen2.5:32b"
|
||||
assert c["base_url"] == "http://localhost:11434/v1"
|
||||
|
||||
|
||||
def test_save_aux_choice_creates_missing_task_entry(tmp_path, monkeypatch):
|
||||
"""Saving a task that was wiped from config.yaml should recreate it."""
|
||||
from pathlib import Path
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
(tmp_path / ".hermes").mkdir(exist_ok=True)
|
||||
|
||||
# Remove vision from config entirely
|
||||
from hermes_cli.config import save_config
|
||||
|
||||
cfg = load_config()
|
||||
cfg.setdefault("auxiliary", {}).pop("vision", None)
|
||||
save_config(cfg)
|
||||
|
||||
_save_aux_choice("vision", provider="nous", model="gemini-3-flash")
|
||||
cfg = load_config()
|
||||
assert cfg["auxiliary"]["vision"]["provider"] == "nous"
|
||||
assert cfg["auxiliary"]["vision"]["model"] == "gemini-3-flash"
|
||||
|
||||
|
||||
# ── _reset_aux_to_auto ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_reset_aux_to_auto_clears_routing_preserves_timeouts(tmp_path, monkeypatch):
|
||||
from pathlib import Path
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
(tmp_path / ".hermes").mkdir(exist_ok=True)
|
||||
|
||||
# Configure two tasks non-auto, and bump a timeout
|
||||
_save_aux_choice("vision", provider="openrouter", model="gpt-4o")
|
||||
_save_aux_choice("compression", provider="nous", model="gemini-3-flash")
|
||||
from hermes_cli.config import save_config
|
||||
|
||||
cfg = load_config()
|
||||
cfg["auxiliary"]["vision"]["timeout"] = 300 # user-tuned
|
||||
save_config(cfg)
|
||||
|
||||
n = _reset_aux_to_auto()
|
||||
assert n == 2 # both changed
|
||||
|
||||
cfg = load_config()
|
||||
for task in ("vision", "compression"):
|
||||
v = cfg["auxiliary"][task]
|
||||
assert v["provider"] == "auto"
|
||||
assert v["model"] == ""
|
||||
assert v["base_url"] == ""
|
||||
assert v["api_key"] == ""
|
||||
# User-tuned timeout survives reset
|
||||
assert cfg["auxiliary"]["vision"]["timeout"] == 300
|
||||
# Default compression timeout preserved
|
||||
assert cfg["auxiliary"]["compression"]["timeout"] == 120
|
||||
|
||||
|
||||
def test_reset_aux_to_auto_idempotent(tmp_path, monkeypatch):
|
||||
"""Second reset on already-auto config returns 0 without errors."""
|
||||
from pathlib import Path
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
(tmp_path / ".hermes").mkdir(exist_ok=True)
|
||||
|
||||
assert _reset_aux_to_auto() == 0
|
||||
_save_aux_choice("vision", provider="nous", model="gemini-3-flash")
|
||||
assert _reset_aux_to_auto() == 1
|
||||
assert _reset_aux_to_auto() == 0
|
||||
|
||||
|
||||
# ── Menu dispatch ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_select_provider_and_model_dispatches_to_aux_menu(tmp_path, monkeypatch):
|
||||
"""Picking 'Configure auxiliary models...' in the provider list calls _aux_config_menu."""
|
||||
from pathlib import Path
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
(tmp_path / ".hermes").mkdir(exist_ok=True)
|
||||
|
||||
from hermes_cli import main as main_mod
|
||||
|
||||
called = {"aux": 0, "flow": 0}
|
||||
|
||||
def fake_prompt(choices, *, default=0):
|
||||
# Find the aux-config entry by its label text and return its index
|
||||
for i, label in enumerate(choices):
|
||||
if "Configure auxiliary models" in label:
|
||||
return i
|
||||
raise AssertionError("aux entry not in provider list")
|
||||
|
||||
monkeypatch.setattr(main_mod, "_prompt_provider_choice", fake_prompt)
|
||||
monkeypatch.setattr(main_mod, "_aux_config_menu", lambda: called.__setitem__("aux", called["aux"] + 1))
|
||||
# Guard against any main flow accidentally running
|
||||
monkeypatch.setattr(main_mod, "_model_flow_openrouter",
|
||||
lambda *a, **kw: called.__setitem__("flow", called["flow"] + 1))
|
||||
|
||||
main_mod.select_provider_and_model()
|
||||
|
||||
assert called["aux"] == 1, "aux menu not invoked"
|
||||
assert called["flow"] == 0, "main provider flow should not run"
|
||||
|
||||
|
||||
def test_leave_unchanged_replaces_cancel_label(tmp_path, monkeypatch):
|
||||
"""The bottom cancel entry now reads 'Leave unchanged' (UX polish)."""
|
||||
from pathlib import Path
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
(tmp_path / ".hermes").mkdir(exist_ok=True)
|
||||
|
||||
from hermes_cli import main as main_mod
|
||||
|
||||
captured: list[list[str]] = []
|
||||
|
||||
def fake_prompt(choices, *, default=0):
|
||||
captured.append(list(choices))
|
||||
# Pick 'Leave unchanged' (last item) to exit cleanly
|
||||
for i, label in enumerate(choices):
|
||||
if label == "Leave unchanged":
|
||||
return i
|
||||
raise AssertionError("Leave unchanged not in provider list")
|
||||
|
||||
monkeypatch.setattr(main_mod, "_prompt_provider_choice", fake_prompt)
|
||||
|
||||
main_mod.select_provider_and_model()
|
||||
|
||||
assert captured, "provider menu never rendered"
|
||||
labels = captured[0]
|
||||
assert "Leave unchanged" in labels
|
||||
assert "Cancel" not in labels, "Cancel label should be replaced"
|
||||
assert any("Configure auxiliary models" in label for label in labels)
|
||||
|
|
@ -106,6 +106,49 @@ class TestCmdUpdateBranchFallback:
|
|||
pull_cmds = [c for c in commands if "pull" in c]
|
||||
assert len(pull_cmds) == 0
|
||||
|
||||
@patch("shutil.which")
|
||||
@patch("subprocess.run")
|
||||
def test_update_refreshes_repo_and_tui_node_dependencies(
|
||||
self, mock_run, mock_which, mock_args
|
||||
):
|
||||
mock_which.side_effect = {"uv": "/usr/bin/uv", "npm": "/usr/bin/npm"}.get
|
||||
mock_run.side_effect = _make_run_side_effect(
|
||||
branch="main", verify_ok=True, commit_count="1"
|
||||
)
|
||||
|
||||
cmd_update(mock_args)
|
||||
|
||||
npm_calls = [
|
||||
(call.args[0], call.kwargs.get("cwd"))
|
||||
for call in mock_run.call_args_list
|
||||
if call.args and call.args[0][0] == "/usr/bin/npm"
|
||||
]
|
||||
|
||||
assert npm_calls == [
|
||||
(
|
||||
[
|
||||
"/usr/bin/npm",
|
||||
"install",
|
||||
"--silent",
|
||||
"--no-fund",
|
||||
"--no-audit",
|
||||
"--progress=false",
|
||||
],
|
||||
PROJECT_ROOT,
|
||||
),
|
||||
(
|
||||
[
|
||||
"/usr/bin/npm",
|
||||
"install",
|
||||
"--silent",
|
||||
"--no-fund",
|
||||
"--no-audit",
|
||||
"--progress=false",
|
||||
],
|
||||
PROJECT_ROOT / "ui-tui",
|
||||
),
|
||||
]
|
||||
|
||||
def test_update_non_interactive_skips_migration_prompt(self, mock_args, capsys):
|
||||
"""When stdin/stdout aren't TTYs, config migration prompt is skipped."""
|
||||
with patch("shutil.which", return_value=None), patch(
|
||||
|
|
|
|||
|
|
@ -93,6 +93,8 @@ class TestResolveCommand:
|
|||
def test_canonical_name_resolves(self):
|
||||
assert resolve_command("help").name == "help"
|
||||
assert resolve_command("background").name == "background"
|
||||
assert resolve_command("copy").name == "copy"
|
||||
assert resolve_command("agents").name == "agents"
|
||||
|
||||
def test_alias_resolves_to_canonical(self):
|
||||
assert resolve_command("bg").name == "background"
|
||||
|
|
@ -102,6 +104,7 @@ class TestResolveCommand:
|
|||
assert resolve_command("gateway").name == "platforms"
|
||||
assert resolve_command("set-home").name == "sethome"
|
||||
assert resolve_command("reload_mcp").name == "reload-mcp"
|
||||
assert resolve_command("tasks").name == "agents"
|
||||
|
||||
def test_leading_slash_stripped(self):
|
||||
assert resolve_command("/help").name == "help"
|
||||
|
|
|
|||
169
tests/hermes_cli/test_config_env_refs.py
Normal file
169
tests/hermes_cli/test_config_env_refs.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
import textwrap
|
||||
|
||||
from hermes_cli.config import load_config, save_config
|
||||
|
||||
|
||||
def _write_config(tmp_path, body: str):
|
||||
(tmp_path / "config.yaml").write_text(textwrap.dedent(body), encoding="utf-8")
|
||||
|
||||
|
||||
def _read_config(tmp_path) -> str:
|
||||
return (tmp_path / "config.yaml").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_save_config_preserves_env_refs_on_unrelated_change(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("TU_ZI_API_KEY", "sk-realsecret")
|
||||
monkeypatch.setenv("ALT_SECRET", "alt-secret")
|
||||
_write_config(
|
||||
tmp_path,
|
||||
"""\
|
||||
custom_providers:
|
||||
- name: tuzi
|
||||
base_url: https://api.tu-zi.com
|
||||
api_key: ${TU_ZI_API_KEY}
|
||||
headers:
|
||||
Authorization: Bearer ${ALT_SECRET}
|
||||
model: claude-opus-4-6
|
||||
model:
|
||||
default: claude-opus-4-6
|
||||
""",
|
||||
)
|
||||
|
||||
config = load_config()
|
||||
config["model"]["default"] = "doubao-pro"
|
||||
save_config(config)
|
||||
|
||||
saved = _read_config(tmp_path)
|
||||
assert "api_key: ${TU_ZI_API_KEY}" in saved
|
||||
assert "Authorization: Bearer ${ALT_SECRET}" in saved
|
||||
assert "sk-realsecret" not in saved
|
||||
assert "alt-secret" not in saved
|
||||
|
||||
|
||||
def test_save_config_preserves_unresolved_env_refs(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.delenv("MISSING_SECRET", raising=False)
|
||||
_write_config(
|
||||
tmp_path,
|
||||
"""\
|
||||
custom_providers:
|
||||
- name: unresolved
|
||||
api_key: ${MISSING_SECRET}
|
||||
model: claude-opus-4-6
|
||||
model:
|
||||
default: claude-opus-4-6
|
||||
""",
|
||||
)
|
||||
|
||||
config = load_config()
|
||||
config["display"]["compact"] = True
|
||||
save_config(config)
|
||||
|
||||
assert "api_key: ${MISSING_SECRET}" in _read_config(tmp_path)
|
||||
|
||||
|
||||
def test_save_config_allows_intentional_secret_value_change(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("TU_ZI_API_KEY", "sk-old-secret")
|
||||
_write_config(
|
||||
tmp_path,
|
||||
"""\
|
||||
custom_providers:
|
||||
- name: tuzi
|
||||
api_key: ${TU_ZI_API_KEY}
|
||||
model: claude-opus-4-6
|
||||
model:
|
||||
default: claude-opus-4-6
|
||||
""",
|
||||
)
|
||||
|
||||
config = load_config()
|
||||
config["custom_providers"][0]["api_key"] = "sk-new-secret"
|
||||
save_config(config)
|
||||
|
||||
saved = _read_config(tmp_path)
|
||||
assert "api_key: sk-new-secret" in saved
|
||||
assert "${TU_ZI_API_KEY}" not in saved
|
||||
|
||||
|
||||
def test_save_config_preserves_template_when_env_rotates_after_load(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("TU_ZI_API_KEY", "sk-old-secret")
|
||||
_write_config(
|
||||
tmp_path,
|
||||
"""\
|
||||
custom_providers:
|
||||
- name: tuzi
|
||||
api_key: ${TU_ZI_API_KEY}
|
||||
model: claude-opus-4-6
|
||||
model:
|
||||
default: claude-opus-4-6
|
||||
""",
|
||||
)
|
||||
|
||||
config = load_config()
|
||||
monkeypatch.setenv("TU_ZI_API_KEY", "sk-rotated-secret")
|
||||
config["model"]["default"] = "doubao-pro"
|
||||
save_config(config)
|
||||
|
||||
saved = _read_config(tmp_path)
|
||||
assert "api_key: ${TU_ZI_API_KEY}" in saved
|
||||
assert "sk-old-secret" not in saved
|
||||
assert "sk-rotated-secret" not in saved
|
||||
|
||||
|
||||
def test_save_config_keeps_edited_partial_template_strings_literal(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("ALT_SECRET", "alt-secret")
|
||||
_write_config(
|
||||
tmp_path,
|
||||
"""\
|
||||
custom_providers:
|
||||
- name: tuzi
|
||||
headers:
|
||||
Authorization: Bearer ${ALT_SECRET}
|
||||
model: claude-opus-4-6
|
||||
model:
|
||||
default: claude-opus-4-6
|
||||
""",
|
||||
)
|
||||
|
||||
config = load_config()
|
||||
config["custom_providers"][0]["headers"]["Authorization"] = "Token alt-secret"
|
||||
save_config(config)
|
||||
|
||||
saved = _read_config(tmp_path)
|
||||
assert "Authorization: Token alt-secret" in saved
|
||||
assert "Authorization: Bearer ${ALT_SECRET}" not in saved
|
||||
|
||||
|
||||
def test_save_config_falls_back_to_positional_matching_for_duplicate_names(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("FIRST_SECRET", "first-secret")
|
||||
monkeypatch.setenv("SECOND_SECRET", "second-secret")
|
||||
_write_config(
|
||||
tmp_path,
|
||||
"""\
|
||||
custom_providers:
|
||||
- name: duplicate
|
||||
api_key: ${FIRST_SECRET}
|
||||
model: claude-opus-4-6
|
||||
- name: duplicate
|
||||
api_key: ${SECOND_SECRET}
|
||||
model: doubao-pro
|
||||
model:
|
||||
default: claude-opus-4-6
|
||||
""",
|
||||
)
|
||||
|
||||
config = load_config()
|
||||
config["display"]["compact"] = True
|
||||
save_config(config)
|
||||
|
||||
saved = _read_config(tmp_path)
|
||||
assert saved.count("name: duplicate") == 2
|
||||
assert "api_key: ${FIRST_SECRET}" in saved
|
||||
assert "api_key: ${SECOND_SECRET}" in saved
|
||||
assert "first-secret" not in saved
|
||||
assert "second-secret" not in saved
|
||||
|
|
@ -501,40 +501,272 @@ class TestDeletePaste:
|
|||
|
||||
|
||||
class TestScheduleAutoDelete:
|
||||
def test_spawns_detached_process(self):
|
||||
"""``_schedule_auto_delete`` used to spawn a detached Python subprocess
|
||||
per call (one per paste URL batch). Those subprocesses slept 6 hours
|
||||
and accumulated forever under repeated use — 15+ orphaned interpreters
|
||||
were observed in production.
|
||||
|
||||
The new implementation is stateless: it records pending deletions to
|
||||
``~/.hermes/pastes/pending.json`` and lets ``_sweep_expired_pastes``
|
||||
handle the DELETE requests synchronously on the next ``hermes debug``
|
||||
invocation.
|
||||
"""
|
||||
|
||||
def test_does_not_spawn_subprocess(self, hermes_home):
|
||||
"""Regression guard: _schedule_auto_delete must NEVER spawn subprocesses.
|
||||
|
||||
We assert this structurally rather than by mocking Popen: the new
|
||||
implementation doesn't even import ``subprocess`` at module scope,
|
||||
so a mock patch wouldn't find it.
|
||||
"""
|
||||
import ast
|
||||
import inspect
|
||||
from hermes_cli.debug import _schedule_auto_delete
|
||||
|
||||
with patch("subprocess.Popen") as mock_popen:
|
||||
_schedule_auto_delete(
|
||||
["https://paste.rs/abc", "https://paste.rs/def"],
|
||||
delay_seconds=10,
|
||||
)
|
||||
# Strip the docstring before scanning so the regression-rationale
|
||||
# prose inside it doesn't trigger our banned-word checks.
|
||||
source = inspect.getsource(_schedule_auto_delete)
|
||||
tree = ast.parse(source)
|
||||
func_node = tree.body[0]
|
||||
if (
|
||||
func_node.body
|
||||
and isinstance(func_node.body[0], ast.Expr)
|
||||
and isinstance(func_node.body[0].value, ast.Constant)
|
||||
and isinstance(func_node.body[0].value.value, str)
|
||||
):
|
||||
func_node.body = func_node.body[1:]
|
||||
code_only = ast.unparse(func_node)
|
||||
|
||||
mock_popen.assert_called_once()
|
||||
call_args = mock_popen.call_args
|
||||
# Verify detached
|
||||
assert call_args[1]["start_new_session"] is True
|
||||
# Verify the script references both URLs
|
||||
script = call_args[0][0][2] # [python, -c, script]
|
||||
assert "paste.rs/abc" in script
|
||||
assert "paste.rs/def" in script
|
||||
assert "time.sleep(10)" in script
|
||||
assert "Popen" not in code_only, (
|
||||
"_schedule_auto_delete must not spawn subprocesses — "
|
||||
"use pending.json + _sweep_expired_pastes instead"
|
||||
)
|
||||
assert "subprocess" not in code_only, (
|
||||
"_schedule_auto_delete must not reference subprocess at all"
|
||||
)
|
||||
assert "time.sleep" not in code_only, (
|
||||
"Regression: sleeping in _schedule_auto_delete is the bug being fixed"
|
||||
)
|
||||
|
||||
def test_skips_non_paste_rs_urls(self):
|
||||
from hermes_cli.debug import _schedule_auto_delete
|
||||
# And verify that calling it doesn't produce any orphaned children
|
||||
# (it should just write pending.json synchronously).
|
||||
import os as _os
|
||||
before = set(_os.listdir("/proc")) if _os.path.exists("/proc") else None
|
||||
_schedule_auto_delete(
|
||||
["https://paste.rs/abc", "https://paste.rs/def"],
|
||||
delay_seconds=10,
|
||||
)
|
||||
if before is not None:
|
||||
after = set(_os.listdir("/proc"))
|
||||
new = after - before
|
||||
# Filter to only integer-named entries (process PIDs)
|
||||
new_pids = [p for p in new if p.isdigit()]
|
||||
# It's fine if unrelated processes appeared — we just need to make
|
||||
# sure we didn't spawn a long-sleeping one. The old bug spawned
|
||||
# a python interpreter whose cmdline contained "time.sleep".
|
||||
for pid in new_pids:
|
||||
try:
|
||||
with open(f"/proc/{pid}/cmdline", "rb") as f:
|
||||
cmdline = f.read().decode("utf-8", errors="replace")
|
||||
assert "time.sleep" not in cmdline, (
|
||||
f"Leaked sleeper subprocess PID {pid}: {cmdline}"
|
||||
)
|
||||
except OSError:
|
||||
pass # process exited already
|
||||
|
||||
with patch("subprocess.Popen") as mock_popen:
|
||||
_schedule_auto_delete(["https://dpaste.com/something"])
|
||||
def test_records_pending_to_json(self, hermes_home):
|
||||
"""Scheduled URLs are persisted to pending.json with expiration."""
|
||||
from hermes_cli.debug import _schedule_auto_delete, _pending_file
|
||||
import json
|
||||
|
||||
mock_popen.assert_not_called()
|
||||
_schedule_auto_delete(
|
||||
["https://paste.rs/abc", "https://paste.rs/def"],
|
||||
delay_seconds=10,
|
||||
)
|
||||
|
||||
def test_handles_popen_failure_gracefully(self):
|
||||
from hermes_cli.debug import _schedule_auto_delete
|
||||
pending_path = _pending_file()
|
||||
assert pending_path.exists()
|
||||
|
||||
with patch("subprocess.Popen",
|
||||
side_effect=OSError("no such file")):
|
||||
# Should not raise
|
||||
_schedule_auto_delete(["https://paste.rs/abc"])
|
||||
entries = json.loads(pending_path.read_text())
|
||||
assert len(entries) == 2
|
||||
urls = {e["url"] for e in entries}
|
||||
assert urls == {"https://paste.rs/abc", "https://paste.rs/def"}
|
||||
|
||||
# expire_at is ~now + delay_seconds
|
||||
import time
|
||||
for e in entries:
|
||||
assert e["expire_at"] > time.time()
|
||||
assert e["expire_at"] <= time.time() + 15
|
||||
|
||||
def test_skips_non_paste_rs_urls(self, hermes_home):
|
||||
"""dpaste.com URLs auto-expire — don't track them."""
|
||||
from hermes_cli.debug import _schedule_auto_delete, _pending_file
|
||||
|
||||
_schedule_auto_delete(["https://dpaste.com/something"])
|
||||
|
||||
# pending.json should not be created for non-paste.rs URLs
|
||||
assert not _pending_file().exists()
|
||||
|
||||
def test_merges_with_existing_pending(self, hermes_home):
|
||||
"""Subsequent calls merge into existing pending.json."""
|
||||
from hermes_cli.debug import _schedule_auto_delete, _load_pending
|
||||
|
||||
_schedule_auto_delete(["https://paste.rs/first"], delay_seconds=10)
|
||||
_schedule_auto_delete(["https://paste.rs/second"], delay_seconds=10)
|
||||
|
||||
entries = _load_pending()
|
||||
urls = {e["url"] for e in entries}
|
||||
assert urls == {"https://paste.rs/first", "https://paste.rs/second"}
|
||||
|
||||
def test_dedupes_same_url(self, hermes_home):
|
||||
"""Same URL recorded twice → one entry with the later expire_at."""
|
||||
from hermes_cli.debug import _schedule_auto_delete, _load_pending
|
||||
|
||||
_schedule_auto_delete(["https://paste.rs/dup"], delay_seconds=10)
|
||||
_schedule_auto_delete(["https://paste.rs/dup"], delay_seconds=100)
|
||||
|
||||
entries = _load_pending()
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["url"] == "https://paste.rs/dup"
|
||||
|
||||
|
||||
class TestSweepExpiredPastes:
|
||||
"""Test the opportunistic sweep that replaces the sleeping subprocess."""
|
||||
|
||||
def test_sweep_empty_is_noop(self, hermes_home):
|
||||
from hermes_cli.debug import _sweep_expired_pastes
|
||||
|
||||
deleted, remaining = _sweep_expired_pastes()
|
||||
assert deleted == 0
|
||||
assert remaining == 0
|
||||
|
||||
def test_sweep_deletes_expired_entries(self, hermes_home):
|
||||
from hermes_cli.debug import (
|
||||
_sweep_expired_pastes,
|
||||
_save_pending,
|
||||
_load_pending,
|
||||
)
|
||||
import time
|
||||
|
||||
# Seed pending.json with one expired + one future entry
|
||||
_save_pending([
|
||||
{"url": "https://paste.rs/expired", "expire_at": time.time() - 100},
|
||||
{"url": "https://paste.rs/future", "expire_at": time.time() + 3600},
|
||||
])
|
||||
|
||||
delete_calls = []
|
||||
|
||||
def fake_delete(url):
|
||||
delete_calls.append(url)
|
||||
return True
|
||||
|
||||
with patch("hermes_cli.debug.delete_paste", side_effect=fake_delete):
|
||||
deleted, remaining = _sweep_expired_pastes()
|
||||
|
||||
assert delete_calls == ["https://paste.rs/expired"]
|
||||
assert deleted == 1
|
||||
assert remaining == 1
|
||||
|
||||
entries = _load_pending()
|
||||
urls = {e["url"] for e in entries}
|
||||
assert urls == {"https://paste.rs/future"}
|
||||
|
||||
def test_sweep_leaves_future_entries_alone(self, hermes_home):
|
||||
from hermes_cli.debug import _sweep_expired_pastes, _save_pending
|
||||
import time
|
||||
|
||||
_save_pending([
|
||||
{"url": "https://paste.rs/future1", "expire_at": time.time() + 3600},
|
||||
{"url": "https://paste.rs/future2", "expire_at": time.time() + 7200},
|
||||
])
|
||||
|
||||
with patch("hermes_cli.debug.delete_paste") as mock_delete:
|
||||
deleted, remaining = _sweep_expired_pastes()
|
||||
|
||||
mock_delete.assert_not_called()
|
||||
assert deleted == 0
|
||||
assert remaining == 2
|
||||
|
||||
def test_sweep_survives_network_failure(self, hermes_home):
|
||||
"""Failed DELETEs stay in pending.json until the 24h grace window."""
|
||||
from hermes_cli.debug import (
|
||||
_sweep_expired_pastes,
|
||||
_save_pending,
|
||||
_load_pending,
|
||||
)
|
||||
import time
|
||||
|
||||
_save_pending([
|
||||
{"url": "https://paste.rs/flaky", "expire_at": time.time() - 100},
|
||||
])
|
||||
|
||||
with patch(
|
||||
"hermes_cli.debug.delete_paste",
|
||||
side_effect=Exception("network down"),
|
||||
):
|
||||
deleted, remaining = _sweep_expired_pastes()
|
||||
|
||||
# Failure within 24h grace → kept for retry
|
||||
assert deleted == 0
|
||||
assert remaining == 1
|
||||
assert len(_load_pending()) == 1
|
||||
|
||||
def test_sweep_drops_entries_past_grace_window(self, hermes_home):
|
||||
"""After 24h past expiration, give up even on network failures."""
|
||||
from hermes_cli.debug import (
|
||||
_sweep_expired_pastes,
|
||||
_save_pending,
|
||||
_load_pending,
|
||||
)
|
||||
import time
|
||||
|
||||
# Expired 25 hours ago → past the 24h grace window
|
||||
very_old = time.time() - (25 * 3600)
|
||||
_save_pending([
|
||||
{"url": "https://paste.rs/ancient", "expire_at": very_old},
|
||||
])
|
||||
|
||||
with patch(
|
||||
"hermes_cli.debug.delete_paste",
|
||||
side_effect=Exception("network down"),
|
||||
):
|
||||
deleted, remaining = _sweep_expired_pastes()
|
||||
|
||||
assert deleted == 1
|
||||
assert remaining == 0
|
||||
assert _load_pending() == []
|
||||
|
||||
|
||||
class TestRunDebugSweepsOnInvocation:
|
||||
"""``run_debug`` must sweep expired pastes on every invocation."""
|
||||
|
||||
def test_run_debug_calls_sweep(self, hermes_home):
|
||||
from hermes_cli.debug import run_debug
|
||||
|
||||
args = MagicMock()
|
||||
args.debug_command = None # default → prints help
|
||||
|
||||
with patch("hermes_cli.debug._sweep_expired_pastes") as mock_sweep:
|
||||
run_debug(args)
|
||||
|
||||
mock_sweep.assert_called_once()
|
||||
|
||||
def test_run_debug_survives_sweep_failure(self, hermes_home, capsys):
|
||||
"""If the sweep throws, the subcommand still runs."""
|
||||
from hermes_cli.debug import run_debug
|
||||
|
||||
args = MagicMock()
|
||||
args.debug_command = None
|
||||
|
||||
with patch(
|
||||
"hermes_cli.debug._sweep_expired_pastes",
|
||||
side_effect=RuntimeError("boom"),
|
||||
):
|
||||
run_debug(args) # must not raise
|
||||
|
||||
# Default subcommand still printed help
|
||||
out = capsys.readouterr().out
|
||||
assert "Usage: hermes debug" in out
|
||||
|
||||
|
||||
class TestRunDebugDelete:
|
||||
|
|
|
|||
|
|
@ -39,6 +39,76 @@ class TestSystemdLingerStatus:
|
|||
assert gateway.get_systemd_linger_status() == (None, "not supported in Termux")
|
||||
|
||||
|
||||
class TestContainerSystemdSupport:
|
||||
def test_supports_systemd_services_in_container_with_user_manager(self, monkeypatch):
|
||||
monkeypatch.setattr(gateway, "is_linux", lambda: True)
|
||||
monkeypatch.setattr(gateway, "is_termux", lambda: False)
|
||||
monkeypatch.setattr(gateway, "is_wsl", lambda: False)
|
||||
monkeypatch.setattr(gateway, "is_container", lambda: True)
|
||||
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/systemctl")
|
||||
monkeypatch.setattr(gateway, "_systemd_operational", lambda system=False: not system)
|
||||
|
||||
assert gateway.supports_systemd_services() is True
|
||||
|
||||
def test_supports_systemd_services_in_container_with_system_manager(self, monkeypatch):
|
||||
monkeypatch.setattr(gateway, "is_linux", lambda: True)
|
||||
monkeypatch.setattr(gateway, "is_termux", lambda: False)
|
||||
monkeypatch.setattr(gateway, "is_wsl", lambda: False)
|
||||
monkeypatch.setattr(gateway, "is_container", lambda: True)
|
||||
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/systemctl")
|
||||
monkeypatch.setattr(gateway, "_systemd_operational", lambda system=False: system)
|
||||
|
||||
assert gateway.supports_systemd_services() is True
|
||||
|
||||
def test_supports_systemd_services_in_container_without_systemd(self, monkeypatch):
|
||||
monkeypatch.setattr(gateway, "is_linux", lambda: True)
|
||||
monkeypatch.setattr(gateway, "is_termux", lambda: False)
|
||||
monkeypatch.setattr(gateway, "is_wsl", lambda: False)
|
||||
monkeypatch.setattr(gateway, "is_container", lambda: True)
|
||||
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/systemctl")
|
||||
monkeypatch.setattr(gateway, "_systemd_operational", lambda system=False: False)
|
||||
|
||||
assert gateway.supports_systemd_services() is False
|
||||
|
||||
|
||||
def test_gateway_install_in_container_with_operational_systemd_uses_systemd(monkeypatch):
|
||||
monkeypatch.setattr(gateway, "supports_systemd_services", lambda: True)
|
||||
monkeypatch.setattr(gateway, "is_wsl", lambda: False)
|
||||
monkeypatch.setattr(gateway, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway, "is_managed", lambda: False)
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr(
|
||||
gateway,
|
||||
"systemd_install",
|
||||
lambda force=False, system=False, run_as_user=None: calls.append((force, system, run_as_user)),
|
||||
)
|
||||
|
||||
args = SimpleNamespace(
|
||||
gateway_command="install",
|
||||
force=False,
|
||||
system=False,
|
||||
run_as_user=None,
|
||||
)
|
||||
gateway.gateway_command(args)
|
||||
|
||||
assert calls == [(False, False, None)]
|
||||
|
||||
|
||||
def test_gateway_start_in_container_with_operational_systemd_uses_systemd(monkeypatch):
|
||||
monkeypatch.setattr(gateway, "supports_systemd_services", lambda: True)
|
||||
monkeypatch.setattr(gateway, "is_wsl", lambda: False)
|
||||
monkeypatch.setattr(gateway, "is_macos", lambda: False)
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr(gateway, "systemd_start", lambda system=False: calls.append(system))
|
||||
|
||||
args = SimpleNamespace(gateway_command="start", system=False, all=False)
|
||||
gateway.gateway_command(args)
|
||||
|
||||
assert calls == [False]
|
||||
|
||||
|
||||
def test_systemd_status_warns_when_linger_disabled(monkeypatch, tmp_path, capsys):
|
||||
unit_path = tmp_path / "hermes-gateway.service"
|
||||
unit_path.write_text("[Unit]\n")
|
||||
|
|
@ -179,6 +249,21 @@ def test_install_linux_gateway_from_setup_system_choice_as_root_installs(monkeyp
|
|||
assert calls == [(True, True, "alice")]
|
||||
|
||||
|
||||
def test_find_gateway_pids_falls_back_to_pid_file_when_process_scan_fails(monkeypatch):
|
||||
monkeypatch.setattr(gateway, "_get_service_pids", lambda: set())
|
||||
monkeypatch.setattr(gateway, "is_windows", lambda: False)
|
||||
monkeypatch.setattr("gateway.status.get_running_pid", lambda: 321)
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
if cmd[:4] == ["ps", "-A", "eww", "-o"]:
|
||||
return SimpleNamespace(returncode=1, stdout="", stderr="ps failed")
|
||||
raise AssertionError(f"Unexpected command: {cmd}")
|
||||
|
||||
monkeypatch.setattr(gateway.subprocess, "run", fake_run)
|
||||
|
||||
assert gateway.find_gateway_pids() == [321]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _wait_for_gateway_exit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -450,7 +450,6 @@ class TestGatewayServiceDetection:
|
|||
|
||||
assert gateway_cli._is_service_running() is False
|
||||
|
||||
|
||||
class TestGatewaySystemServiceRouting:
|
||||
def test_systemd_restart_self_requests_graceful_restart_and_waits(self, monkeypatch, capsys):
|
||||
calls = []
|
||||
|
|
@ -554,6 +553,38 @@ class TestGatewaySystemServiceRouting:
|
|||
|
||||
assert calls == [(False, False)]
|
||||
|
||||
def test_gateway_status_reports_manual_process_when_service_is_stopped(self, monkeypatch, capsys):
|
||||
user_unit = SimpleNamespace(exists=lambda: True)
|
||||
system_unit = SimpleNamespace(exists=lambda: False)
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"get_systemd_unit_path",
|
||||
lambda system=False: system_unit if system else user_unit,
|
||||
)
|
||||
monkeypatch.setattr(gateway_cli, "systemd_status", lambda deep=False, system=False: print("service stopped"))
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"get_gateway_runtime_snapshot",
|
||||
lambda system=False: gateway_cli.GatewayRuntimeSnapshot(
|
||||
manager="systemd (user)",
|
||||
service_installed=True,
|
||||
service_running=False,
|
||||
gateway_pids=(4321,),
|
||||
service_scope="user",
|
||||
),
|
||||
)
|
||||
|
||||
gateway_cli.gateway_command(SimpleNamespace(gateway_command="status", deep=False, system=False))
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "service stopped" in out
|
||||
assert "Gateway process is running for this profile" in out
|
||||
assert "PID(s): 4321" in out
|
||||
|
||||
def test_gateway_status_on_termux_shows_manual_guidance(self, monkeypatch, capsys):
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_termux", lambda: True)
|
||||
|
|
@ -1146,3 +1177,556 @@ class TestDockerAwareGateway:
|
|||
out = capsys.readouterr().out
|
||||
assert "docker" in out.lower()
|
||||
assert "hermes gateway run" in out
|
||||
|
||||
|
||||
class TestLegacyHermesUnitDetection:
|
||||
"""Tests for _find_legacy_hermes_units / has_legacy_hermes_units.
|
||||
|
||||
These guard against the scenario that tripped Luis in April 2026: an
|
||||
older install left a ``hermes.service`` unit behind when the service was
|
||||
renamed to ``hermes-gateway.service``. After PR #5646 (signal recovery
|
||||
via systemd), the two services began SIGTERM-flapping over the same
|
||||
Telegram bot token in a 30-second cycle.
|
||||
|
||||
The detector must flag ``hermes.service`` ONLY when it actually runs our
|
||||
gateway, and must NEVER flag profile units
|
||||
(``hermes-gateway-<profile>.service``) or unrelated third-party services.
|
||||
"""
|
||||
|
||||
# Minimal ExecStart that looks like our gateway
|
||||
_OUR_UNIT_TEXT = (
|
||||
"[Unit]\nDescription=Hermes Gateway\n[Service]\n"
|
||||
"ExecStart=/usr/bin/python -m hermes_cli.main gateway run --replace\n"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _setup_search_paths(tmp_path, monkeypatch):
|
||||
"""Redirect the legacy search to user_dir + system_dir under tmp_path."""
|
||||
user_dir = tmp_path / "user"
|
||||
system_dir = tmp_path / "system"
|
||||
user_dir.mkdir()
|
||||
system_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"_legacy_unit_search_paths",
|
||||
lambda: [(False, user_dir), (True, system_dir)],
|
||||
)
|
||||
return user_dir, system_dir
|
||||
|
||||
def test_detects_legacy_hermes_service_in_user_scope(self, tmp_path, monkeypatch):
|
||||
user_dir, _ = self._setup_search_paths(tmp_path, monkeypatch)
|
||||
legacy = user_dir / "hermes.service"
|
||||
legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
|
||||
results = gateway_cli._find_legacy_hermes_units()
|
||||
|
||||
assert len(results) == 1
|
||||
name, path, is_system = results[0]
|
||||
assert name == "hermes.service"
|
||||
assert path == legacy
|
||||
assert is_system is False
|
||||
assert gateway_cli.has_legacy_hermes_units() is True
|
||||
|
||||
def test_detects_legacy_hermes_service_in_system_scope(self, tmp_path, monkeypatch):
|
||||
_, system_dir = self._setup_search_paths(tmp_path, monkeypatch)
|
||||
legacy = system_dir / "hermes.service"
|
||||
legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
|
||||
results = gateway_cli._find_legacy_hermes_units()
|
||||
|
||||
assert len(results) == 1
|
||||
name, path, is_system = results[0]
|
||||
assert name == "hermes.service"
|
||||
assert path == legacy
|
||||
assert is_system is True
|
||||
|
||||
def test_ignores_profile_unit_hermes_gateway_coder(self, tmp_path, monkeypatch):
|
||||
"""CRITICAL: profile units must NOT be flagged as legacy.
|
||||
|
||||
Teknium's concern — ``hermes-gateway-coder.service`` is our standard
|
||||
naming for the ``coder`` profile. The legacy detector is an explicit
|
||||
allowlist, not a glob, so profile units are safe.
|
||||
"""
|
||||
user_dir, system_dir = self._setup_search_paths(tmp_path, monkeypatch)
|
||||
# Drop profile units in BOTH scopes with our ExecStart
|
||||
for base in (user_dir, system_dir):
|
||||
(base / "hermes-gateway-coder.service").write_text(
|
||||
self._OUR_UNIT_TEXT, encoding="utf-8"
|
||||
)
|
||||
(base / "hermes-gateway-orcha.service").write_text(
|
||||
self._OUR_UNIT_TEXT, encoding="utf-8"
|
||||
)
|
||||
(base / "hermes-gateway.service").write_text(
|
||||
self._OUR_UNIT_TEXT, encoding="utf-8"
|
||||
)
|
||||
|
||||
results = gateway_cli._find_legacy_hermes_units()
|
||||
|
||||
assert results == []
|
||||
assert gateway_cli.has_legacy_hermes_units() is False
|
||||
|
||||
def test_ignores_unrelated_hermes_service(self, tmp_path, monkeypatch):
|
||||
"""Third-party ``hermes.service`` that isn't ours stays untouched.
|
||||
|
||||
If a user has some other package named ``hermes`` installed as a
|
||||
service, we must not flag it.
|
||||
"""
|
||||
user_dir, _ = self._setup_search_paths(tmp_path, monkeypatch)
|
||||
(user_dir / "hermes.service").write_text(
|
||||
"[Unit]\nDescription=Some Other Hermes\n[Service]\n"
|
||||
"ExecStart=/opt/other-hermes/bin/daemon --foreground\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
results = gateway_cli._find_legacy_hermes_units()
|
||||
|
||||
assert results == []
|
||||
assert gateway_cli.has_legacy_hermes_units() is False
|
||||
|
||||
def test_returns_empty_when_no_legacy_files_exist(self, tmp_path, monkeypatch):
|
||||
self._setup_search_paths(tmp_path, monkeypatch)
|
||||
|
||||
assert gateway_cli._find_legacy_hermes_units() == []
|
||||
assert gateway_cli.has_legacy_hermes_units() is False
|
||||
|
||||
def test_detects_both_scopes_simultaneously(self, tmp_path, monkeypatch):
|
||||
"""When a user has BOTH user-scope and system-scope legacy units,
|
||||
both are reported so the migration step can remove them together."""
|
||||
user_dir, system_dir = self._setup_search_paths(tmp_path, monkeypatch)
|
||||
(user_dir / "hermes.service").write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
(system_dir / "hermes.service").write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
|
||||
results = gateway_cli._find_legacy_hermes_units()
|
||||
|
||||
scopes = sorted(is_system for _, _, is_system in results)
|
||||
assert scopes == [False, True]
|
||||
|
||||
def test_accepts_alternate_execstart_formats(self, tmp_path, monkeypatch):
|
||||
"""Older installs may have used different python invocations.
|
||||
|
||||
ExecStart variants we've seen in the wild:
|
||||
- python -m hermes_cli.main gateway run
|
||||
- python path/to/hermes_cli/main.py gateway run
|
||||
- hermes gateway run (direct binary)
|
||||
- python path/to/gateway/run.py
|
||||
"""
|
||||
user_dir, _ = self._setup_search_paths(tmp_path, monkeypatch)
|
||||
variants = [
|
||||
"ExecStart=/venv/bin/python -m hermes_cli.main gateway run --replace",
|
||||
"ExecStart=/venv/bin/python /opt/hermes/hermes_cli/main.py gateway run",
|
||||
"ExecStart=/usr/local/bin/hermes gateway run --replace",
|
||||
"ExecStart=/venv/bin/python /opt/hermes/gateway/run.py",
|
||||
]
|
||||
for i, execstart in enumerate(variants):
|
||||
name = f"hermes.service" if i == 0 else f"hermes.service" # same name
|
||||
# Test each variant fresh
|
||||
(user_dir / "hermes.service").write_text(
|
||||
f"[Unit]\nDescription=Old Hermes\n[Service]\n{execstart}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
results = gateway_cli._find_legacy_hermes_units()
|
||||
assert len(results) == 1, f"Variant {i} not detected: {execstart!r}"
|
||||
|
||||
def test_print_legacy_unit_warning_is_noop_when_empty(self, tmp_path, monkeypatch, capsys):
|
||||
self._setup_search_paths(tmp_path, monkeypatch)
|
||||
|
||||
gateway_cli.print_legacy_unit_warning()
|
||||
out = capsys.readouterr().out
|
||||
|
||||
assert out == ""
|
||||
|
||||
def test_print_legacy_unit_warning_shows_migration_hint(self, tmp_path, monkeypatch, capsys):
|
||||
user_dir, _ = self._setup_search_paths(tmp_path, monkeypatch)
|
||||
(user_dir / "hermes.service").write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
|
||||
gateway_cli.print_legacy_unit_warning()
|
||||
out = capsys.readouterr().out
|
||||
|
||||
assert "Legacy" in out
|
||||
assert "hermes.service" in out
|
||||
assert "hermes gateway migrate-legacy" in out
|
||||
|
||||
def test_handles_unreadable_unit_file_gracefully(self, tmp_path, monkeypatch):
|
||||
"""A permission error reading a unit file must not crash detection."""
|
||||
user_dir, _ = self._setup_search_paths(tmp_path, monkeypatch)
|
||||
unreadable = user_dir / "hermes.service"
|
||||
unreadable.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
# Simulate a read failure — monkeypatch Path.read_text to raise
|
||||
original_read_text = gateway_cli.Path.read_text
|
||||
|
||||
def raising_read_text(self, *args, **kwargs):
|
||||
if self == unreadable:
|
||||
raise PermissionError("simulated")
|
||||
return original_read_text(self, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(gateway_cli.Path, "read_text", raising_read_text)
|
||||
|
||||
# Should not raise
|
||||
results = gateway_cli._find_legacy_hermes_units()
|
||||
assert results == []
|
||||
|
||||
|
||||
class TestRemoveLegacyHermesUnits:
|
||||
"""Tests for remove_legacy_hermes_units (the migration action)."""
|
||||
|
||||
_OUR_UNIT_TEXT = (
|
||||
"[Unit]\nDescription=Hermes Gateway\n[Service]\n"
|
||||
"ExecStart=/usr/bin/python -m hermes_cli.main gateway run --replace\n"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _setup(tmp_path, monkeypatch, as_root=False):
|
||||
user_dir = tmp_path / "user"
|
||||
system_dir = tmp_path / "system"
|
||||
user_dir.mkdir()
|
||||
system_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"_legacy_unit_search_paths",
|
||||
lambda: [(False, user_dir), (True, system_dir)],
|
||||
)
|
||||
# Mock systemctl — return success for everything
|
||||
systemctl_calls: list[list[str]] = []
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
systemctl_calls.append(cmd)
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
|
||||
monkeypatch.setattr(gateway_cli.os, "geteuid", lambda: 0 if as_root else 1000)
|
||||
return user_dir, system_dir, systemctl_calls
|
||||
|
||||
def test_returns_zero_when_no_legacy_units(self, tmp_path, monkeypatch, capsys):
|
||||
self._setup(tmp_path, monkeypatch)
|
||||
|
||||
removed, remaining = gateway_cli.remove_legacy_hermes_units(interactive=False)
|
||||
|
||||
assert removed == 0
|
||||
assert remaining == []
|
||||
assert "No legacy" in capsys.readouterr().out
|
||||
|
||||
def test_dry_run_lists_without_removing(self, tmp_path, monkeypatch, capsys):
|
||||
user_dir, _, calls = self._setup(tmp_path, monkeypatch)
|
||||
legacy = user_dir / "hermes.service"
|
||||
legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
|
||||
removed, remaining = gateway_cli.remove_legacy_hermes_units(
|
||||
interactive=False, dry_run=True
|
||||
)
|
||||
|
||||
assert removed == 0
|
||||
assert remaining == [legacy]
|
||||
assert legacy.exists() # Not removed
|
||||
assert calls == [] # No systemctl invocations
|
||||
out = capsys.readouterr().out
|
||||
assert "dry-run" in out
|
||||
|
||||
def test_removes_user_scope_legacy_unit(self, tmp_path, monkeypatch, capsys):
|
||||
user_dir, _, calls = self._setup(tmp_path, monkeypatch)
|
||||
legacy = user_dir / "hermes.service"
|
||||
legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
|
||||
removed, remaining = gateway_cli.remove_legacy_hermes_units(interactive=False)
|
||||
|
||||
assert removed == 1
|
||||
assert remaining == []
|
||||
assert not legacy.exists()
|
||||
# Must have invoked stop → disable → daemon-reload on user scope
|
||||
cmds_joined = [" ".join(c) for c in calls]
|
||||
assert any("--user stop hermes.service" in c for c in cmds_joined)
|
||||
assert any("--user disable hermes.service" in c for c in cmds_joined)
|
||||
assert any("--user daemon-reload" in c for c in cmds_joined)
|
||||
|
||||
def test_system_scope_without_root_defers_removal(self, tmp_path, monkeypatch, capsys):
|
||||
_, system_dir, calls = self._setup(tmp_path, monkeypatch, as_root=False)
|
||||
legacy = system_dir / "hermes.service"
|
||||
legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
|
||||
removed, remaining = gateway_cli.remove_legacy_hermes_units(interactive=False)
|
||||
|
||||
assert removed == 0
|
||||
assert remaining == [legacy]
|
||||
assert legacy.exists() # Not removed — requires sudo
|
||||
out = capsys.readouterr().out
|
||||
assert "sudo hermes gateway migrate-legacy" in out
|
||||
|
||||
def test_system_scope_with_root_removes(self, tmp_path, monkeypatch, capsys):
|
||||
_, system_dir, calls = self._setup(tmp_path, monkeypatch, as_root=True)
|
||||
legacy = system_dir / "hermes.service"
|
||||
legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
|
||||
removed, remaining = gateway_cli.remove_legacy_hermes_units(interactive=False)
|
||||
|
||||
assert removed == 1
|
||||
assert remaining == []
|
||||
assert not legacy.exists()
|
||||
cmds_joined = [" ".join(c) for c in calls]
|
||||
# System-scope uses plain "systemctl" (no --user)
|
||||
assert any(
|
||||
c.startswith("systemctl stop hermes.service") for c in cmds_joined
|
||||
)
|
||||
assert any(
|
||||
c.startswith("systemctl disable hermes.service") for c in cmds_joined
|
||||
)
|
||||
|
||||
def test_removes_both_scopes_with_root(self, tmp_path, monkeypatch, capsys):
|
||||
user_dir, system_dir, _ = self._setup(tmp_path, monkeypatch, as_root=True)
|
||||
user_legacy = user_dir / "hermes.service"
|
||||
system_legacy = system_dir / "hermes.service"
|
||||
user_legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
system_legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
|
||||
removed, remaining = gateway_cli.remove_legacy_hermes_units(interactive=False)
|
||||
|
||||
assert removed == 2
|
||||
assert remaining == []
|
||||
assert not user_legacy.exists()
|
||||
assert not system_legacy.exists()
|
||||
|
||||
def test_does_not_touch_profile_units_during_migration(
|
||||
self, tmp_path, monkeypatch, capsys
|
||||
):
|
||||
"""Teknium's constraint: profile units (hermes-gateway-coder.service)
|
||||
must survive a migration call, even if we somehow include them in the
|
||||
search dir."""
|
||||
user_dir, _, _ = self._setup(tmp_path, monkeypatch, as_root=True)
|
||||
profile_unit = user_dir / "hermes-gateway-coder.service"
|
||||
profile_unit.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
default_unit = user_dir / "hermes-gateway.service"
|
||||
default_unit.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
|
||||
removed, remaining = gateway_cli.remove_legacy_hermes_units(interactive=False)
|
||||
|
||||
assert removed == 0
|
||||
assert remaining == []
|
||||
# Both the profile unit and the current default unit must survive
|
||||
assert profile_unit.exists()
|
||||
assert default_unit.exists()
|
||||
|
||||
def test_interactive_prompt_no_skips_removal(self, tmp_path, monkeypatch, capsys):
|
||||
"""When interactive=True and user answers no, no removal happens."""
|
||||
user_dir, _, _ = self._setup(tmp_path, monkeypatch)
|
||||
legacy = user_dir / "hermes.service"
|
||||
legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "prompt_yes_no", lambda *a, **k: False)
|
||||
|
||||
removed, remaining = gateway_cli.remove_legacy_hermes_units(interactive=True)
|
||||
|
||||
assert removed == 0
|
||||
assert remaining == [legacy]
|
||||
assert legacy.exists()
|
||||
|
||||
|
||||
class TestMigrateLegacyCommand:
|
||||
"""Tests for the `hermes gateway migrate-legacy` subcommand dispatch."""
|
||||
|
||||
def test_migrate_legacy_subparser_accepts_dry_run_and_yes(self):
|
||||
"""Verify the argparse subparser is registered and parses flags."""
|
||||
import hermes_cli.main as cli_main
|
||||
|
||||
parser = cli_main.build_parser() if hasattr(cli_main, "build_parser") else None
|
||||
# Fall back to calling main's setup helper if direct access isn't exposed
|
||||
# The key thing: the subparser must exist. We verify by constructing
|
||||
# a namespace through argparse directly — but if build_parser isn't
|
||||
# public, just confirm that `hermes gateway --help` shows it.
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
project_root = cli_main.PROJECT_ROOT if hasattr(cli_main, "PROJECT_ROOT") else None
|
||||
if project_root is None:
|
||||
import hermes_cli.gateway as gw
|
||||
project_root = gw.PROJECT_ROOT
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "hermes_cli.main", "gateway", "--help"],
|
||||
cwd=str(project_root),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=15,
|
||||
)
|
||||
assert result.returncode == 0
|
||||
assert "migrate-legacy" in result.stdout
|
||||
|
||||
def test_gateway_command_migrate_legacy_dispatches(
|
||||
self, tmp_path, monkeypatch, capsys
|
||||
):
|
||||
"""gateway_command(args) with subcmd='migrate-legacy' calls the helper."""
|
||||
called = {}
|
||||
|
||||
def fake_remove(interactive=True, dry_run=False):
|
||||
called["interactive"] = interactive
|
||||
called["dry_run"] = dry_run
|
||||
return 0, []
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "remove_legacy_hermes_units", fake_remove)
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
|
||||
args = SimpleNamespace(
|
||||
gateway_command="migrate-legacy", dry_run=False, yes=True
|
||||
)
|
||||
gateway_cli.gateway_command(args)
|
||||
|
||||
assert called == {"interactive": False, "dry_run": False}
|
||||
|
||||
def test_gateway_command_migrate_legacy_dry_run_passes_through(
|
||||
self, monkeypatch
|
||||
):
|
||||
called = {}
|
||||
|
||||
def fake_remove(interactive=True, dry_run=False):
|
||||
called["interactive"] = interactive
|
||||
called["dry_run"] = dry_run
|
||||
return 0, []
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "remove_legacy_hermes_units", fake_remove)
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
|
||||
args = SimpleNamespace(
|
||||
gateway_command="migrate-legacy", dry_run=True, yes=False
|
||||
)
|
||||
gateway_cli.gateway_command(args)
|
||||
|
||||
assert called == {"interactive": True, "dry_run": True}
|
||||
|
||||
def test_migrate_legacy_on_unsupported_platform_prints_message(
|
||||
self, monkeypatch, capsys
|
||||
):
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
|
||||
args = SimpleNamespace(
|
||||
gateway_command="migrate-legacy", dry_run=False, yes=True
|
||||
)
|
||||
gateway_cli.gateway_command(args)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "only applies to systemd" in out
|
||||
|
||||
|
||||
class TestSystemdInstallOffersLegacyRemoval:
|
||||
"""Verify that systemd_install prompts to remove legacy units first."""
|
||||
|
||||
def test_install_offers_removal_when_legacy_detected(
|
||||
self, tmp_path, monkeypatch, capsys
|
||||
):
|
||||
"""When legacy units exist, install flow should call the removal
|
||||
helper before writing the new unit."""
|
||||
remove_called = {}
|
||||
|
||||
def fake_remove(interactive=True, dry_run=False):
|
||||
remove_called["invoked"] = True
|
||||
remove_called["interactive"] = interactive
|
||||
return 1, []
|
||||
|
||||
# has_legacy_hermes_units must return True
|
||||
monkeypatch.setattr(gateway_cli, "has_legacy_hermes_units", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "remove_legacy_hermes_units", fake_remove)
|
||||
monkeypatch.setattr(gateway_cli, "print_legacy_unit_warning", lambda: None)
|
||||
# Answer "yes" to the legacy-removal prompt
|
||||
monkeypatch.setattr(gateway_cli, "prompt_yes_no", lambda *a, **k: True)
|
||||
|
||||
# Mock the rest of the install flow
|
||||
unit_path = tmp_path / "hermes-gateway.service"
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "get_systemd_unit_path", lambda system=False: unit_path
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"generate_systemd_unit",
|
||||
lambda system=False, run_as_user=None: "unit text\n",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli.subprocess,
|
||||
"run",
|
||||
lambda cmd, **kw: SimpleNamespace(returncode=0, stdout="", stderr=""),
|
||||
)
|
||||
monkeypatch.setattr(gateway_cli, "_ensure_linger_enabled", lambda: None)
|
||||
|
||||
gateway_cli.systemd_install()
|
||||
|
||||
assert remove_called.get("invoked") is True
|
||||
assert remove_called.get("interactive") is False # prompted elsewhere
|
||||
|
||||
def test_install_declines_legacy_removal_when_user_says_no(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""When legacy units exist and user declines, install still proceeds
|
||||
but doesn't touch them."""
|
||||
remove_called = {"invoked": False}
|
||||
|
||||
def fake_remove(interactive=True, dry_run=False):
|
||||
remove_called["invoked"] = True
|
||||
return 0, []
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "has_legacy_hermes_units", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "remove_legacy_hermes_units", fake_remove)
|
||||
monkeypatch.setattr(gateway_cli, "print_legacy_unit_warning", lambda: None)
|
||||
monkeypatch.setattr(gateway_cli, "prompt_yes_no", lambda *a, **k: False)
|
||||
|
||||
unit_path = tmp_path / "hermes-gateway.service"
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "get_systemd_unit_path", lambda system=False: unit_path
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"generate_systemd_unit",
|
||||
lambda system=False, run_as_user=None: "unit text\n",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli.subprocess,
|
||||
"run",
|
||||
lambda cmd, **kw: SimpleNamespace(returncode=0, stdout="", stderr=""),
|
||||
)
|
||||
monkeypatch.setattr(gateway_cli, "_ensure_linger_enabled", lambda: None)
|
||||
|
||||
gateway_cli.systemd_install()
|
||||
|
||||
# Helper must NOT have been called
|
||||
assert remove_called["invoked"] is False
|
||||
# New unit should still have been written
|
||||
assert unit_path.exists()
|
||||
assert unit_path.read_text() == "unit text\n"
|
||||
|
||||
def test_install_skips_legacy_check_when_none_present(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""No legacy → no prompt, no helper call."""
|
||||
prompt_called = {"count": 0}
|
||||
|
||||
def counting_prompt(*a, **k):
|
||||
prompt_called["count"] += 1
|
||||
return True
|
||||
|
||||
remove_called = {"invoked": False}
|
||||
|
||||
def fake_remove(interactive=True, dry_run=False):
|
||||
remove_called["invoked"] = True
|
||||
return 0, []
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "has_legacy_hermes_units", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "remove_legacy_hermes_units", fake_remove)
|
||||
monkeypatch.setattr(gateway_cli, "prompt_yes_no", counting_prompt)
|
||||
|
||||
unit_path = tmp_path / "hermes-gateway.service"
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "get_systemd_unit_path", lambda system=False: unit_path
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"generate_systemd_unit",
|
||||
lambda system=False, run_as_user=None: "unit text\n",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli.subprocess,
|
||||
"run",
|
||||
lambda cmd, **kw: SimpleNamespace(returncode=0, stdout="", stderr=""),
|
||||
)
|
||||
monkeypatch.setattr(gateway_cli, "_ensure_linger_enabled", lambda: None)
|
||||
|
||||
gateway_cli.systemd_install()
|
||||
|
||||
assert prompt_called["count"] == 0
|
||||
assert remove_called["invoked"] is False
|
||||
|
|
|
|||
|
|
@ -178,10 +178,6 @@ class TestGeminiContextLength:
|
|||
ctx = get_model_context_length("gemma-4-31b-it", provider="gemini")
|
||||
assert ctx == 256000
|
||||
|
||||
def test_gemma_4_26b_context(self):
|
||||
ctx = get_model_context_length("gemma-4-26b-it", provider="gemini")
|
||||
assert ctx == 256000
|
||||
|
||||
def test_gemini_3_context(self):
|
||||
ctx = get_model_context_length("gemini-3.1-pro-preview", provider="gemini")
|
||||
assert ctx == 1048576
|
||||
|
|
|
|||
|
|
@ -403,7 +403,8 @@ class TestValidateFormatChecks:
|
|||
|
||||
def test_no_slash_model_rejected_if_not_in_api(self):
|
||||
result = _validate("gpt-5.4", api_models=["openai/gpt-5.4"])
|
||||
assert result["accepted"] is True
|
||||
assert result["accepted"] is False
|
||||
assert result["persist"] is False
|
||||
assert "not found" in result["message"]
|
||||
|
||||
|
||||
|
|
@ -429,10 +430,10 @@ class TestValidateApiFound:
|
|||
# -- validate — API not found ------------------------------------------------
|
||||
|
||||
class TestValidateApiNotFound:
|
||||
def test_model_not_in_api_accepted_with_warning(self):
|
||||
def test_model_not_in_api_rejected_with_guidance(self):
|
||||
result = _validate("anthropic/claude-nonexistent")
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
assert result["accepted"] is False
|
||||
assert result["persist"] is False
|
||||
assert "not found" in result["message"]
|
||||
|
||||
def test_warning_includes_suggestions(self):
|
||||
|
|
@ -456,30 +457,29 @@ class TestValidateApiNotFound:
|
|||
assert "not found" in result["message"]
|
||||
|
||||
|
||||
# -- validate — API unreachable — accept and persist everything ----------------
|
||||
# -- validate — API unreachable — reject with guidance ----------------
|
||||
|
||||
class TestValidateApiFallback:
|
||||
def test_any_model_accepted_when_api_down(self):
|
||||
def test_any_model_rejected_when_api_down(self):
|
||||
result = _validate("anthropic/claude-opus-4.6", api_models=None)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
assert result["accepted"] is False
|
||||
assert result["persist"] is False
|
||||
|
||||
def test_unknown_model_also_accepted_when_api_down(self):
|
||||
"""No hardcoded catalog gatekeeping — accept, persist, and warn."""
|
||||
def test_unknown_model_also_rejected_when_api_down(self):
|
||||
result = _validate("anthropic/claude-next-gen", api_models=None)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
assert result["accepted"] is False
|
||||
assert result["persist"] is False
|
||||
assert "could not reach" in result["message"].lower()
|
||||
|
||||
def test_zai_model_accepted_when_api_down(self):
|
||||
def test_zai_model_rejected_when_api_down(self):
|
||||
result = _validate("glm-5", provider="zai", api_models=None)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
assert result["accepted"] is False
|
||||
assert result["persist"] is False
|
||||
|
||||
def test_unknown_provider_accepted_when_api_down(self):
|
||||
def test_unknown_provider_rejected_when_api_down(self):
|
||||
result = _validate("some-model", provider="totally-unknown", api_models=None)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
assert result["accepted"] is False
|
||||
assert result["persist"] is False
|
||||
|
||||
def test_custom_endpoint_warns_with_probed_url_and_v1_hint(self):
|
||||
with patch(
|
||||
|
|
@ -499,8 +499,8 @@ class TestValidateApiFallback:
|
|||
base_url="http://localhost:8000",
|
||||
)
|
||||
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
assert result["accepted"] is False
|
||||
assert result["persist"] is False
|
||||
assert "http://localhost:8000/v1/models" in result["message"]
|
||||
assert "http://localhost:8000/v1" in result["message"]
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ def test_opencode_go_appears_when_api_key_set():
|
|||
opencode_go = next((p for p in providers if p["slug"] == "opencode-go"), None)
|
||||
|
||||
assert opencode_go is not None, "opencode-go should appear when OPENCODE_GO_API_KEY is set"
|
||||
assert opencode_go["models"] == ["glm-5.1", "glm-5", "kimi-k2.5", "mimo-v2-pro", "mimo-v2-omni", "minimax-m2.7", "minimax-m2.5"]
|
||||
assert opencode_go["models"] == ["kimi-k2.5", "glm-5.1", "glm-5", "mimo-v2-pro", "mimo-v2-omni", "minimax-m2.7", "minimax-m2.5"]
|
||||
# opencode-go can appear as "built-in" (from PROVIDER_TO_MODELS_DEV when
|
||||
# models.dev is reachable) or "hermes" (from HERMES_OVERLAYS fallback when
|
||||
# the API is unavailable, e.g. in CI).
|
||||
|
|
|
|||
|
|
@ -799,35 +799,30 @@ class TestEdgeCases:
|
|||
assert default.skill_count == 0
|
||||
|
||||
def test_gateway_running_check_with_pid_file(self, profile_env):
|
||||
"""Verify _check_gateway_running reads pid file and probes os.kill."""
|
||||
"""Verify _check_gateway_running uses the shared gateway PID validator."""
|
||||
from hermes_cli.profiles import _check_gateway_running
|
||||
tmp_path = profile_env
|
||||
default_home = tmp_path / ".hermes"
|
||||
|
||||
# No pid file -> not running
|
||||
assert _check_gateway_running(default_home) is False
|
||||
|
||||
# Write a PID file with a JSON payload
|
||||
pid_file = default_home / "gateway.pid"
|
||||
pid_file.write_text(json.dumps({"pid": 99999}))
|
||||
|
||||
# os.kill(99999, 0) should raise ProcessLookupError -> not running
|
||||
assert _check_gateway_running(default_home) is False
|
||||
|
||||
# Mock os.kill to simulate a running process
|
||||
with patch("os.kill", return_value=None):
|
||||
with patch("gateway.status.get_running_pid", return_value=99999) as mock_get_running_pid:
|
||||
assert _check_gateway_running(default_home) is True
|
||||
mock_get_running_pid.assert_called_once_with(
|
||||
default_home / "gateway.pid",
|
||||
cleanup_stale=False,
|
||||
)
|
||||
|
||||
def test_gateway_running_check_plain_pid(self, profile_env):
|
||||
"""Pid file containing just a number (legacy format)."""
|
||||
"""Shared PID validator returning None means the profile is not running."""
|
||||
from hermes_cli.profiles import _check_gateway_running
|
||||
tmp_path = profile_env
|
||||
default_home = tmp_path / ".hermes"
|
||||
pid_file = default_home / "gateway.pid"
|
||||
pid_file.write_text("99999")
|
||||
|
||||
with patch("os.kill", return_value=None):
|
||||
assert _check_gateway_running(default_home) is True
|
||||
with patch("gateway.status.get_running_pid", return_value=None) as mock_get_running_pid:
|
||||
assert _check_gateway_running(default_home) is False
|
||||
mock_get_running_pid.assert_called_once_with(
|
||||
default_home / "gateway.pid",
|
||||
cleanup_stale=False,
|
||||
)
|
||||
|
||||
def test_profile_name_boundary_single_char(self):
|
||||
"""Single alphanumeric character is valid."""
|
||||
|
|
|
|||
53
tests/hermes_cli/test_tui_npm_install.py
Normal file
53
tests/hermes_cli/test_tui_npm_install.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""_tui_need_npm_install: auto npm when lockfile ahead of node_modules."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def main_mod():
|
||||
import hermes_cli.main as m
|
||||
|
||||
return m
|
||||
|
||||
|
||||
def _touch_ink(root: Path) -> None:
|
||||
ink = root / "node_modules" / "@hermes" / "ink" / "package.json"
|
||||
ink.parent.mkdir(parents=True, exist_ok=True)
|
||||
ink.write_text("{}")
|
||||
|
||||
|
||||
def test_need_install_when_ink_missing(tmp_path: Path, main_mod) -> None:
|
||||
(tmp_path / "package-lock.json").write_text("{}")
|
||||
assert main_mod._tui_need_npm_install(tmp_path) is True
|
||||
|
||||
|
||||
def test_need_install_when_lock_newer_than_marker(tmp_path: Path, main_mod) -> None:
|
||||
_touch_ink(tmp_path)
|
||||
(tmp_path / "package-lock.json").write_text("{}")
|
||||
(tmp_path / "node_modules" / ".package-lock.json").write_text("{}")
|
||||
os.utime(tmp_path / "package-lock.json", (200, 200))
|
||||
os.utime(tmp_path / "node_modules" / ".package-lock.json", (100, 100))
|
||||
assert main_mod._tui_need_npm_install(tmp_path) is True
|
||||
|
||||
|
||||
def test_no_install_when_lock_older_than_marker(tmp_path: Path, main_mod) -> None:
|
||||
_touch_ink(tmp_path)
|
||||
(tmp_path / "package-lock.json").write_text("{}")
|
||||
(tmp_path / "node_modules" / ".package-lock.json").write_text("{}")
|
||||
os.utime(tmp_path / "package-lock.json", (100, 100))
|
||||
os.utime(tmp_path / "node_modules" / ".package-lock.json", (200, 200))
|
||||
assert main_mod._tui_need_npm_install(tmp_path) is False
|
||||
|
||||
|
||||
def test_need_install_when_marker_missing(tmp_path: Path, main_mod) -> None:
|
||||
_touch_ink(tmp_path)
|
||||
(tmp_path / "package-lock.json").write_text("{}")
|
||||
assert main_mod._tui_need_npm_install(tmp_path) is True
|
||||
|
||||
|
||||
def test_no_install_without_lockfile_when_ink_present(tmp_path: Path, main_mod) -> None:
|
||||
_touch_ink(tmp_path)
|
||||
assert main_mod._tui_need_npm_install(tmp_path) is False
|
||||
121
tests/hermes_cli/test_tui_resume_flow.py
Normal file
121
tests/hermes_cli/test_tui_resume_flow.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
from argparse import Namespace
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _args(**overrides):
|
||||
base = {
|
||||
"continue_last": None,
|
||||
"resume": None,
|
||||
"tui": True,
|
||||
}
|
||||
base.update(overrides)
|
||||
return Namespace(**base)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def main_mod(monkeypatch):
|
||||
import hermes_cli.main as mod
|
||||
|
||||
monkeypatch.setattr(mod, "_has_any_provider_configured", lambda: True)
|
||||
return mod
|
||||
|
||||
|
||||
def test_cmd_chat_tui_continue_uses_latest_tui_session(monkeypatch, main_mod):
|
||||
calls = []
|
||||
captured = {}
|
||||
|
||||
def fake_resolve_last(source="cli"):
|
||||
calls.append(source)
|
||||
return "20260408_235959_a1b2c3" if source == "tui" else None
|
||||
|
||||
def fake_launch(resume_session_id=None, tui_dev=False):
|
||||
captured["resume"] = resume_session_id
|
||||
raise SystemExit(0)
|
||||
|
||||
monkeypatch.setattr(main_mod, "_resolve_last_session", fake_resolve_last)
|
||||
monkeypatch.setattr(main_mod, "_resolve_session_by_name_or_id", lambda val: val)
|
||||
monkeypatch.setattr(main_mod, "_launch_tui", fake_launch)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
main_mod.cmd_chat(_args(continue_last=True))
|
||||
|
||||
assert calls == ["tui"]
|
||||
assert captured["resume"] == "20260408_235959_a1b2c3"
|
||||
|
||||
|
||||
def test_cmd_chat_tui_continue_falls_back_to_latest_cli_session(monkeypatch, main_mod):
|
||||
calls = []
|
||||
captured = {}
|
||||
|
||||
def fake_resolve_last(source="cli"):
|
||||
calls.append(source)
|
||||
if source == "tui":
|
||||
return None
|
||||
if source == "cli":
|
||||
return "20260408_235959_d4e5f6"
|
||||
return None
|
||||
|
||||
def fake_launch(resume_session_id=None, tui_dev=False):
|
||||
captured["resume"] = resume_session_id
|
||||
raise SystemExit(0)
|
||||
|
||||
monkeypatch.setattr(main_mod, "_resolve_last_session", fake_resolve_last)
|
||||
monkeypatch.setattr(main_mod, "_resolve_session_by_name_or_id", lambda val: val)
|
||||
monkeypatch.setattr(main_mod, "_launch_tui", fake_launch)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
main_mod.cmd_chat(_args(continue_last=True))
|
||||
|
||||
assert calls == ["tui", "cli"]
|
||||
assert captured["resume"] == "20260408_235959_d4e5f6"
|
||||
|
||||
|
||||
def test_cmd_chat_tui_resume_resolves_title_before_launch(monkeypatch, main_mod):
|
||||
captured = {}
|
||||
|
||||
def fake_launch(resume_session_id=None, tui_dev=False):
|
||||
captured["resume"] = resume_session_id
|
||||
raise SystemExit(0)
|
||||
|
||||
monkeypatch.setattr(main_mod, "_resolve_session_by_name_or_id", lambda val: "20260409_000000_aa11bb")
|
||||
monkeypatch.setattr(main_mod, "_launch_tui", fake_launch)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
main_mod.cmd_chat(_args(resume="my t0p session"))
|
||||
|
||||
assert captured["resume"] == "20260409_000000_aa11bb"
|
||||
|
||||
|
||||
def test_print_tui_exit_summary_includes_resume_and_token_totals(monkeypatch, capsys):
|
||||
import hermes_cli.main as main_mod
|
||||
|
||||
class _FakeDB:
|
||||
def get_session(self, session_id):
|
||||
assert session_id == "20260409_000001_abc123"
|
||||
return {
|
||||
"message_count": 2,
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 6,
|
||||
"cache_read_tokens": 2,
|
||||
"cache_write_tokens": 2,
|
||||
"reasoning_tokens": 1,
|
||||
}
|
||||
|
||||
def get_session_title(self, _session_id):
|
||||
return "demo title"
|
||||
|
||||
def close(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setitem(sys.modules, "hermes_state", types.SimpleNamespace(SessionDB=lambda: _FakeDB()))
|
||||
|
||||
main_mod._print_tui_exit_summary("20260409_000001_abc123")
|
||||
out = capsys.readouterr().out
|
||||
|
||||
assert "Resume this session with:" in out
|
||||
assert "hermes --tui --resume 20260409_000001_abc123" in out
|
||||
assert 'hermes --tui -c "demo title"' in out
|
||||
assert "Tokens: 21 (in 10, out 6, cache 4, reasoning 1)" in out
|
||||
|
|
@ -13,9 +13,29 @@ from unittest.mock import patch, MagicMock
|
|||
import pytest
|
||||
|
||||
import hermes_cli.gateway as gateway_cli
|
||||
import hermes_cli.main as cli_main
|
||||
from hermes_cli.main import cmd_update
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skip the real-time sleeps inside cmd_update's restart-verification path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _no_restart_verify_sleep(monkeypatch):
|
||||
"""hermes_cli/main.py uses time.sleep(3) after systemctl restart to
|
||||
verify the service survived. Tests mock subprocess.run — nothing
|
||||
actually restarts — so the 3s wait is dead time.
|
||||
|
||||
main.py does ``import time as _time`` at both module level (line 167)
|
||||
and inside functions (lines 3281, 4384, 4401). Patching the global
|
||||
``time.sleep`` affects only the duration of this test.
|
||||
"""
|
||||
import time as _real_time
|
||||
monkeypatch.setattr(_real_time, "sleep", lambda *_a, **_k: None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -915,3 +935,183 @@ class TestGatewayModeWritesExitCodeEarly:
|
|||
assert exit_code_existed_at_restart, "systemctl restart was never called"
|
||||
assert exit_code_existed_at_restart[0] is True, \
|
||||
".update_exit_code must exist BEFORE systemctl restart (cgroup kill race)"
|
||||
|
||||
|
||||
class TestCmdUpdateLegacyGatewayWarning:
|
||||
"""Tests for the legacy hermes.service warning printed by `hermes update`.
|
||||
|
||||
Users who installed Hermes before the service rename often have a
|
||||
dormant ``hermes.service`` that starts flap-fighting the current
|
||||
``hermes-gateway.service`` after PR #5646. Every ``hermes update``
|
||||
should remind them to run ``hermes gateway migrate-legacy`` until
|
||||
they do.
|
||||
"""
|
||||
|
||||
_OUR_UNIT_TEXT = (
|
||||
"[Unit]\nDescription=Hermes Gateway\n[Service]\n"
|
||||
"ExecStart=/usr/bin/python -m hermes_cli.main gateway run --replace\n"
|
||||
)
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_prints_legacy_warning_when_detected(
|
||||
self, mock_run, _mock_which, mock_args, capsys, tmp_path, monkeypatch,
|
||||
):
|
||||
"""Legacy units present → warning in update output with migrate command."""
|
||||
user_dir = tmp_path / "user"
|
||||
system_dir = tmp_path / "system"
|
||||
user_dir.mkdir()
|
||||
system_dir.mkdir()
|
||||
legacy_path = user_dir / "hermes.service"
|
||||
legacy_path.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"_legacy_unit_search_paths",
|
||||
lambda: [(False, user_dir), (True, system_dir)],
|
||||
)
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(commit_count="3")
|
||||
|
||||
with patch.object(gateway_cli, "find_gateway_pids", return_value=[]):
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
assert "Legacy Hermes gateway unit(s) detected" in captured
|
||||
assert "hermes.service" in captured
|
||||
assert "hermes gateway migrate-legacy" in captured
|
||||
assert "(user scope)" in captured
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_silent_when_no_legacy_units(
|
||||
self, mock_run, _mock_which, mock_args, capsys, tmp_path, monkeypatch,
|
||||
):
|
||||
"""No legacy units → no warning printed."""
|
||||
user_dir = tmp_path / "user"
|
||||
system_dir = tmp_path / "system"
|
||||
user_dir.mkdir()
|
||||
system_dir.mkdir()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"_legacy_unit_search_paths",
|
||||
lambda: [(False, user_dir), (True, system_dir)],
|
||||
)
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(commit_count="3")
|
||||
|
||||
with patch.object(gateway_cli, "find_gateway_pids", return_value=[]):
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
assert "Legacy Hermes gateway" not in captured
|
||||
assert "migrate-legacy" not in captured
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_does_not_flag_profile_units(
|
||||
self, mock_run, _mock_which, mock_args, capsys, tmp_path, monkeypatch,
|
||||
):
|
||||
"""Profile units (hermes-gateway-coder.service) must not trigger the warning.
|
||||
|
||||
This is the core safety invariant: the legacy allowlist is
|
||||
``hermes.service`` only, no globs.
|
||||
"""
|
||||
user_dir = tmp_path / "user"
|
||||
system_dir = tmp_path / "system"
|
||||
user_dir.mkdir()
|
||||
system_dir.mkdir()
|
||||
# Drop a profile unit that an over-eager glob would match
|
||||
(user_dir / "hermes-gateway-coder.service").write_text(
|
||||
self._OUR_UNIT_TEXT, encoding="utf-8"
|
||||
)
|
||||
(user_dir / "hermes-gateway.service").write_text(
|
||||
self._OUR_UNIT_TEXT, encoding="utf-8"
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"_legacy_unit_search_paths",
|
||||
lambda: [(False, user_dir), (True, system_dir)],
|
||||
)
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(commit_count="3")
|
||||
|
||||
with patch.object(gateway_cli, "find_gateway_pids", return_value=[]):
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
assert "Legacy Hermes gateway" not in captured
|
||||
assert "hermes-gateway-coder.service" not in captured # not flagged
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_skips_legacy_check_on_non_systemd_platforms(
|
||||
self, mock_run, _mock_which, mock_args, capsys, tmp_path, monkeypatch,
|
||||
):
|
||||
"""macOS / Windows / Termux — skip check entirely since the rename
|
||||
is systemd-specific."""
|
||||
user_dir = tmp_path / "user"
|
||||
user_dir.mkdir()
|
||||
# Put a file that WOULD match if the check ran
|
||||
(user_dir / "hermes.service").write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"_legacy_unit_search_paths",
|
||||
lambda: [(False, user_dir), (True, tmp_path / "system")],
|
||||
)
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: False)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(
|
||||
commit_count="3", launchctl_loaded=False,
|
||||
)
|
||||
|
||||
with patch.object(gateway_cli, "find_gateway_pids", return_value=[]):
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
# Must not print the warning on non-systemd platforms
|
||||
assert "Legacy Hermes gateway" not in captured
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_lists_system_scope_unit_with_sudo_hint(
|
||||
self, mock_run, _mock_which, mock_args, capsys, tmp_path, monkeypatch,
|
||||
):
|
||||
"""System-scope legacy units need sudo — the warning must point that out."""
|
||||
user_dir = tmp_path / "user"
|
||||
system_dir = tmp_path / "system"
|
||||
user_dir.mkdir()
|
||||
system_dir.mkdir()
|
||||
(system_dir / "hermes.service").write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"_legacy_unit_search_paths",
|
||||
lambda: [(False, user_dir), (True, system_dir)],
|
||||
)
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(commit_count="3")
|
||||
|
||||
with patch.object(gateway_cli, "find_gateway_pids", return_value=[]):
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
assert "Legacy Hermes gateway" in captured
|
||||
assert "(system scope)" in captured
|
||||
assert "sudo" in captured
|
||||
|
|
|
|||
|
|
@ -31,6 +31,31 @@ def _isolate_env(tmp_path, monkeypatch):
|
|||
monkeypatch.delenv("RETAINDB_PROJECT", raising=False)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cap_retaindb_sleeps(monkeypatch):
|
||||
"""Cap production-code sleeps so background-thread tests run fast.
|
||||
|
||||
The retaindb ``_WriteQueue._flush_row`` does ``time.sleep(2)`` after
|
||||
errors. Across multiple tests that trigger the retry path, that adds
|
||||
up. Cap the module's bound ``time.sleep`` to 0.05s — tests don't care
|
||||
about the exact retry delay, only that it happens. The test file's
|
||||
own ``time.sleep`` stays real since it uses a different reference.
|
||||
"""
|
||||
try:
|
||||
from plugins.memory import retaindb as _retaindb
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
real_sleep = _retaindb.time.sleep
|
||||
|
||||
def _capped_sleep(seconds):
|
||||
return real_sleep(min(float(seconds), 0.05))
|
||||
|
||||
import types as _types
|
||||
fake_time = _types.SimpleNamespace(sleep=_capped_sleep, time=_retaindb.time.time)
|
||||
monkeypatch.setattr(_retaindb, "time", fake_time)
|
||||
|
||||
|
||||
# We need the repo root on sys.path so the plugin can import agent.memory_provider
|
||||
import sys
|
||||
_repo_root = str(Path(__file__).resolve().parents[2])
|
||||
|
|
@ -130,16 +155,18 @@ class TestWriteQueue:
|
|||
def test_enqueue_creates_row(self, tmp_path):
|
||||
q, client, db_path = self._make_queue(tmp_path)
|
||||
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
|
||||
# Give the writer thread a moment to process
|
||||
time.sleep(1)
|
||||
# shutdown() blocks until the writer thread drains the queue — no need
|
||||
# to pre-sleep (the old 1s sleep was a just-in-case wait, but shutdown
|
||||
# does the right thing).
|
||||
q.shutdown()
|
||||
# If ingest succeeded, the row should be deleted
|
||||
client.ingest_session.assert_called_once()
|
||||
|
||||
def test_enqueue_persists_to_sqlite(self, tmp_path):
|
||||
client = MagicMock()
|
||||
# Make ingest hang so the row stays in SQLite
|
||||
client.ingest_session = MagicMock(side_effect=lambda *a, **kw: time.sleep(5))
|
||||
# Make ingest slow so the row is still in SQLite when we peek.
|
||||
# 0.5s is plenty — the test just needs the flush to still be in-flight.
|
||||
client.ingest_session = MagicMock(side_effect=lambda *a, **kw: time.sleep(0.5))
|
||||
db_path = tmp_path / "test_queue.db"
|
||||
q = _WriteQueue(client, db_path)
|
||||
q.enqueue("user1", "sess1", [{"role": "user", "content": "test"}])
|
||||
|
|
@ -154,8 +181,7 @@ class TestWriteQueue:
|
|||
def test_flush_deletes_row_on_success(self, tmp_path):
|
||||
q, client, db_path = self._make_queue(tmp_path)
|
||||
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
|
||||
time.sleep(1)
|
||||
q.shutdown()
|
||||
q.shutdown() # blocks until drain
|
||||
# Row should be gone
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
rows = conn.execute("SELECT COUNT(*) FROM pending").fetchone()[0]
|
||||
|
|
@ -168,14 +194,20 @@ class TestWriteQueue:
|
|||
db_path = tmp_path / "test_queue.db"
|
||||
q = _WriteQueue(client, db_path)
|
||||
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
|
||||
time.sleep(3) # Allow retry + sleep(2) in _flush_row
|
||||
# Poll for the error to be recorded (max 2s), instead of a fixed 3s wait.
|
||||
deadline = time.time() + 2.0
|
||||
last_error = None
|
||||
while time.time() < deadline:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
row = conn.execute("SELECT last_error FROM pending").fetchone()
|
||||
conn.close()
|
||||
if row and row[0]:
|
||||
last_error = row[0]
|
||||
break
|
||||
time.sleep(0.05)
|
||||
q.shutdown()
|
||||
# Row should still exist with error recorded
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
row = conn.execute("SELECT last_error FROM pending").fetchone()
|
||||
conn.close()
|
||||
assert row is not None
|
||||
assert "API down" in row[0]
|
||||
assert last_error is not None
|
||||
assert "API down" in last_error
|
||||
|
||||
def test_thread_local_connection_reuse(self, tmp_path):
|
||||
q, _, _ = self._make_queue(tmp_path)
|
||||
|
|
@ -193,14 +225,27 @@ class TestWriteQueue:
|
|||
client1.ingest_session = MagicMock(side_effect=RuntimeError("fail"))
|
||||
q1 = _WriteQueue(client1, db_path)
|
||||
q1.enqueue("user1", "sess1", [{"role": "user", "content": "lost turn"}])
|
||||
time.sleep(3)
|
||||
# Wait until the error is recorded (poll with short interval).
|
||||
deadline = time.time() + 2.0
|
||||
while time.time() < deadline:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
row = conn.execute("SELECT last_error FROM pending").fetchone()
|
||||
conn.close()
|
||||
if row and row[0]:
|
||||
break
|
||||
time.sleep(0.05)
|
||||
q1.shutdown()
|
||||
|
||||
# Now create a new queue — it should replay the pending rows
|
||||
client2 = MagicMock()
|
||||
client2.ingest_session = MagicMock(return_value={"status": "ok"})
|
||||
q2 = _WriteQueue(client2, db_path)
|
||||
time.sleep(2)
|
||||
# Poll for the replay to happen.
|
||||
deadline = time.time() + 2.0
|
||||
while time.time() < deadline:
|
||||
if client2.ingest_session.called:
|
||||
break
|
||||
time.sleep(0.05)
|
||||
q2.shutdown()
|
||||
|
||||
# The replayed row should have been ingested via client2
|
||||
|
|
|
|||
34
tests/run_agent/conftest.py
Normal file
34
tests/run_agent/conftest.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
"""Fast-path fixtures shared across tests/run_agent/.
|
||||
|
||||
Many tests in this directory exercise the retry/backoff paths in the
|
||||
agent loop. Production code uses ``jittered_backoff(base_delay=5.0)``
|
||||
with a ``while time.time() < sleep_end`` loop — a single retry test
|
||||
spends 5+ seconds of real wall-clock time on backoff waits.
|
||||
|
||||
Mocking ``jittered_backoff`` to return 0.0 collapses the while-loop
|
||||
to a no-op (``time.time() < time.time() + 0`` is false immediately),
|
||||
which handles the most common case without touching ``time.sleep``.
|
||||
|
||||
We deliberately DO NOT mock ``time.sleep`` here — some tests
|
||||
(test_interrupt_propagation, test_primary_runtime_restore, etc.) use
|
||||
the real ``time.sleep`` for threading coordination or assert that it
|
||||
was called with specific values. Tests that want to additionally
|
||||
fast-path direct ``time.sleep(N)`` calls in production code should
|
||||
monkeypatch ``run_agent.time.sleep`` locally (see
|
||||
``test_anthropic_error_handling.py`` for the pattern).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _fast_retry_backoff(monkeypatch):
|
||||
"""Short-circuit retry backoff for all tests in this directory."""
|
||||
try:
|
||||
import run_agent
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0)
|
||||
|
|
@ -32,6 +32,7 @@ class TestGeneric400Heuristic:
|
|||
from run_agent import AIAgent
|
||||
a = AIAgent(
|
||||
api_key="test-key-12345",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
|
|
|||
|
|
@ -19,6 +19,24 @@ import pytest
|
|||
|
||||
from agent.context_compressor import SUMMARY_PREFIX
|
||||
from run_agent import AIAgent
|
||||
import run_agent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fast backoff for compression retry tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _no_compression_sleep(monkeypatch):
|
||||
"""Short-circuit the 2s time.sleep between compression retries.
|
||||
|
||||
Production code has ``time.sleep(2)`` in multiple places after a 413/context
|
||||
compression, for rate-limit smoothing. Tests assert behavior, not timing.
|
||||
"""
|
||||
import time as _time
|
||||
monkeypatch.setattr(_time, "sleep", lambda *_a, **_k: None)
|
||||
monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -69,6 +87,7 @@ def agent():
|
|||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ class TestFlushDeduplication:
|
|||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
session_db=session_db,
|
||||
|
|
@ -271,6 +273,8 @@ class TestFlushIdxInit:
|
|||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -283,6 +287,8 @@ class TestFlushIdxInit:
|
|||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
|
|||
|
|
@ -27,6 +27,39 @@ from gateway.config import Platform
|
|||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fast backoff for tests that exercise the retry loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _no_backoff_wait(monkeypatch):
|
||||
"""Short-circuit retry backoff so tests don't block on real wall-clock waits.
|
||||
|
||||
The production code uses jittered_backoff() with a 5s base delay plus a
|
||||
tight time.sleep(0.2) loop. Without this patch, each 429/500/529 retry
|
||||
test burns ~10s of real time on CI — across six tests that's ~60s for
|
||||
behavior we're not asserting against timing.
|
||||
|
||||
Tests assert retry counts and final results, never wait durations.
|
||||
"""
|
||||
import asyncio as _asyncio
|
||||
import time as _time
|
||||
|
||||
monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0)
|
||||
monkeypatch.setattr(_time, "sleep", lambda *_a, **_k: None)
|
||||
|
||||
# Also fast-path asyncio.sleep — the gateway's _run_agent path has
|
||||
# several await asyncio.sleep(...) calls that add real wall-clock time.
|
||||
_real_asyncio_sleep = _asyncio.sleep
|
||||
|
||||
async def _fast_sleep(delay=0, *args, **kwargs):
|
||||
# Yield to the event loop but skip the actual delay.
|
||||
await _real_asyncio_sleep(0)
|
||||
|
||||
monkeypatch.setattr(_asyncio, "sleep", _fast_sleep)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ class TestFlushAfterCompression:
|
|||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
session_db=session_db,
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ from run_agent import AIAgent
|
|||
def test_create_openai_client_does_not_mutate_input_kwargs(mock_openai):
|
||||
mock_openai.return_value = MagicMock()
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
|
|||
|
|
@ -23,6 +23,8 @@ from run_agent import AIAgent
|
|||
|
||||
def _make_agent():
|
||||
return AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,24 @@ from unittest.mock import MagicMock, patch, call
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_runtime_provider(monkeypatch):
|
||||
"""run_job calls resolve_runtime_provider which can try real network
|
||||
auto-detection (~4s of socket timeouts in hermetic CI). Mock it out
|
||||
since these tests don't care about provider resolution — the agent
|
||||
is mocked too."""
|
||||
import hermes_cli.runtime_provider as rp
|
||||
def _fake_resolve(*args, **kwargs):
|
||||
return {
|
||||
"provider": "openrouter",
|
||||
"api_key": "test-key",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"model": "test/model",
|
||||
"api_mode": "chat_completions",
|
||||
}
|
||||
monkeypatch.setattr(rp, "resolve_runtime_provider", _fake_resolve)
|
||||
|
||||
|
||||
class TestCronJobCleanup:
|
||||
"""cron/scheduler.py — end_session + close in the finally block."""
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,16 @@ from unittest.mock import MagicMock, patch
|
|||
import pytest
|
||||
|
||||
from run_agent import AIAgent
|
||||
import run_agent
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _no_fallback_wait(monkeypatch):
|
||||
"""Short-circuit time.sleep in fallback/recovery paths so tests don't
|
||||
block on the ``min(3 + retry_count, 8)`` wait before a primary retry."""
|
||||
import time as _time
|
||||
monkeypatch.setattr(_time, "sleep", lambda *_a, **_k: None)
|
||||
monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0)
|
||||
|
||||
|
||||
def _make_tool_defs(*names: str) -> list:
|
||||
|
|
@ -36,6 +46,7 @@ def _make_agent(fallback_model=None):
|
|||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ def test_plugin_engine_gets_context_length_on_init():
|
|||
|
||||
agent = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
|
@ -75,6 +76,7 @@ def test_plugin_engine_update_model_args():
|
|||
agent = AIAgent(
|
||||
model="openrouter/auto",
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ def _make_agent(fallback_model=None):
|
|||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
|
|
|||
|
|
@ -60,6 +60,9 @@ def _make_agent(monkeypatch, provider, api_mode="chat_completions", base_url="ht
|
|||
)
|
||||
if model:
|
||||
kwargs["model"] = model
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
return AIAgent(**kwargs)
|
||||
|
||||
|
||||
|
|
@ -248,6 +251,19 @@ class TestBuildApiKwargsChatCompletionsServiceTier:
|
|||
assert "service_tier" not in kwargs
|
||||
|
||||
|
||||
class TestBuildApiKwargsKimiFixedTemperature:
|
||||
def test_kimi_for_coding_forces_temperature_on_main_chat_path(self, monkeypatch):
|
||||
agent = _make_agent(
|
||||
monkeypatch,
|
||||
"kimi-coding",
|
||||
base_url="https://api.kimi.com/coding/v1",
|
||||
model="kimi-for-coding",
|
||||
)
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert kwargs["temperature"] == 0.6
|
||||
|
||||
|
||||
class TestBuildApiKwargsAIGateway:
|
||||
def test_uses_chat_completions_format(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1", model="gpt-4o")
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ def agent():
|
|||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
|
@ -76,6 +77,7 @@ def agent_with_memory_tool():
|
|||
):
|
||||
a = AIAgent(
|
||||
api_key="test-k...7890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
|
@ -112,12 +114,14 @@ def test_aiagent_reuses_existing_errors_log_handler():
|
|||
):
|
||||
AIAgent(
|
||||
api_key="test-k...7890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
AIAgent(
|
||||
api_key="test-k...7890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
|
@ -491,6 +495,7 @@ class TestInit:
|
|||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="openai/gpt-4o",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -542,6 +547,7 @@ class TestInit:
|
|||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
|
@ -557,6 +563,7 @@ class TestInit:
|
|||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
|
@ -694,6 +701,7 @@ class TestBuildSystemPrompt:
|
|||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-k...7890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
|
@ -726,6 +734,7 @@ class TestToolUseEnforcementConfig:
|
|||
a = AIAgent(
|
||||
model=model,
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
|
@ -822,6 +831,7 @@ class TestToolUseEnforcementConfig:
|
|||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
|
@ -3433,7 +3443,7 @@ class TestAnthropicBaseUrlPassthrough:
|
|||
):
|
||||
mock_build.return_value = MagicMock()
|
||||
a = AIAgent(
|
||||
api_key="sk-ant-api03-test1234567890",
|
||||
api_key="sk-ant...7890",
|
||||
api_mode="anthropic_messages",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -3457,6 +3467,7 @@ class TestAnthropicCredentialRefresh:
|
|||
mock_build.side_effect = [old_client, new_client]
|
||||
agent = AIAgent(
|
||||
api_key="sk-ant-oat01-stale-token",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_mode="anthropic_messages",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -3487,6 +3498,7 @@ class TestAnthropicCredentialRefresh:
|
|||
):
|
||||
agent = AIAgent(
|
||||
api_key="sk-ant-oat01-same-token",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_mode="anthropic_messages",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -3514,6 +3526,7 @@ class TestAnthropicCredentialRefresh:
|
|||
):
|
||||
agent = AIAgent(
|
||||
api_key="sk-ant-oat01-current-token",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_mode="anthropic_messages",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
|
|||
|
|
@ -12,6 +12,15 @@ sys.modules.setdefault("fal_client", types.SimpleNamespace())
|
|||
import run_agent
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _no_codex_backoff(monkeypatch):
|
||||
"""Short-circuit retry backoff so Codex retry tests don't block on real
|
||||
wall-clock waits (5s jittered_backoff base delay + tight time.sleep loop)."""
|
||||
import time as _time
|
||||
monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0)
|
||||
monkeypatch.setattr(_time, "sleep", lambda *_a, **_k: None)
|
||||
|
||||
|
||||
def _patch_agent_bootstrap(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
run_agent,
|
||||
|
|
|
|||
|
|
@ -80,6 +80,8 @@ class TestStreamingAccumulator:
|
|||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -120,6 +122,8 @@ class TestStreamingAccumulator:
|
|||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -167,6 +171,8 @@ class TestStreamingAccumulator:
|
|||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -205,6 +211,8 @@ class TestStreamingAccumulator:
|
|||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -245,6 +253,8 @@ class TestStreamingCallbacks:
|
|||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -277,6 +287,8 @@ class TestStreamingCallbacks:
|
|||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -308,6 +320,8 @@ class TestStreamingCallbacks:
|
|||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -346,6 +360,8 @@ class TestStreamingCallbacks:
|
|||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -381,6 +397,8 @@ class TestStreamingCallbacks:
|
|||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -428,6 +446,8 @@ class TestStreamingFallback:
|
|||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -455,6 +475,8 @@ class TestStreamingFallback:
|
|||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -477,6 +499,8 @@ class TestStreamingFallback:
|
|||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -500,6 +524,8 @@ class TestStreamingFallback:
|
|||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -542,6 +568,8 @@ class TestStreamingFallback:
|
|||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -577,6 +605,8 @@ class TestStreamingFallback:
|
|||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -619,6 +649,8 @@ class TestReasoningStreaming:
|
|||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -646,6 +678,8 @@ class TestHasStreamConsumers:
|
|||
def test_no_consumers(self):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -656,6 +690,8 @@ class TestHasStreamConsumers:
|
|||
def test_delta_callback_set(self):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -667,6 +703,8 @@ class TestHasStreamConsumers:
|
|||
def test_stream_callback_set(self):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -688,6 +726,8 @@ class TestCodexStreamCallbacks:
|
|||
deltas = []
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -729,6 +769,8 @@ class TestCodexStreamCallbacks:
|
|||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -792,6 +834,8 @@ class TestCodexStreamCallbacks:
|
|||
)
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -810,6 +854,8 @@ class TestCodexStreamCallbacks:
|
|||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
@ -861,6 +907,8 @@ class TestAnthropicStreamCallbacks:
|
|||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ def _make_agent(session_db, *, platform: str):
|
|||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
|
|
|||
28
tests/test_mini_swe_runner.py
Normal file
28
tests/test_mini_swe_runner.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def test_run_task_forces_kimi_fixed_temperature():
|
||||
with patch("openai.OpenAI") as mock_openai:
|
||||
client = MagicMock()
|
||||
client.chat.completions.create.return_value = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="done", tool_calls=[]))]
|
||||
)
|
||||
mock_openai.return_value = client
|
||||
|
||||
from mini_swe_runner import MiniSWERunner
|
||||
|
||||
runner = MiniSWERunner(
|
||||
model="kimi-for-coding",
|
||||
base_url="https://api.kimi.com/coding/v1",
|
||||
api_key="test-key",
|
||||
env_type="local",
|
||||
max_iterations=1,
|
||||
)
|
||||
runner._create_env = MagicMock()
|
||||
runner._cleanup_env = MagicMock()
|
||||
|
||||
result = runner.run_task("2+2")
|
||||
|
||||
assert result["completed"] is True
|
||||
assert client.chat.completions.create.call_args.kwargs["temperature"] == 0.6
|
||||
|
|
@ -34,3 +34,21 @@ def test_messaging_extra_includes_qrcode_for_weixin_setup():
|
|||
|
||||
messaging_extra = optional_dependencies["messaging"]
|
||||
assert any(dep.startswith("qrcode") for dep in messaging_extra)
|
||||
|
||||
|
||||
def test_dingtalk_extra_includes_qrcode_for_qr_auth():
|
||||
"""DingTalk's QR-code device-flow auth (hermes_cli/dingtalk_auth.py)
|
||||
needs the qrcode package."""
|
||||
optional_dependencies = _load_optional_dependencies()
|
||||
|
||||
dingtalk_extra = optional_dependencies["dingtalk"]
|
||||
assert any(dep.startswith("qrcode") for dep in dingtalk_extra)
|
||||
|
||||
|
||||
def test_feishu_extra_includes_qrcode_for_qr_login():
|
||||
"""Feishu's QR login flow (gateway/platforms/feishu.py) needs the
|
||||
qrcode package."""
|
||||
optional_dependencies = _load_optional_dependencies()
|
||||
|
||||
feishu_extra = optional_dependencies["feishu"]
|
||||
assert any(dep.startswith("qrcode") for dep in feishu_extra)
|
||||
|
|
|
|||
|
|
@ -159,18 +159,34 @@ class TestCodeExecutionTZ:
|
|||
return _json.dumps({"error": f"unexpected tool call: {function_name}"})
|
||||
|
||||
def test_tz_injected_when_configured(self):
|
||||
"""When HERMES_TIMEZONE is set, child process sees TZ env var."""
|
||||
"""When HERMES_TIMEZONE is set, child process sees TZ env var.
|
||||
|
||||
Verified alongside leak-prevention + empty-TZ handling in one
|
||||
subprocess call so we don't pay 3x the subprocess startup cost
|
||||
(each execute_code spawns a real Python subprocess ~3s).
|
||||
"""
|
||||
import json as _json
|
||||
os.environ["HERMES_TIMEZONE"] = "Asia/Kolkata"
|
||||
|
||||
# One subprocess, three things checked:
|
||||
# 1) TZ is injected as "Asia/Kolkata"
|
||||
# 2) HERMES_TIMEZONE itself does NOT leak into the child env
|
||||
probe = (
|
||||
'import os; '
|
||||
'print("TZ=" + os.environ.get("TZ", "NOT_SET")); '
|
||||
'print("HERMES_TIMEZONE=" + os.environ.get("HERMES_TIMEZONE", "NOT_SET"))'
|
||||
)
|
||||
with patch("model_tools.handle_function_call", side_effect=self._mock_handle):
|
||||
result = _json.loads(self._execute_code(
|
||||
code='import os; print(os.environ.get("TZ", "NOT_SET"))',
|
||||
task_id="tz-test",
|
||||
code=probe,
|
||||
task_id="tz-combined-test",
|
||||
enabled_tools=[],
|
||||
))
|
||||
assert result["status"] == "success"
|
||||
assert "Asia/Kolkata" in result["output"]
|
||||
assert "TZ=Asia/Kolkata" in result["output"]
|
||||
assert "HERMES_TIMEZONE=NOT_SET" in result["output"], (
|
||||
"HERMES_TIMEZONE should not leak into child env (only TZ)"
|
||||
)
|
||||
|
||||
def test_tz_not_injected_when_empty(self):
|
||||
"""When HERMES_TIMEZONE is not set, child process has no TZ."""
|
||||
|
|
@ -186,20 +202,6 @@ class TestCodeExecutionTZ:
|
|||
assert result["status"] == "success"
|
||||
assert "NOT_SET" in result["output"]
|
||||
|
||||
def test_hermes_timezone_not_leaked_to_child(self):
|
||||
"""HERMES_TIMEZONE itself must NOT appear in child env (only TZ)."""
|
||||
import json as _json
|
||||
os.environ["HERMES_TIMEZONE"] = "Asia/Kolkata"
|
||||
|
||||
with patch("model_tools.handle_function_call", side_effect=self._mock_handle):
|
||||
result = _json.loads(self._execute_code(
|
||||
code='import os; print(os.environ.get("HERMES_TIMEZONE", "NOT_SET"))',
|
||||
task_id="tz-leak-test",
|
||||
enabled_tools=[],
|
||||
))
|
||||
assert result["status"] == "success"
|
||||
assert "NOT_SET" in result["output"]
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Cron timezone-aware scheduling
|
||||
|
|
|
|||
|
|
@ -31,6 +31,29 @@ def test_import_loads_env_from_hermes_home(tmp_path, monkeypatch):
|
|||
assert os.getenv("OPENROUTER_API_KEY") == "from-hermes-home"
|
||||
|
||||
|
||||
def test_generate_summary_custom_client_forces_kimi_temperature():
|
||||
config = CompressionConfig(
|
||||
summarization_model="kimi-for-coding",
|
||||
temperature=0.3,
|
||||
summary_target_tokens=100,
|
||||
max_retries=1,
|
||||
)
|
||||
compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
|
||||
compressor.config = config
|
||||
compressor.logger = MagicMock()
|
||||
compressor._use_call_llm = False
|
||||
compressor.client = MagicMock()
|
||||
compressor.client.chat.completions.create.return_value = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="[CONTEXT SUMMARY]: summary"))]
|
||||
)
|
||||
|
||||
metrics = TrajectoryMetrics()
|
||||
result = compressor._generate_summary("tool output", metrics)
|
||||
|
||||
assert result.startswith("[CONTEXT SUMMARY]:")
|
||||
assert compressor.client.chat.completions.create.call_args.kwargs["temperature"] == 0.6
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CompressionConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ each asyncio.run() gets a client bound to the current loop.
|
|||
"""
|
||||
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -113,3 +114,30 @@ class TestSourceLineVerification:
|
|||
"""_get_async_client method should exist."""
|
||||
src = self._read_file()
|
||||
assert "def _get_async_client(self)" in src
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_summary_async_custom_client_forces_kimi_temperature():
|
||||
from trajectory_compressor import CompressionConfig, TrajectoryCompressor, TrajectoryMetrics
|
||||
|
||||
config = CompressionConfig(
|
||||
summarization_model="kimi-for-coding",
|
||||
temperature=0.3,
|
||||
summary_target_tokens=100,
|
||||
max_retries=1,
|
||||
)
|
||||
compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
|
||||
compressor.config = config
|
||||
compressor.logger = MagicMock()
|
||||
compressor._use_call_llm = False
|
||||
async_client = MagicMock()
|
||||
async_client.chat.completions.create = MagicMock(return_value=SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="[CONTEXT SUMMARY]: summary"))]
|
||||
))
|
||||
compressor._get_async_client = MagicMock(return_value=async_client)
|
||||
|
||||
metrics = TrajectoryMetrics()
|
||||
result = await compressor._generate_summary_async("tool output", metrics)
|
||||
|
||||
assert result.startswith("[CONTEXT SUMMARY]:")
|
||||
assert async_client.chat.completions.create.call_args.kwargs["temperature"] == 0.6
|
||||
|
|
|
|||
440
tests/test_tui_gateway_server.py
Normal file
440
tests/test_tui_gateway_server.py
Normal file
|
|
@ -0,0 +1,440 @@
|
|||
import json
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from tui_gateway import server
|
||||
|
||||
|
||||
class _ChunkyStdout:
|
||||
def __init__(self):
|
||||
self.parts: list[str] = []
|
||||
|
||||
def write(self, text: str) -> int:
|
||||
for ch in text:
|
||||
self.parts.append(ch)
|
||||
time.sleep(0.0001)
|
||||
return len(text)
|
||||
|
||||
def flush(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _BrokenStdout:
|
||||
def write(self, text: str) -> int:
|
||||
raise BrokenPipeError
|
||||
|
||||
def flush(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_write_json_serializes_concurrent_writes(monkeypatch):
|
||||
out = _ChunkyStdout()
|
||||
monkeypatch.setattr(server, "_real_stdout", out)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=server.write_json, args=({"seq": i, "text": "x" * 24},))
|
||||
for i in range(8)
|
||||
]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
lines = "".join(out.parts).splitlines()
|
||||
|
||||
assert len(lines) == 8
|
||||
assert {json.loads(line)["seq"] for line in lines} == set(range(8))
|
||||
|
||||
|
||||
def test_write_json_returns_false_on_broken_pipe(monkeypatch):
|
||||
monkeypatch.setattr(server, "_real_stdout", _BrokenStdout())
|
||||
|
||||
assert server.write_json({"ok": True}) is False
|
||||
|
||||
|
||||
def test_status_callback_emits_kind_and_text():
|
||||
with patch("tui_gateway.server._emit") as emit:
|
||||
cb = server._agent_cbs("sid")["status_callback"]
|
||||
cb("context_pressure", "85% to compaction")
|
||||
|
||||
emit.assert_called_once_with(
|
||||
"status.update",
|
||||
"sid",
|
||||
{"kind": "context_pressure", "text": "85% to compaction"},
|
||||
)
|
||||
|
||||
|
||||
def test_status_callback_accepts_single_message_argument():
|
||||
with patch("tui_gateway.server._emit") as emit:
|
||||
cb = server._agent_cbs("sid")["status_callback"]
|
||||
cb("thinking...")
|
||||
|
||||
emit.assert_called_once_with(
|
||||
"status.update",
|
||||
"sid",
|
||||
{"kind": "status", "text": "thinking..."},
|
||||
)
|
||||
|
||||
|
||||
def _session(agent=None, **extra):
|
||||
return {
|
||||
"agent": agent if agent is not None else types.SimpleNamespace(),
|
||||
"session_key": "session-key",
|
||||
"history": [],
|
||||
"history_lock": threading.Lock(),
|
||||
"history_version": 0,
|
||||
"running": False,
|
||||
"attached_images": [],
|
||||
"image_counter": 0,
|
||||
"cols": 80,
|
||||
"slash_worker": None,
|
||||
"show_reasoning": False,
|
||||
"tool_progress_mode": "all",
|
||||
**extra,
|
||||
}
|
||||
|
||||
|
||||
def test_config_set_yolo_toggles_session_scope():
|
||||
from tools.approval import clear_session, is_session_yolo_enabled
|
||||
|
||||
server._sessions["sid"] = _session()
|
||||
try:
|
||||
resp_on = server.handle_request({"id": "1", "method": "config.set", "params": {"session_id": "sid", "key": "yolo"}})
|
||||
assert resp_on["result"]["value"] == "1"
|
||||
assert is_session_yolo_enabled("session-key") is True
|
||||
|
||||
resp_off = server.handle_request({"id": "2", "method": "config.set", "params": {"session_id": "sid", "key": "yolo"}})
|
||||
assert resp_off["result"]["value"] == "0"
|
||||
assert is_session_yolo_enabled("session-key") is False
|
||||
finally:
|
||||
clear_session("session-key")
|
||||
server._sessions.clear()
|
||||
|
||||
|
||||
def test_enable_gateway_prompts_sets_gateway_env(monkeypatch):
|
||||
monkeypatch.delenv("HERMES_EXEC_ASK", raising=False)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
|
||||
server._enable_gateway_prompts()
|
||||
|
||||
assert server.os.environ["HERMES_GATEWAY_SESSION"] == "1"
|
||||
assert server.os.environ["HERMES_EXEC_ASK"] == "1"
|
||||
assert server.os.environ["HERMES_INTERACTIVE"] == "1"
|
||||
|
||||
|
||||
def test_setup_status_reports_provider_config(monkeypatch):
|
||||
monkeypatch.setattr("hermes_cli.main._has_any_provider_configured", lambda: False)
|
||||
|
||||
resp = server.handle_request({"id": "1", "method": "setup.status", "params": {}})
|
||||
|
||||
assert resp["result"]["provider_configured"] is False
|
||||
|
||||
|
||||
def test_config_set_reasoning_updates_live_session_and_agent(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(server, "_hermes_home", tmp_path)
|
||||
agent = types.SimpleNamespace(reasoning_config=None)
|
||||
server._sessions["sid"] = _session(agent=agent)
|
||||
|
||||
resp_effort = server.handle_request(
|
||||
{"id": "1", "method": "config.set", "params": {"session_id": "sid", "key": "reasoning", "value": "low"}}
|
||||
)
|
||||
assert resp_effort["result"]["value"] == "low"
|
||||
assert agent.reasoning_config == {"enabled": True, "effort": "low"}
|
||||
|
||||
resp_show = server.handle_request(
|
||||
{"id": "2", "method": "config.set", "params": {"session_id": "sid", "key": "reasoning", "value": "show"}}
|
||||
)
|
||||
assert resp_show["result"]["value"] == "show"
|
||||
assert server._sessions["sid"]["show_reasoning"] is True
|
||||
|
||||
|
||||
def test_config_set_verbose_updates_session_mode_and_agent(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(server, "_hermes_home", tmp_path)
|
||||
agent = types.SimpleNamespace(verbose_logging=False)
|
||||
server._sessions["sid"] = _session(agent=agent)
|
||||
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "config.set", "params": {"session_id": "sid", "key": "verbose", "value": "cycle"}}
|
||||
)
|
||||
|
||||
assert resp["result"]["value"] == "verbose"
|
||||
assert server._sessions["sid"]["tool_progress_mode"] == "verbose"
|
||||
assert agent.verbose_logging is True
|
||||
|
||||
|
||||
def test_config_set_model_uses_live_switch_path(monkeypatch):
|
||||
server._sessions["sid"] = _session()
|
||||
seen = {}
|
||||
|
||||
def _fake_apply(sid, session, raw):
|
||||
seen["args"] = (sid, session["session_key"], raw)
|
||||
return {"value": "new/model", "warning": "catalog unreachable"}
|
||||
|
||||
monkeypatch.setattr(server, "_apply_model_switch", _fake_apply)
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "config.set", "params": {"session_id": "sid", "key": "model", "value": "new/model"}}
|
||||
)
|
||||
|
||||
assert resp["result"]["value"] == "new/model"
|
||||
assert resp["result"]["warning"] == "catalog unreachable"
|
||||
assert seen["args"] == ("sid", "session-key", "new/model")
|
||||
|
||||
|
||||
def test_config_set_model_global_persists(monkeypatch):
|
||||
class _Agent:
|
||||
provider = "openrouter"
|
||||
model = "old/model"
|
||||
base_url = ""
|
||||
api_key = "sk-old"
|
||||
|
||||
def switch_model(self, **kwargs):
|
||||
return None
|
||||
|
||||
result = types.SimpleNamespace(
|
||||
success=True,
|
||||
new_model="anthropic/claude-sonnet-4.6",
|
||||
target_provider="anthropic",
|
||||
api_key="sk-new",
|
||||
base_url="https://api.anthropic.com",
|
||||
api_mode="anthropic_messages",
|
||||
warning_message="",
|
||||
)
|
||||
seen = {}
|
||||
saved = {}
|
||||
|
||||
def _switch_model(**kwargs):
|
||||
seen.update(kwargs)
|
||||
return result
|
||||
|
||||
server._sessions["sid"] = _session(agent=_Agent())
|
||||
monkeypatch.setattr("hermes_cli.model_switch.switch_model", _switch_model)
|
||||
monkeypatch.setattr(server, "_restart_slash_worker", lambda session: None)
|
||||
monkeypatch.setattr(server, "_emit", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr("hermes_cli.config.save_config", lambda cfg: saved.update(cfg))
|
||||
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "config.set", "params": {"session_id": "sid", "key": "model", "value": "anthropic/claude-sonnet-4.6 --global"}}
|
||||
)
|
||||
|
||||
assert resp["result"]["value"] == "anthropic/claude-sonnet-4.6"
|
||||
assert seen["is_global"] is True
|
||||
assert saved["model"]["default"] == "anthropic/claude-sonnet-4.6"
|
||||
assert saved["model"]["provider"] == "anthropic"
|
||||
assert saved["model"]["base_url"] == "https://api.anthropic.com"
|
||||
|
||||
|
||||
def test_config_set_personality_rejects_unknown_name(monkeypatch):
|
||||
monkeypatch.setattr(server, "_available_personalities", lambda cfg=None: {"helpful": "You are helpful."})
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "config.set", "params": {"key": "personality", "value": "bogus"}}
|
||||
)
|
||||
|
||||
assert "error" in resp
|
||||
assert "Unknown personality" in resp["error"]["message"]
|
||||
|
||||
|
||||
def test_config_set_personality_resets_history_and_returns_info(monkeypatch):
|
||||
session = _session(agent=types.SimpleNamespace(), history=[{"role": "user", "text": "hi"}], history_version=4)
|
||||
new_agent = types.SimpleNamespace(model="x")
|
||||
emits = []
|
||||
|
||||
server._sessions["sid"] = session
|
||||
monkeypatch.setattr(server, "_available_personalities", lambda cfg=None: {"helpful": "You are helpful."})
|
||||
monkeypatch.setattr(server, "_make_agent", lambda sid, key, session_id=None: new_agent)
|
||||
monkeypatch.setattr(server, "_session_info", lambda agent: {"model": getattr(agent, "model", "?")})
|
||||
monkeypatch.setattr(server, "_restart_slash_worker", lambda session: None)
|
||||
monkeypatch.setattr(server, "_emit", lambda *args: emits.append(args))
|
||||
monkeypatch.setattr(server, "_write_config_key", lambda path, value: None)
|
||||
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "config.set", "params": {"session_id": "sid", "key": "personality", "value": "helpful"}}
|
||||
)
|
||||
|
||||
assert resp["result"]["history_reset"] is True
|
||||
assert resp["result"]["info"] == {"model": "x"}
|
||||
assert session["history"] == []
|
||||
assert session["history_version"] == 5
|
||||
assert ("session.info", "sid", {"model": "x"}) in emits
|
||||
|
||||
|
||||
def test_session_compress_uses_compress_helper(monkeypatch):
|
||||
agent = types.SimpleNamespace()
|
||||
server._sessions["sid"] = _session(agent=agent)
|
||||
|
||||
monkeypatch.setattr(server, "_compress_session_history", lambda session, focus_topic=None: (2, {"total": 42}))
|
||||
monkeypatch.setattr(server, "_session_info", lambda _agent: {"model": "x"})
|
||||
|
||||
with patch("tui_gateway.server._emit") as emit:
|
||||
resp = server.handle_request({"id": "1", "method": "session.compress", "params": {"session_id": "sid"}})
|
||||
|
||||
assert resp["result"]["removed"] == 2
|
||||
assert resp["result"]["usage"]["total"] == 42
|
||||
emit.assert_called_once_with("session.info", "sid", {"model": "x"})
|
||||
|
||||
|
||||
def test_prompt_submit_sets_approval_session_key(monkeypatch):
|
||||
from tools.approval import get_current_session_key
|
||||
|
||||
captured = {}
|
||||
|
||||
class _Agent:
|
||||
def run_conversation(self, prompt, conversation_history=None, stream_callback=None):
|
||||
captured["session_key"] = get_current_session_key(default="")
|
||||
return {"final_response": "ok", "messages": [{"role": "assistant", "content": "ok"}]}
|
||||
|
||||
class _ImmediateThread:
|
||||
def __init__(self, target=None, daemon=None):
|
||||
self._target = target
|
||||
|
||||
def start(self):
|
||||
self._target()
|
||||
|
||||
server._sessions["sid"] = _session(agent=_Agent())
|
||||
monkeypatch.setattr(server.threading, "Thread", _ImmediateThread)
|
||||
monkeypatch.setattr(server, "_emit", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(server, "make_stream_renderer", lambda cols: None)
|
||||
monkeypatch.setattr(server, "render_message", lambda raw, cols: None)
|
||||
|
||||
resp = server.handle_request({"id": "1", "method": "prompt.submit", "params": {"session_id": "sid", "text": "ping"}})
|
||||
|
||||
assert resp["result"]["status"] == "streaming"
|
||||
assert captured["session_key"] == "session-key"
|
||||
|
||||
|
||||
def test_prompt_submit_expands_context_refs(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
class _Agent:
|
||||
model = "test/model"
|
||||
base_url = ""
|
||||
api_key = ""
|
||||
|
||||
def run_conversation(self, prompt, conversation_history=None, stream_callback=None):
|
||||
captured["prompt"] = prompt
|
||||
return {"final_response": "ok", "messages": [{"role": "assistant", "content": "ok"}]}
|
||||
|
||||
class _ImmediateThread:
|
||||
def __init__(self, target=None, daemon=None):
|
||||
self._target = target
|
||||
|
||||
def start(self):
|
||||
self._target()
|
||||
|
||||
fake_ctx = types.ModuleType("agent.context_references")
|
||||
fake_ctx.preprocess_context_references = lambda message, **kwargs: types.SimpleNamespace(
|
||||
blocked=False, message="expanded prompt", warnings=[], references=[], injected_tokens=0
|
||||
)
|
||||
fake_meta = types.ModuleType("agent.model_metadata")
|
||||
fake_meta.get_model_context_length = lambda *args, **kwargs: 100000
|
||||
|
||||
server._sessions["sid"] = _session(agent=_Agent())
|
||||
monkeypatch.setattr(server.threading, "Thread", _ImmediateThread)
|
||||
monkeypatch.setattr(server, "_emit", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(server, "make_stream_renderer", lambda cols: None)
|
||||
monkeypatch.setattr(server, "render_message", lambda raw, cols: None)
|
||||
monkeypatch.setitem(sys.modules, "agent.context_references", fake_ctx)
|
||||
monkeypatch.setitem(sys.modules, "agent.model_metadata", fake_meta)
|
||||
|
||||
server.handle_request({"id": "1", "method": "prompt.submit", "params": {"session_id": "sid", "text": "@diff"}})
|
||||
|
||||
assert captured["prompt"] == "expanded prompt"
|
||||
|
||||
|
||||
def test_image_attach_appends_local_image(monkeypatch):
|
||||
fake_cli = types.ModuleType("cli")
|
||||
fake_cli._IMAGE_EXTENSIONS = {".png"}
|
||||
fake_cli._split_path_input = lambda raw: (raw, "")
|
||||
fake_cli._resolve_attachment_path = lambda raw: Path("/tmp/cat.png")
|
||||
|
||||
server._sessions["sid"] = _session()
|
||||
monkeypatch.setitem(sys.modules, "cli", fake_cli)
|
||||
|
||||
resp = server.handle_request({"id": "1", "method": "image.attach", "params": {"session_id": "sid", "path": "/tmp/cat.png"}})
|
||||
|
||||
assert resp["result"]["attached"] is True
|
||||
assert resp["result"]["name"] == "cat.png"
|
||||
assert len(server._sessions["sid"]["attached_images"]) == 1
|
||||
|
||||
|
||||
def test_command_dispatch_exec_nonzero_surfaces_error(monkeypatch):
|
||||
monkeypatch.setattr(server, "_load_cfg", lambda: {"quick_commands": {"boom": {"type": "exec", "command": "boom"}}})
|
||||
monkeypatch.setattr(
|
||||
server.subprocess,
|
||||
"run",
|
||||
lambda *args, **kwargs: types.SimpleNamespace(returncode=1, stdout="", stderr="failed"),
|
||||
)
|
||||
|
||||
resp = server.handle_request({"id": "1", "method": "command.dispatch", "params": {"name": "boom"}})
|
||||
|
||||
assert "error" in resp
|
||||
assert "failed" in resp["error"]["message"]
|
||||
|
||||
|
||||
def test_plugins_list_surfaces_loader_error(monkeypatch):
|
||||
with patch("hermes_cli.plugins.get_plugin_manager", side_effect=Exception("boom")):
|
||||
resp = server.handle_request({"id": "1", "method": "plugins.list", "params": {}})
|
||||
|
||||
assert "error" in resp
|
||||
assert "boom" in resp["error"]["message"]
|
||||
|
||||
|
||||
def test_complete_slash_surfaces_completer_error(monkeypatch):
|
||||
with patch("hermes_cli.commands.SlashCommandCompleter", side_effect=Exception("no completer")):
|
||||
resp = server.handle_request({"id": "1", "method": "complete.slash", "params": {"text": "/mo"}})
|
||||
|
||||
assert "error" in resp
|
||||
assert "no completer" in resp["error"]["message"]
|
||||
|
||||
|
||||
def test_input_detect_drop_attaches_image(monkeypatch):
|
||||
fake_cli = types.ModuleType("cli")
|
||||
fake_cli._detect_file_drop = lambda raw: {
|
||||
"path": Path("/tmp/cat.png"),
|
||||
"is_image": True,
|
||||
"remainder": "",
|
||||
}
|
||||
|
||||
server._sessions["sid"] = _session()
|
||||
monkeypatch.setitem(sys.modules, "cli", fake_cli)
|
||||
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "input.detect_drop", "params": {"session_id": "sid", "text": "/tmp/cat.png"}}
|
||||
)
|
||||
|
||||
assert resp["result"]["matched"] is True
|
||||
assert resp["result"]["is_image"] is True
|
||||
assert resp["result"]["text"] == "[User attached image: cat.png]"
|
||||
|
||||
|
||||
def test_rollback_restore_resolves_number_and_file_path():
|
||||
calls = {}
|
||||
|
||||
class _Mgr:
|
||||
enabled = True
|
||||
|
||||
def list_checkpoints(self, cwd):
|
||||
return [{"hash": "aaa111"}, {"hash": "bbb222"}]
|
||||
|
||||
def restore(self, cwd, target, file_path=None):
|
||||
calls["args"] = (cwd, target, file_path)
|
||||
return {"success": True, "message": "done"}
|
||||
|
||||
server._sessions["sid"] = _session(agent=types.SimpleNamespace(_checkpoint_mgr=_Mgr()), history=[])
|
||||
resp = server.handle_request(
|
||||
{
|
||||
"id": "1",
|
||||
"method": "rollback.restore",
|
||||
"params": {"session_id": "sid", "hash": "2", "file_path": "src/app.tsx"},
|
||||
}
|
||||
)
|
||||
|
||||
assert resp["result"]["success"] is True
|
||||
assert calls["args"][1] == "bbb222"
|
||||
assert calls["args"][2] == "src/app.tsx"
|
||||
199
tests/tools/test_accretion_caps.py
Normal file
199
tests/tools/test_accretion_caps.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
"""Accretion caps for _read_tracker (file_tools) and _completion_consumed
|
||||
(process_registry).
|
||||
|
||||
Both structures are process-lifetime singletons that previously grew
|
||||
unbounded in long-running CLI / gateway sessions:
|
||||
|
||||
file_tools._read_tracker[task_id]
|
||||
├─ read_history (set) — one entry per unique (path, offset, limit)
|
||||
├─ dedup (dict) — one entry per unique (path, offset, limit)
|
||||
└─ read_timestamps (dict) — one entry per unique resolved path
|
||||
process_registry._completion_consumed (set) — one entry per session_id
|
||||
ever polled / waited / logged
|
||||
|
||||
None of these were ever trimmed. A 10k-read CLI session accumulated
|
||||
roughly 1.5MB of tracker state; a gateway with high background-process
|
||||
churn accumulated ~20B per session_id until the process exited.
|
||||
|
||||
These tests pin the new caps + prune hooks.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestReadTrackerCaps:
|
||||
def setup_method(self):
|
||||
from tools import file_tools
|
||||
|
||||
# Clean slate per test.
|
||||
with file_tools._read_tracker_lock:
|
||||
file_tools._read_tracker.clear()
|
||||
|
||||
def test_read_history_capped(self, monkeypatch):
|
||||
"""read_history set is bounded by _READ_HISTORY_CAP."""
|
||||
from tools import file_tools as ft
|
||||
|
||||
monkeypatch.setattr(ft, "_READ_HISTORY_CAP", 10)
|
||||
task_data = {
|
||||
"last_key": None,
|
||||
"consecutive": 0,
|
||||
"read_history": set((f"/p{i}", 0, 500) for i in range(50)),
|
||||
"dedup": {},
|
||||
"read_timestamps": {},
|
||||
}
|
||||
ft._cap_read_tracker_data(task_data)
|
||||
assert len(task_data["read_history"]) == 10
|
||||
|
||||
def test_dedup_capped_oldest_first(self, monkeypatch):
|
||||
"""dedup dict is bounded; oldest entries evicted first."""
|
||||
from tools import file_tools as ft
|
||||
|
||||
monkeypatch.setattr(ft, "_DEDUP_CAP", 5)
|
||||
task_data = {
|
||||
"read_history": set(),
|
||||
"dedup": {(f"/p{i}", 0, 500): float(i) for i in range(20)},
|
||||
"read_timestamps": {},
|
||||
}
|
||||
ft._cap_read_tracker_data(task_data)
|
||||
assert len(task_data["dedup"]) == 5
|
||||
# Entries 15-19 (inserted last) should survive.
|
||||
assert ("/p19", 0, 500) in task_data["dedup"]
|
||||
assert ("/p15", 0, 500) in task_data["dedup"]
|
||||
# Entries 0-14 should be evicted.
|
||||
assert ("/p0", 0, 500) not in task_data["dedup"]
|
||||
assert ("/p14", 0, 500) not in task_data["dedup"]
|
||||
|
||||
def test_read_timestamps_capped_oldest_first(self, monkeypatch):
|
||||
"""read_timestamps dict is bounded; oldest entries evicted first."""
|
||||
from tools import file_tools as ft
|
||||
|
||||
monkeypatch.setattr(ft, "_READ_TIMESTAMPS_CAP", 3)
|
||||
task_data = {
|
||||
"read_history": set(),
|
||||
"dedup": {},
|
||||
"read_timestamps": {f"/path/{i}": float(i) for i in range(10)},
|
||||
}
|
||||
ft._cap_read_tracker_data(task_data)
|
||||
assert len(task_data["read_timestamps"]) == 3
|
||||
assert "/path/9" in task_data["read_timestamps"]
|
||||
assert "/path/7" in task_data["read_timestamps"]
|
||||
assert "/path/0" not in task_data["read_timestamps"]
|
||||
|
||||
def test_cap_is_idempotent_under_cap(self, monkeypatch):
|
||||
"""When containers are under cap, _cap_read_tracker_data is a no-op."""
|
||||
from tools import file_tools as ft
|
||||
|
||||
monkeypatch.setattr(ft, "_READ_HISTORY_CAP", 100)
|
||||
monkeypatch.setattr(ft, "_DEDUP_CAP", 100)
|
||||
monkeypatch.setattr(ft, "_READ_TIMESTAMPS_CAP", 100)
|
||||
task_data = {
|
||||
"read_history": {("/a", 0, 500), ("/b", 0, 500)},
|
||||
"dedup": {("/a", 0, 500): 1.0},
|
||||
"read_timestamps": {"/a": 1.0},
|
||||
}
|
||||
rh_before = set(task_data["read_history"])
|
||||
dedup_before = dict(task_data["dedup"])
|
||||
ts_before = dict(task_data["read_timestamps"])
|
||||
|
||||
ft._cap_read_tracker_data(task_data)
|
||||
|
||||
assert task_data["read_history"] == rh_before
|
||||
assert task_data["dedup"] == dedup_before
|
||||
assert task_data["read_timestamps"] == ts_before
|
||||
|
||||
def test_cap_handles_missing_containers(self):
|
||||
"""Missing sub-keys don't cause AttributeError."""
|
||||
from tools import file_tools as ft
|
||||
|
||||
ft._cap_read_tracker_data({}) # no containers at all
|
||||
ft._cap_read_tracker_data({"read_history": None})
|
||||
ft._cap_read_tracker_data({"dedup": None})
|
||||
|
||||
def test_live_cap_applied_after_read_add(self, tmp_path, monkeypatch):
|
||||
"""Live read_file path enforces caps."""
|
||||
from tools import file_tools as ft
|
||||
|
||||
monkeypatch.setattr(ft, "_READ_HISTORY_CAP", 3)
|
||||
monkeypatch.setattr(ft, "_DEDUP_CAP", 3)
|
||||
monkeypatch.setattr(ft, "_READ_TIMESTAMPS_CAP", 3)
|
||||
|
||||
# Create 10 distinct files and read each once.
|
||||
for i in range(10):
|
||||
p = tmp_path / f"file_{i}.txt"
|
||||
p.write_text(f"content {i}\n" * 10)
|
||||
ft.read_file_tool(path=str(p), task_id="long-session")
|
||||
|
||||
with ft._read_tracker_lock:
|
||||
td = ft._read_tracker["long-session"]
|
||||
assert len(td["read_history"]) <= 3
|
||||
assert len(td["dedup"]) <= 3
|
||||
assert len(td["read_timestamps"]) <= 3
|
||||
|
||||
|
||||
class TestCompletionConsumedPrune:
|
||||
def test_prune_drops_completion_entry_with_expired_session(self):
|
||||
"""When a finished session is pruned, _completion_consumed is
|
||||
cleared for the same session_id."""
|
||||
from tools.process_registry import ProcessRegistry, FINISHED_TTL_SECONDS
|
||||
import time
|
||||
|
||||
reg = ProcessRegistry()
|
||||
# Fake a finished session whose started_at is older than the TTL.
|
||||
class _FakeSess:
|
||||
def __init__(self, sid):
|
||||
self.id = sid
|
||||
self.started_at = time.time() - (FINISHED_TTL_SECONDS + 100)
|
||||
self.exited = True
|
||||
|
||||
reg._finished["stale-1"] = _FakeSess("stale-1")
|
||||
reg._completion_consumed.add("stale-1")
|
||||
|
||||
with reg._lock:
|
||||
reg._prune_if_needed()
|
||||
|
||||
assert "stale-1" not in reg._finished
|
||||
assert "stale-1" not in reg._completion_consumed
|
||||
|
||||
def test_prune_drops_completion_entry_for_lru_evicted(self):
|
||||
"""Same contract for the LRU path (over MAX_PROCESSES)."""
|
||||
from tools import process_registry as pr
|
||||
import time
|
||||
|
||||
reg = pr.ProcessRegistry()
|
||||
|
||||
class _FakeSess:
|
||||
def __init__(self, sid, started):
|
||||
self.id = sid
|
||||
self.started_at = started
|
||||
self.exited = True
|
||||
|
||||
# Fill above MAX_PROCESSES with recently-finished sessions.
|
||||
now = time.time()
|
||||
for i in range(pr.MAX_PROCESSES + 5):
|
||||
sid = f"sess-{i}"
|
||||
reg._finished[sid] = _FakeSess(sid, now - i) # sess-0 newest
|
||||
reg._completion_consumed.add(sid)
|
||||
|
||||
with reg._lock:
|
||||
# _prune_if_needed removes one oldest finished per invocation;
|
||||
# call it enough times to trim back down.
|
||||
for _ in range(10):
|
||||
reg._prune_if_needed()
|
||||
|
||||
# The _completion_consumed set should not contain session IDs that
|
||||
# are no longer in _running or _finished.
|
||||
assert (reg._completion_consumed - (reg._running.keys() | reg._finished.keys())) == set()
|
||||
|
||||
def test_prune_clears_dangling_completion_entries(self):
|
||||
"""Stale entries in _completion_consumed without a backing session
|
||||
record are cleared out (belt-and-suspenders invariant)."""
|
||||
from tools.process_registry import ProcessRegistry
|
||||
|
||||
reg = ProcessRegistry()
|
||||
# Add a dangling entry that was never in _running or _finished.
|
||||
reg._completion_consumed.add("dangling-never-tracked")
|
||||
|
||||
with reg._lock:
|
||||
reg._prune_if_needed()
|
||||
|
||||
assert "dangling-never-tracked" not in reg._completion_consumed
|
||||
|
|
@ -2,11 +2,13 @@
|
|||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch as mock_patch
|
||||
|
||||
import tools.approval as approval_module
|
||||
from tools.approval import (
|
||||
_get_approval_mode,
|
||||
_smart_approve,
|
||||
approve_session,
|
||||
detect_dangerous_command,
|
||||
is_approved,
|
||||
|
|
@ -26,6 +28,21 @@ class TestApprovalModeParsing:
|
|||
assert _get_approval_mode() == "off"
|
||||
|
||||
|
||||
class TestSmartApproval:
|
||||
def test_smart_approval_uses_call_llm(self):
|
||||
response = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="APPROVE"))]
|
||||
)
|
||||
with mock_patch("agent.auxiliary_client.call_llm", return_value=response) as mock_call:
|
||||
result = _smart_approve("python -c \"print('hello')\"", "script execution via -c flag")
|
||||
|
||||
assert result == "approve"
|
||||
mock_call.assert_called_once()
|
||||
assert mock_call.call_args.kwargs["task"] == "approval"
|
||||
assert mock_call.call_args.kwargs["temperature"] == 0
|
||||
assert mock_call.call_args.kwargs["max_tokens"] == 16
|
||||
|
||||
|
||||
class TestDetectDangerousRm:
|
||||
def test_rm_rf_detected(self):
|
||||
is_dangerous, key, desc = detect_dangerous_command("rm -rf /home/user")
|
||||
|
|
@ -820,4 +837,3 @@ class TestChmodExecuteCombo:
|
|||
dangerous, _, _ = detect_dangerous_command(cmd)
|
||||
assert dangerous is False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -77,3 +77,42 @@ class TestResolveCdpOverride:
|
|||
"https://cdp.browser-use.example/session/json/version",
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
|
||||
class TestGetCdpOverride:
|
||||
def test_prefers_env_var_over_config(self, monkeypatch):
|
||||
import tools.browser_tool as browser_tool
|
||||
|
||||
monkeypatch.setenv("BROWSER_CDP_URL", HTTP_URL)
|
||||
monkeypatch.setattr(
|
||||
browser_tool,
|
||||
"read_raw_config",
|
||||
lambda: {"browser": {"cdp_url": "http://config-host:9222"}},
|
||||
raising=False,
|
||||
)
|
||||
|
||||
response = Mock()
|
||||
response.raise_for_status.return_value = None
|
||||
response.json.return_value = {"webSocketDebuggerUrl": WS_URL}
|
||||
|
||||
with patch("tools.browser_tool.requests.get", return_value=response) as mock_get:
|
||||
resolved = browser_tool._get_cdp_override()
|
||||
|
||||
assert resolved == WS_URL
|
||||
mock_get.assert_called_once_with(VERSION_URL, timeout=10)
|
||||
|
||||
def test_uses_config_browser_cdp_url_when_env_missing(self, monkeypatch):
|
||||
import tools.browser_tool as browser_tool
|
||||
|
||||
monkeypatch.delenv("BROWSER_CDP_URL", raising=False)
|
||||
|
||||
response = Mock()
|
||||
response.raise_for_status.return_value = None
|
||||
response.json.return_value = {"webSocketDebuggerUrl": WS_URL}
|
||||
|
||||
with patch("hermes_cli.config.read_raw_config", return_value={"browser": {"cdp_url": HTTP_URL}}), \
|
||||
patch("tools.browser_tool.requests.get", return_value=response) as mock_get:
|
||||
resolved = browser_tool._get_cdp_override()
|
||||
|
||||
assert resolved == WS_URL
|
||||
mock_get.assert_called_once_with(VERSION_URL, timeout=10)
|
||||
|
|
|
|||
|
|
@ -28,12 +28,22 @@ def _isolate_sessions():
|
|||
bt._active_sessions.update(orig)
|
||||
|
||||
|
||||
def _make_socket_dir(tmpdir, session_name, pid=None):
|
||||
"""Create a fake agent-browser socket directory with optional PID file."""
|
||||
def _make_socket_dir(tmpdir, session_name, pid=None, owner_pid=None):
|
||||
"""Create a fake agent-browser socket directory with optional PID files.
|
||||
|
||||
Args:
|
||||
tmpdir: base temp directory
|
||||
session_name: name like "h_abc1234567" or "cdp_abc1234567"
|
||||
pid: daemon PID to write to <session>.pid (None = no file)
|
||||
owner_pid: owning hermes PID to write to <session>.owner_pid
|
||||
(None = no file; tests the legacy path)
|
||||
"""
|
||||
d = tmpdir / f"agent-browser-{session_name}"
|
||||
d.mkdir()
|
||||
if pid is not None:
|
||||
(d / f"{session_name}.pid").write_text(str(pid))
|
||||
if owner_pid is not None:
|
||||
(d / f"{session_name}.owner_pid").write_text(str(owner_pid))
|
||||
return d
|
||||
|
||||
|
||||
|
|
@ -62,7 +72,10 @@ class TestReapOrphanedBrowserSessions:
|
|||
assert not d.exists()
|
||||
|
||||
def test_orphaned_alive_daemon_is_killed(self, fake_tmpdir):
|
||||
"""Alive daemon not tracked by _active_sessions gets SIGTERM."""
|
||||
"""Alive daemon not tracked by _active_sessions gets SIGTERM (legacy path).
|
||||
|
||||
No owner_pid file => falls back to tracked_names check.
|
||||
"""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
d = _make_socket_dir(fake_tmpdir, "h_orphan12345", pid=12345)
|
||||
|
|
@ -84,7 +97,7 @@ class TestReapOrphanedBrowserSessions:
|
|||
assert (12345, signal.SIGTERM) in kill_calls
|
||||
|
||||
def test_tracked_session_is_not_reaped(self, fake_tmpdir):
|
||||
"""Sessions tracked in _active_sessions are left alone."""
|
||||
"""Sessions tracked in _active_sessions are left alone (legacy path)."""
|
||||
import tools.browser_tool as bt
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
|
|
@ -156,3 +169,240 @@ class TestReapOrphanedBrowserSessions:
|
|||
|
||||
_reap_orphaned_browser_sessions()
|
||||
assert not d.exists()
|
||||
|
||||
|
||||
class TestOwnerPidCrossProcess:
|
||||
"""Tests for owner_pid-based cross-process safe reaping.
|
||||
|
||||
The owner_pid file records which hermes process owns a daemon so that
|
||||
concurrent hermes processes don't reap each other's active browser
|
||||
sessions. Added to fix orphan accumulation from crashed processes.
|
||||
"""
|
||||
|
||||
def test_alive_owner_is_not_reaped_even_when_untracked(self, fake_tmpdir):
|
||||
"""Daemon with alive owner_pid is NOT reaped, even if not in our _active_sessions.
|
||||
|
||||
This is the core cross-process safety check: Process B scanning while
|
||||
Process A is using a browser must not kill A's daemon.
|
||||
"""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
# Use our own PID as the "owner" — guaranteed alive
|
||||
d = _make_socket_dir(
|
||||
fake_tmpdir, "h_alive_owner", pid=12345, owner_pid=os.getpid()
|
||||
)
|
||||
|
||||
kill_calls = []
|
||||
|
||||
def mock_kill(pid, sig):
|
||||
kill_calls.append((pid, sig))
|
||||
if pid == os.getpid() and sig == 0:
|
||||
return # real existence check: owner alive
|
||||
if sig == 0:
|
||||
return # pretend daemon exists too
|
||||
# Don't actually kill anything
|
||||
|
||||
with patch("os.kill", side_effect=mock_kill):
|
||||
_reap_orphaned_browser_sessions()
|
||||
|
||||
# We should have checked the owner (sig 0) but never tried to kill
|
||||
# the daemon.
|
||||
assert (12345, signal.SIGTERM) not in kill_calls
|
||||
# Dir should still exist
|
||||
assert d.exists()
|
||||
|
||||
def test_dead_owner_triggers_reap(self, fake_tmpdir):
|
||||
"""Daemon whose owner_pid is dead gets reaped."""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
# PID 999999999 almost certainly doesn't exist
|
||||
d = _make_socket_dir(
|
||||
fake_tmpdir, "h_dead_owner1", pid=12345, owner_pid=999999999
|
||||
)
|
||||
|
||||
kill_calls = []
|
||||
|
||||
def mock_kill(pid, sig):
|
||||
kill_calls.append((pid, sig))
|
||||
if pid == 999999999 and sig == 0:
|
||||
raise ProcessLookupError # owner dead
|
||||
if pid == 12345 and sig == 0:
|
||||
return # daemon still alive
|
||||
# SIGTERM to daemon — noop in test
|
||||
|
||||
with patch("os.kill", side_effect=mock_kill):
|
||||
_reap_orphaned_browser_sessions()
|
||||
|
||||
# Owner checked (returned dead), daemon checked (alive), daemon killed
|
||||
assert (999999999, 0) in kill_calls
|
||||
assert (12345, 0) in kill_calls
|
||||
assert (12345, signal.SIGTERM) in kill_calls
|
||||
# Dir cleaned up
|
||||
assert not d.exists()
|
||||
|
||||
def test_corrupt_owner_pid_falls_back_to_legacy(self, fake_tmpdir):
|
||||
"""Corrupt owner_pid file → fall back to tracked_names check."""
|
||||
import tools.browser_tool as bt
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
session_name = "h_corrupt_own"
|
||||
d = _make_socket_dir(fake_tmpdir, session_name, pid=12345)
|
||||
# Write garbage to owner_pid file
|
||||
(d / f"{session_name}.owner_pid").write_text("not-a-pid")
|
||||
|
||||
# Register session so legacy fallback leaves it alone
|
||||
bt._active_sessions["task"] = {"session_name": session_name}
|
||||
|
||||
kill_calls = []
|
||||
|
||||
def mock_kill(pid, sig):
|
||||
kill_calls.append((pid, sig))
|
||||
|
||||
with patch("os.kill", side_effect=mock_kill):
|
||||
_reap_orphaned_browser_sessions()
|
||||
|
||||
# Legacy path took over → tracked → not reaped
|
||||
assert (12345, signal.SIGTERM) not in kill_calls
|
||||
assert d.exists()
|
||||
|
||||
def test_owner_pid_permission_error_treated_as_alive(self, fake_tmpdir):
|
||||
"""If os.kill(owner, 0) raises PermissionError, treat owner as alive.
|
||||
|
||||
PermissionError means the PID exists but is owned by a different user —
|
||||
we must not assume the owner is dead (could kill someone else's daemon).
|
||||
"""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
d = _make_socket_dir(
|
||||
fake_tmpdir, "h_perm_owner1", pid=12345, owner_pid=22222
|
||||
)
|
||||
|
||||
kill_calls = []
|
||||
|
||||
def mock_kill(pid, sig):
|
||||
kill_calls.append((pid, sig))
|
||||
if pid == 22222 and sig == 0:
|
||||
raise PermissionError("not our user")
|
||||
|
||||
with patch("os.kill", side_effect=mock_kill):
|
||||
_reap_orphaned_browser_sessions()
|
||||
|
||||
# Must NOT have tried to kill the daemon
|
||||
assert (12345, signal.SIGTERM) not in kill_calls
|
||||
assert d.exists()
|
||||
|
||||
def test_write_owner_pid_creates_file_with_current_pid(
|
||||
self, fake_tmpdir, monkeypatch
|
||||
):
|
||||
"""_write_owner_pid(dir, session) writes <session>.owner_pid with os.getpid()."""
|
||||
import tools.browser_tool as bt
|
||||
|
||||
session_name = "h_ownertest01"
|
||||
socket_dir = fake_tmpdir / f"agent-browser-{session_name}"
|
||||
socket_dir.mkdir()
|
||||
|
||||
bt._write_owner_pid(str(socket_dir), session_name)
|
||||
|
||||
owner_pid_file = socket_dir / f"{session_name}.owner_pid"
|
||||
assert owner_pid_file.exists()
|
||||
assert owner_pid_file.read_text().strip() == str(os.getpid())
|
||||
|
||||
def test_write_owner_pid_is_idempotent(self, fake_tmpdir):
|
||||
"""Calling _write_owner_pid twice leaves a single owner_pid file."""
|
||||
import tools.browser_tool as bt
|
||||
|
||||
session_name = "h_idempot1234"
|
||||
socket_dir = fake_tmpdir / f"agent-browser-{session_name}"
|
||||
socket_dir.mkdir()
|
||||
|
||||
bt._write_owner_pid(str(socket_dir), session_name)
|
||||
bt._write_owner_pid(str(socket_dir), session_name)
|
||||
|
||||
files = list(socket_dir.glob("*.owner_pid"))
|
||||
assert len(files) == 1
|
||||
assert files[0].read_text().strip() == str(os.getpid())
|
||||
|
||||
def test_write_owner_pid_swallows_oserror(self, fake_tmpdir, monkeypatch):
|
||||
"""OSError (e.g. permission denied) doesn't propagate — the reaper
|
||||
falls back to the legacy tracked_names heuristic in that case.
|
||||
"""
|
||||
import tools.browser_tool as bt
|
||||
|
||||
def raise_oserror(*a, **kw):
|
||||
raise OSError("permission denied")
|
||||
|
||||
monkeypatch.setattr("builtins.open", raise_oserror)
|
||||
|
||||
# Must not raise
|
||||
bt._write_owner_pid(str(fake_tmpdir), "h_readonly123")
|
||||
|
||||
def test_run_browser_command_calls_write_owner_pid(
|
||||
self, fake_tmpdir, monkeypatch
|
||||
):
|
||||
"""_run_browser_command wires _write_owner_pid after mkdir."""
|
||||
import tools.browser_tool as bt
|
||||
|
||||
session_name = "h_wiringtest1"
|
||||
|
||||
# Short-circuit Popen so we exit after the owner_pid write
|
||||
class _FakePopen:
|
||||
def __init__(self, *a, **kw):
|
||||
raise RuntimeError("short-circuit after owner_pid")
|
||||
|
||||
monkeypatch.setattr(bt.subprocess, "Popen", _FakePopen)
|
||||
monkeypatch.setattr(bt, "_find_agent_browser", lambda: "/bin/true")
|
||||
monkeypatch.setattr(
|
||||
bt, "_requires_real_termux_browser_install", lambda *a: False
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
bt, "_get_session_info",
|
||||
lambda task_id: {"session_name": session_name},
|
||||
)
|
||||
|
||||
calls = []
|
||||
orig_write = bt._write_owner_pid
|
||||
|
||||
def _spy(*a, **kw):
|
||||
calls.append(a)
|
||||
orig_write(*a, **kw)
|
||||
|
||||
monkeypatch.setattr(bt, "_write_owner_pid", _spy)
|
||||
|
||||
with patch("tools.browser_tool._socket_safe_tmpdir", return_value=str(fake_tmpdir)):
|
||||
try:
|
||||
bt._run_browser_command(task_id="test_task", command="goto", args=[])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
assert calls, "_run_browser_command must call _write_owner_pid"
|
||||
# First positional arg is the socket_dir, second is the session_name
|
||||
socket_dir_arg, session_name_arg = calls[0][0], calls[0][1]
|
||||
assert session_name_arg == session_name
|
||||
assert session_name in socket_dir_arg
|
||||
|
||||
|
||||
class TestEmergencyCleanupRunsReaper:
|
||||
"""Verify atexit-registered cleanup sweeps orphans even without an active session."""
|
||||
|
||||
def test_emergency_cleanup_calls_reaper(self, fake_tmpdir, monkeypatch):
|
||||
"""_emergency_cleanup_all_sessions must call _reap_orphaned_browser_sessions."""
|
||||
import tools.browser_tool as bt
|
||||
|
||||
# Reset the _cleanup_done flag so the cleanup actually runs
|
||||
monkeypatch.setattr(bt, "_cleanup_done", False)
|
||||
|
||||
reaper_called = []
|
||||
orig_reaper = bt._reap_orphaned_browser_sessions
|
||||
|
||||
def _spy_reaper():
|
||||
reaper_called.append(True)
|
||||
orig_reaper()
|
||||
|
||||
monkeypatch.setattr(bt, "_reap_orphaned_browser_sessions", _spy_reaper)
|
||||
|
||||
# No active sessions — reaper should still run
|
||||
bt._emergency_cleanup_all_sessions()
|
||||
|
||||
assert reaper_called, (
|
||||
"Reaper must run on exit even with no active sessions"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -250,6 +250,15 @@ class TestWslHasImage:
|
|||
mock_run.return_value = MagicMock(stdout="False\n", returncode=0)
|
||||
assert _wsl_has_image() is False
|
||||
|
||||
def test_falls_back_to_get_clipboard_image(self):
|
||||
with patch("hermes_cli.clipboard.subprocess.run") as mock_run:
|
||||
mock_run.side_effect = [
|
||||
MagicMock(stdout="False\n", returncode=0),
|
||||
MagicMock(stdout="True\n", returncode=0),
|
||||
]
|
||||
assert _wsl_has_image() is True
|
||||
assert mock_run.call_count == 2
|
||||
|
||||
def test_powershell_not_found(self):
|
||||
with patch("hermes_cli.clipboard.subprocess.run", side_effect=FileNotFoundError):
|
||||
assert _wsl_has_image() is False
|
||||
|
|
@ -269,6 +278,18 @@ class TestWslSave:
|
|||
assert _wsl_save(dest) is True
|
||||
assert dest.read_bytes() == FAKE_PNG
|
||||
|
||||
def test_falls_back_to_get_clipboard_extraction(self, tmp_path):
|
||||
dest = tmp_path / "out.png"
|
||||
b64_png = base64.b64encode(FAKE_PNG).decode()
|
||||
with patch("hermes_cli.clipboard.subprocess.run") as mock_run:
|
||||
mock_run.side_effect = [
|
||||
MagicMock(stdout="", returncode=1),
|
||||
MagicMock(stdout=b64_png + "\n", returncode=0),
|
||||
]
|
||||
assert _wsl_save(dest) is True
|
||||
assert mock_run.call_count == 2
|
||||
assert dest.read_bytes() == FAKE_PNG
|
||||
|
||||
def test_no_image_returns_false(self, tmp_path):
|
||||
dest = tmp_path / "out.png"
|
||||
with patch("hermes_cli.clipboard.subprocess.run") as mock_run:
|
||||
|
|
@ -528,6 +549,16 @@ class TestWindowsHasImage:
|
|||
mock_run.return_value = MagicMock(stdout="False\n", returncode=0)
|
||||
assert _windows_has_image() is False
|
||||
|
||||
def test_falls_back_to_get_clipboard_image(self):
|
||||
with patch("hermes_cli.clipboard._get_ps_exe", return_value="powershell"):
|
||||
with patch("hermes_cli.clipboard.subprocess.run") as mock_run:
|
||||
mock_run.side_effect = [
|
||||
MagicMock(stdout="False\n", returncode=0),
|
||||
MagicMock(stdout="True\n", returncode=0),
|
||||
]
|
||||
assert _windows_has_image() is True
|
||||
assert mock_run.call_count == 2
|
||||
|
||||
def test_no_powershell_available(self):
|
||||
with patch("hermes_cli.clipboard._get_ps_exe", return_value=None):
|
||||
assert _windows_has_image() is False
|
||||
|
|
@ -559,6 +590,20 @@ class TestWindowsSave:
|
|||
assert _windows_save(dest) is True
|
||||
assert dest.read_bytes() == FAKE_PNG
|
||||
|
||||
def test_falls_back_to_filedrop_image(self, tmp_path):
|
||||
dest = tmp_path / "out.png"
|
||||
b64_png = base64.b64encode(FAKE_PNG).decode()
|
||||
with patch("hermes_cli.clipboard._get_ps_exe", return_value="powershell"):
|
||||
with patch("hermes_cli.clipboard.subprocess.run") as mock_run:
|
||||
mock_run.side_effect = [
|
||||
MagicMock(stdout="", returncode=1),
|
||||
MagicMock(stdout="", returncode=1),
|
||||
MagicMock(stdout=b64_png + "\n", returncode=0),
|
||||
]
|
||||
assert _windows_save(dest) is True
|
||||
assert mock_run.call_count == 3
|
||||
assert dest.read_bytes() == FAKE_PNG
|
||||
|
||||
def test_no_image_returns_false(self, tmp_path):
|
||||
dest = tmp_path / "out.png"
|
||||
with patch("hermes_cli.clipboard._get_ps_exe", return_value="powershell"):
|
||||
|
|
@ -734,6 +779,18 @@ class TestHasClipboardImage:
|
|||
assert has_clipboard_image() is True
|
||||
m.assert_called_once()
|
||||
|
||||
def test_wsl_falls_through_to_wayland_when_windows_path_empty(self):
|
||||
"""WSLg often bridges images to wl-paste even when powershell.exe check fails."""
|
||||
with patch("hermes_cli.clipboard.sys") as mock_sys:
|
||||
mock_sys.platform = "linux"
|
||||
with patch("hermes_cli.clipboard._is_wsl", return_value=True):
|
||||
with patch("hermes_cli.clipboard._wsl_has_image", return_value=False) as wsl:
|
||||
with patch.dict(os.environ, {"WAYLAND_DISPLAY": "wayland-0"}):
|
||||
with patch("hermes_cli.clipboard._wayland_has_image", return_value=True) as wl:
|
||||
assert has_clipboard_image() is True
|
||||
wsl.assert_called_once()
|
||||
wl.assert_called_once()
|
||||
|
||||
def test_linux_wayland_dispatch(self):
|
||||
with patch("hermes_cli.clipboard.sys") as mock_sys:
|
||||
mock_sys.platform = "linux"
|
||||
|
|
|
|||
62
tests/tools/test_feishu_tools.py
Normal file
62
tests/tools/test_feishu_tools.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
"""Tests for feishu_doc_tool and feishu_drive_tool — registration and schema validation."""
|
||||
|
||||
import importlib
|
||||
import unittest
|
||||
|
||||
from tools.registry import registry
|
||||
|
||||
# Trigger tool discovery so feishu tools get registered
|
||||
importlib.import_module("tools.feishu_doc_tool")
|
||||
importlib.import_module("tools.feishu_drive_tool")
|
||||
|
||||
|
||||
class TestFeishuToolRegistration(unittest.TestCase):
|
||||
"""Verify feishu tools are registered and have valid schemas."""
|
||||
|
||||
EXPECTED_TOOLS = {
|
||||
"feishu_doc_read": "feishu_doc",
|
||||
"feishu_drive_list_comments": "feishu_drive",
|
||||
"feishu_drive_list_comment_replies": "feishu_drive",
|
||||
"feishu_drive_reply_comment": "feishu_drive",
|
||||
"feishu_drive_add_comment": "feishu_drive",
|
||||
}
|
||||
|
||||
def test_all_tools_registered(self):
|
||||
for tool_name, toolset in self.EXPECTED_TOOLS.items():
|
||||
entry = registry.get_entry(tool_name)
|
||||
self.assertIsNotNone(entry, f"{tool_name} not registered")
|
||||
self.assertEqual(entry.toolset, toolset)
|
||||
|
||||
def test_schemas_have_required_fields(self):
|
||||
for tool_name in self.EXPECTED_TOOLS:
|
||||
entry = registry.get_entry(tool_name)
|
||||
schema = entry.schema
|
||||
self.assertIn("name", schema)
|
||||
self.assertEqual(schema["name"], tool_name)
|
||||
self.assertIn("description", schema)
|
||||
self.assertIn("parameters", schema)
|
||||
self.assertIn("type", schema["parameters"])
|
||||
self.assertEqual(schema["parameters"]["type"], "object")
|
||||
|
||||
def test_handlers_are_callable(self):
|
||||
for tool_name in self.EXPECTED_TOOLS:
|
||||
entry = registry.get_entry(tool_name)
|
||||
self.assertTrue(callable(entry.handler))
|
||||
|
||||
def test_doc_read_schema_params(self):
|
||||
entry = registry.get_entry("feishu_doc_read")
|
||||
props = entry.schema["parameters"].get("properties", {})
|
||||
self.assertIn("doc_token", props)
|
||||
|
||||
def test_drive_tools_require_file_token(self):
|
||||
for tool_name in self.EXPECTED_TOOLS:
|
||||
if tool_name == "feishu_doc_read":
|
||||
continue
|
||||
entry = registry.get_entry(tool_name)
|
||||
props = entry.schema["parameters"].get("properties", {})
|
||||
self.assertIn("file_token", props, f"{tool_name} missing file_token param")
|
||||
self.assertIn("file_type", props, f"{tool_name} missing file_type param")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
178
tests/tools/test_file_ops_cwd_tracking.py
Normal file
178
tests/tools/test_file_ops_cwd_tracking.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
"""Regression tests for cwd-staleness in ShellFileOperations.
|
||||
|
||||
The bug: ShellFileOperations captured the terminal env's cwd at __init__
|
||||
time and used that stale value for every subsequent _exec() call. When
|
||||
a user ran ``cd`` via the terminal tool, ``env.cwd`` updated but
|
||||
``ops.cwd`` did not. Relative paths passed to patch/read/write/search
|
||||
then targeted the wrong directory — typically the session's start dir
|
||||
instead of the current working directory.
|
||||
|
||||
Observed symptom: patch_replace() returned ``success=True`` with a
|
||||
plausible diff, but the user's ``git diff`` showed no change (because
|
||||
the patch landed in a different directory's copy of the same file).
|
||||
|
||||
Fix: _exec() now prefers the LIVE ``env.cwd`` over the init-time
|
||||
``self.cwd``. Explicit ``cwd`` arg to _exec still wins over both.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.file_operations import ShellFileOperations
|
||||
|
||||
|
||||
class _FakeEnv:
|
||||
"""Minimal terminal env that tracks cwd across execute() calls.
|
||||
|
||||
Matches the real ``BaseEnvironment`` contract: ``cwd`` attribute plus
|
||||
an ``execute(command, cwd=...)`` method whose return dict carries
|
||||
``output`` and ``returncode``. Commands are executed in a real
|
||||
subdirectory so file system effects match production.
|
||||
"""
|
||||
|
||||
def __init__(self, start_cwd: str):
|
||||
self.cwd = start_cwd
|
||||
self.calls: list[dict] = []
|
||||
|
||||
def execute(self, command: str, cwd: str = None, **kwargs) -> dict:
|
||||
import subprocess
|
||||
self.calls.append({"command": command, "cwd": cwd})
|
||||
# Simulate cd by updating self.cwd (the real env does the same
|
||||
# via _extract_cwd_from_output after a successful command)
|
||||
if command.strip().startswith("cd "):
|
||||
new = command.strip()[3:].strip()
|
||||
self.cwd = new
|
||||
return {"output": "", "returncode": 0}
|
||||
# Actually run the command — handle stdin via subprocess
|
||||
stdin_data = kwargs.get("stdin_data")
|
||||
proc = subprocess.run(
|
||||
["bash", "-c", command],
|
||||
cwd=cwd or self.cwd,
|
||||
input=stdin_data,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return {
|
||||
"output": proc.stdout + proc.stderr,
|
||||
"returncode": proc.returncode,
|
||||
}
|
||||
|
||||
|
||||
class TestShellFileOpsCwdTracking:
|
||||
"""_exec() must use live env.cwd, not the init-time cached cwd."""
|
||||
|
||||
def test_exec_follows_env_cwd_after_cd(self, tmp_path):
|
||||
dir_a = tmp_path / "a"
|
||||
dir_b = tmp_path / "b"
|
||||
dir_a.mkdir()
|
||||
dir_b.mkdir()
|
||||
(dir_a / "target.txt").write_text("content-a\n")
|
||||
(dir_b / "target.txt").write_text("content-b\n")
|
||||
|
||||
env = _FakeEnv(start_cwd=str(dir_a))
|
||||
ops = ShellFileOperations(env, cwd=str(dir_a))
|
||||
assert ops.cwd == str(dir_a) # init-time
|
||||
|
||||
# Simulate the user running `cd b` in terminal
|
||||
env.execute(f"cd {dir_b}")
|
||||
assert env.cwd == str(dir_b)
|
||||
assert ops.cwd == str(dir_a), "ops.cwd is still init-time (fallback only)"
|
||||
|
||||
# Reading a relative path must now hit dir_b, not dir_a
|
||||
result = ops._exec("cat target.txt")
|
||||
assert result.exit_code == 0
|
||||
assert "content-b" in result.stdout, (
|
||||
f"Expected dir_b content, got {result.stdout!r}. "
|
||||
"Stale ops.cwd leaked through — _exec must prefer env.cwd."
|
||||
)
|
||||
|
||||
def test_patch_replace_targets_live_cwd_not_init_cwd(self, tmp_path):
|
||||
"""The exact bug reported: patch lands in wrong dir after cd."""
|
||||
dir_a = tmp_path / "main"
|
||||
dir_b = tmp_path / "worktree"
|
||||
dir_a.mkdir()
|
||||
dir_b.mkdir()
|
||||
(dir_a / "t.txt").write_text("shared text\n")
|
||||
(dir_b / "t.txt").write_text("shared text\n")
|
||||
|
||||
env = _FakeEnv(start_cwd=str(dir_a))
|
||||
ops = ShellFileOperations(env, cwd=str(dir_a))
|
||||
|
||||
# Emulate user cd'ing into the worktree
|
||||
env.execute(f"cd {dir_b}")
|
||||
assert env.cwd == str(dir_b)
|
||||
|
||||
# Patch with a RELATIVE path — must target the worktree, not main
|
||||
result = ops.patch_replace("t.txt", "shared text\n", "PATCHED\n")
|
||||
assert result.success is True
|
||||
|
||||
assert (dir_b / "t.txt").read_text() == "PATCHED\n", (
|
||||
"patch must land in the live-cwd dir (worktree)"
|
||||
)
|
||||
assert (dir_a / "t.txt").read_text() == "shared text\n", (
|
||||
"patch must NOT land in the init-time dir (main)"
|
||||
)
|
||||
|
||||
def test_explicit_cwd_arg_still_wins(self, tmp_path):
|
||||
"""An explicit cwd= arg to _exec must override both env.cwd and self.cwd."""
|
||||
dir_a = tmp_path / "a"
|
||||
dir_b = tmp_path / "b"
|
||||
dir_c = tmp_path / "c"
|
||||
for d in (dir_a, dir_b, dir_c):
|
||||
d.mkdir()
|
||||
(dir_a / "target.txt").write_text("from-a\n")
|
||||
(dir_b / "target.txt").write_text("from-b\n")
|
||||
(dir_c / "target.txt").write_text("from-c\n")
|
||||
|
||||
env = _FakeEnv(start_cwd=str(dir_a))
|
||||
ops = ShellFileOperations(env, cwd=str(dir_a))
|
||||
env.execute(f"cd {dir_b}")
|
||||
|
||||
# Explicit cwd=dir_c should win over env.cwd (dir_b) and self.cwd (dir_a)
|
||||
result = ops._exec("cat target.txt", cwd=str(dir_c))
|
||||
assert "from-c" in result.stdout
|
||||
|
||||
def test_env_without_cwd_attribute_falls_back_to_self_cwd(self, tmp_path):
|
||||
"""Backends without a cwd attribute still work via init-time cwd."""
|
||||
dir_a = tmp_path / "fixed"
|
||||
dir_a.mkdir()
|
||||
(dir_a / "target.txt").write_text("fixed-content\n")
|
||||
|
||||
class _NoCwdEnv:
|
||||
def execute(self, command, cwd=None, **kwargs):
|
||||
import subprocess
|
||||
proc = subprocess.run(["bash", "-c", command], cwd=cwd,
|
||||
capture_output=True, text=True)
|
||||
return {"output": proc.stdout, "returncode": proc.returncode}
|
||||
|
||||
env = _NoCwdEnv()
|
||||
ops = ShellFileOperations(env, cwd=str(dir_a))
|
||||
result = ops._exec("cat target.txt")
|
||||
assert result.exit_code == 0
|
||||
assert "fixed-content" in result.stdout
|
||||
|
||||
def test_patch_returns_success_only_when_file_actually_written(self, tmp_path):
|
||||
"""Safety rail: patch_replace success must reflect the real file state.
|
||||
|
||||
This test doesn't trigger the bug directly (it would require manual
|
||||
corruption of the write), but it pins the invariant: when
|
||||
patch_replace returns success=True, the file on disk matches the
|
||||
intended content. If a future write_file change ever regresses,
|
||||
this test catches it.
|
||||
"""
|
||||
target = tmp_path / "file.txt"
|
||||
target.write_text("old content\n")
|
||||
|
||||
env = _FakeEnv(start_cwd=str(tmp_path))
|
||||
ops = ShellFileOperations(env, cwd=str(tmp_path))
|
||||
|
||||
result = ops.patch_replace(str(target), "old content\n", "new content\n")
|
||||
assert result.success is True
|
||||
assert result.error is None
|
||||
assert target.read_text() == "new content\n", (
|
||||
"patch_replace claimed success but file wasn't written correctly"
|
||||
)
|
||||
|
|
@ -86,6 +86,7 @@ class TestProviderEnvBlocklist:
|
|||
"MINIMAX_API_KEY": "mm-key",
|
||||
"MINIMAX_CN_API_KEY": "mmcn-key",
|
||||
"DEEPSEEK_API_KEY": "deepseek-key",
|
||||
"NVIDIA_API_KEY": "nvidia-key",
|
||||
}
|
||||
result_env = _run_with_env(extra_os_env=registry_vars)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
from gateway.config import Platform
|
||||
from tools.send_message_tool import (
|
||||
_derive_forum_thread_name,
|
||||
_parse_target_ref,
|
||||
_send_discord,
|
||||
_send_matrix_via_adapter,
|
||||
|
|
@ -1234,3 +1235,419 @@ class TestSendMatrixUrlEncoding:
|
|||
put_url = mock_session.put.call_args[0][0]
|
||||
assert "%21HLOQwxYGgFPMPJUSNR%3Amatrix.org" in put_url
|
||||
assert "!HLOQwxYGgFPMPJUSNR:matrix.org" not in put_url
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _derive_forum_thread_name
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeriveForumThreadName:
|
||||
def test_single_line_message(self):
|
||||
assert _derive_forum_thread_name("Hello world") == "Hello world"
|
||||
|
||||
def test_multi_line_uses_first_line(self):
|
||||
assert _derive_forum_thread_name("First line\nSecond line") == "First line"
|
||||
|
||||
def test_strips_markdown_heading(self):
|
||||
assert _derive_forum_thread_name("## My Heading") == "My Heading"
|
||||
|
||||
def test_strips_multiple_hash_levels(self):
|
||||
assert _derive_forum_thread_name("### Deep heading") == "Deep heading"
|
||||
|
||||
def test_empty_message_falls_back_to_default(self):
|
||||
assert _derive_forum_thread_name("") == "New Post"
|
||||
|
||||
def test_whitespace_only_falls_back(self):
|
||||
assert _derive_forum_thread_name(" \n ") == "New Post"
|
||||
|
||||
def test_hash_only_falls_back(self):
|
||||
assert _derive_forum_thread_name("###") == "New Post"
|
||||
|
||||
def test_truncates_to_100_chars(self):
|
||||
long_title = "A" * 200
|
||||
result = _derive_forum_thread_name(long_title)
|
||||
assert len(result) == 100
|
||||
|
||||
def test_strips_whitespace_around_first_line(self):
|
||||
assert _derive_forum_thread_name(" Title \nBody") == "Title"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _send_discord with forum channel support
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendDiscordForum:
|
||||
"""_send_discord creates thread posts for forum channels."""
|
||||
|
||||
@staticmethod
|
||||
def _build_mock(response_status, response_data=None, response_text="error body"):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status = response_status
|
||||
mock_resp.json = AsyncMock(return_value=response_data or {})
|
||||
mock_resp.text = AsyncMock(return_value=response_text)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_session.post = MagicMock(return_value=mock_resp)
|
||||
mock_session.get = MagicMock(return_value=mock_resp)
|
||||
|
||||
return mock_session, mock_resp
|
||||
|
||||
def test_directory_forum_creates_thread(self):
|
||||
"""Directory says 'forum' — creates a thread post."""
|
||||
thread_data = {
|
||||
"id": "t123",
|
||||
"message": {"id": "m456"},
|
||||
}
|
||||
mock_session, _ = self._build_mock(200, response_data=thread_data)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session), \
|
||||
patch("gateway.channel_directory.lookup_channel_type", return_value="forum"):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "forum_ch", "Hello forum")
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["thread_id"] == "t123"
|
||||
assert result["message_id"] == "m456"
|
||||
# Should POST to threads endpoint, not messages
|
||||
call_url = mock_session.post.call_args.args[0]
|
||||
assert "/threads" in call_url
|
||||
assert "/messages" not in call_url
|
||||
|
||||
def test_directory_forum_skips_probe(self):
|
||||
"""When directory says 'forum', no GET probe is made."""
|
||||
thread_data = {"id": "t123", "message": {"id": "m456"}}
|
||||
mock_session, _ = self._build_mock(200, response_data=thread_data)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session), \
|
||||
patch("gateway.channel_directory.lookup_channel_type", return_value="forum"):
|
||||
asyncio.run(
|
||||
_send_discord("tok", "forum_ch", "Hello")
|
||||
)
|
||||
|
||||
# get() should never be called — directory resolved the type
|
||||
mock_session.get.assert_not_called()
|
||||
|
||||
def test_directory_channel_skips_forum(self):
|
||||
"""When directory says 'channel', sends via normal messages endpoint."""
|
||||
mock_session, _ = self._build_mock(200, response_data={"id": "msg1"})
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session), \
|
||||
patch("gateway.channel_directory.lookup_channel_type", return_value="channel"):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "ch1", "Hello")
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
call_url = mock_session.post.call_args.args[0]
|
||||
assert "/messages" in call_url
|
||||
assert "/threads" not in call_url
|
||||
|
||||
def test_directory_none_probes_and_detects_forum(self):
|
||||
"""When directory has no entry, probes GET /channels/{id} and detects type 15."""
|
||||
probe_resp = MagicMock()
|
||||
probe_resp.status = 200
|
||||
probe_resp.json = AsyncMock(return_value={"type": 15})
|
||||
probe_resp.__aenter__ = AsyncMock(return_value=probe_resp)
|
||||
probe_resp.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
thread_data = {"id": "t999", "message": {"id": "m888"}}
|
||||
thread_resp = MagicMock()
|
||||
thread_resp.status = 200
|
||||
thread_resp.json = AsyncMock(return_value=thread_data)
|
||||
thread_resp.text = AsyncMock(return_value="")
|
||||
thread_resp.__aenter__ = AsyncMock(return_value=thread_resp)
|
||||
thread_resp.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
probe_session = MagicMock()
|
||||
probe_session.__aenter__ = AsyncMock(return_value=probe_session)
|
||||
probe_session.__aexit__ = AsyncMock(return_value=None)
|
||||
probe_session.get = MagicMock(return_value=probe_resp)
|
||||
|
||||
thread_session = MagicMock()
|
||||
thread_session.__aenter__ = AsyncMock(return_value=thread_session)
|
||||
thread_session.__aexit__ = AsyncMock(return_value=None)
|
||||
thread_session.post = MagicMock(return_value=thread_resp)
|
||||
|
||||
session_iter = iter([probe_session, thread_session])
|
||||
|
||||
with patch("aiohttp.ClientSession", side_effect=lambda **kw: next(session_iter)), \
|
||||
patch("gateway.channel_directory.lookup_channel_type", return_value=None):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "forum_ch", "Hello probe")
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["thread_id"] == "t999"
|
||||
|
||||
def test_directory_lookup_exception_falls_through_to_probe(self):
|
||||
"""When lookup_channel_type raises, falls through to API probe."""
|
||||
mock_session, _ = self._build_mock(200, response_data={"id": "msg1"})
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session), \
|
||||
patch("gateway.channel_directory.lookup_channel_type", side_effect=Exception("io error")):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "ch1", "Hello")
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
# Falls through to probe (GET)
|
||||
mock_session.get.assert_called_once()
|
||||
|
||||
def test_forum_thread_creation_error(self):
|
||||
"""Forum thread creation returning non-200/201 returns an error dict."""
|
||||
mock_session, _ = self._build_mock(403, response_text="Forbidden")
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session), \
|
||||
patch("gateway.channel_directory.lookup_channel_type", return_value="forum"):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "forum_ch", "Hello")
|
||||
)
|
||||
|
||||
assert "error" in result
|
||||
assert "403" in result["error"]
|
||||
|
||||
|
||||
|
||||
class TestSendToPlatformDiscordForum:
|
||||
"""_send_to_platform delegates forum detection to _send_discord."""
|
||||
|
||||
def test_send_to_platform_discord_delegates_to_send_discord(self):
|
||||
"""Discord messages are routed through _send_discord, which handles forum detection."""
|
||||
send_mock = AsyncMock(return_value={"success": True, "message_id": "1"})
|
||||
|
||||
with patch("tools.send_message_tool._send_discord", send_mock):
|
||||
result = asyncio.run(
|
||||
_send_to_platform(
|
||||
Platform.DISCORD,
|
||||
SimpleNamespace(enabled=True, token="tok", extra={}),
|
||||
"forum_ch",
|
||||
"Hello forum",
|
||||
)
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
send_mock.assert_awaited_once_with(
|
||||
"tok", "forum_ch", "Hello forum", media_files=[], thread_id=None,
|
||||
)
|
||||
|
||||
def test_send_to_platform_discord_with_thread_id(self):
|
||||
"""Thread ID is still passed through when sending to Discord."""
|
||||
send_mock = AsyncMock(return_value={"success": True, "message_id": "1"})
|
||||
|
||||
with patch("tools.send_message_tool._send_discord", send_mock):
|
||||
result = asyncio.run(
|
||||
_send_to_platform(
|
||||
Platform.DISCORD,
|
||||
SimpleNamespace(enabled=True, token="tok", extra={}),
|
||||
"ch1",
|
||||
"Hello thread",
|
||||
thread_id="17585",
|
||||
)
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
_, call_kwargs = send_mock.await_args
|
||||
assert call_kwargs["thread_id"] == "17585"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _send_discord forum + media multipart upload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendDiscordForumMedia:
|
||||
"""_send_discord uploads media as part of the starter message when the target is a forum."""
|
||||
|
||||
@staticmethod
|
||||
def _build_thread_resp(thread_id="th_999", msg_id="msg_500"):
|
||||
resp = MagicMock()
|
||||
resp.status = 201
|
||||
resp.json = AsyncMock(return_value={"id": thread_id, "message": {"id": msg_id}})
|
||||
resp.text = AsyncMock(return_value="")
|
||||
resp.__aenter__ = AsyncMock(return_value=resp)
|
||||
resp.__aexit__ = AsyncMock(return_value=None)
|
||||
return resp
|
||||
|
||||
def test_forum_with_media_uses_multipart(self, tmp_path, monkeypatch):
|
||||
"""Forum + media → single multipart POST to /threads carrying the starter + files."""
|
||||
from tools import send_message_tool as smt
|
||||
|
||||
img = tmp_path / "photo.png"
|
||||
img.write_bytes(b"\x89PNGbytes")
|
||||
|
||||
monkeypatch.setattr(smt, "lookup_channel_type", lambda p, cid: "forum", raising=False)
|
||||
monkeypatch.setattr(
|
||||
"gateway.channel_directory.lookup_channel_type", lambda p, cid: "forum"
|
||||
)
|
||||
|
||||
thread_resp = self._build_thread_resp()
|
||||
session = MagicMock()
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock(return_value=None)
|
||||
session.post = MagicMock(return_value=thread_resp)
|
||||
|
||||
post_calls = []
|
||||
orig_post = session.post
|
||||
|
||||
def track_post(url, **kwargs):
|
||||
post_calls.append({"url": url, "kwargs": kwargs})
|
||||
return thread_resp
|
||||
|
||||
session.post = MagicMock(side_effect=track_post)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=session):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "forum_ch", "Thread title\nbody", media_files=[(str(img), False)])
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["thread_id"] == "th_999"
|
||||
assert result["message_id"] == "msg_500"
|
||||
# Exactly one POST — the combined thread-creation + attachments call
|
||||
assert len(post_calls) == 1
|
||||
assert post_calls[0]["url"].endswith("/threads")
|
||||
# Multipart form, not JSON
|
||||
assert post_calls[0]["kwargs"].get("data") is not None
|
||||
assert post_calls[0]["kwargs"].get("json") is None
|
||||
|
||||
def test_forum_without_media_still_json_only(self, tmp_path, monkeypatch):
|
||||
"""Forum + no media → JSON POST (no multipart overhead)."""
|
||||
monkeypatch.setattr(
|
||||
"gateway.channel_directory.lookup_channel_type", lambda p, cid: "forum"
|
||||
)
|
||||
|
||||
thread_resp = self._build_thread_resp("t1", "m1")
|
||||
session = MagicMock()
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
post_calls = []
|
||||
|
||||
def track_post(url, **kwargs):
|
||||
post_calls.append({"url": url, "kwargs": kwargs})
|
||||
return thread_resp
|
||||
|
||||
session.post = MagicMock(side_effect=track_post)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=session):
|
||||
result = asyncio.run(_send_discord("tok", "forum_ch", "Hello forum"))
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(post_calls) == 1
|
||||
# JSON path, no multipart
|
||||
assert post_calls[0]["kwargs"].get("json") is not None
|
||||
assert post_calls[0]["kwargs"].get("data") is None
|
||||
|
||||
def test_forum_missing_media_file_collected_as_warning(self, tmp_path, monkeypatch):
|
||||
"""Missing media files produce warnings but the thread is still created."""
|
||||
monkeypatch.setattr(
|
||||
"gateway.channel_directory.lookup_channel_type", lambda p, cid: "forum"
|
||||
)
|
||||
|
||||
thread_resp = self._build_thread_resp()
|
||||
session = MagicMock()
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock(return_value=None)
|
||||
session.post = MagicMock(return_value=thread_resp)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=session):
|
||||
result = asyncio.run(
|
||||
_send_discord(
|
||||
"tok", "forum_ch", "hi",
|
||||
media_files=[("/nonexistent/does-not-exist.png", False)],
|
||||
)
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert "warnings" in result
|
||||
assert any("not found" in w for w in result["warnings"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for the process-local forum-probe cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestForumProbeCache:
|
||||
"""_DISCORD_CHANNEL_TYPE_PROBE_CACHE memoizes forum detection results."""
|
||||
|
||||
def setup_method(self):
|
||||
from tools import send_message_tool as smt
|
||||
smt._DISCORD_CHANNEL_TYPE_PROBE_CACHE.clear()
|
||||
|
||||
def test_cache_round_trip(self):
|
||||
from tools.send_message_tool import (
|
||||
_probe_is_forum_cached,
|
||||
_remember_channel_is_forum,
|
||||
)
|
||||
assert _probe_is_forum_cached("xyz") is None
|
||||
_remember_channel_is_forum("xyz", True)
|
||||
assert _probe_is_forum_cached("xyz") is True
|
||||
_remember_channel_is_forum("xyz", False)
|
||||
assert _probe_is_forum_cached("xyz") is False
|
||||
|
||||
def test_probe_result_is_memoized(self, monkeypatch):
|
||||
"""An API-probed channel type is cached so subsequent sends skip the probe."""
|
||||
monkeypatch.setattr(
|
||||
"gateway.channel_directory.lookup_channel_type", lambda p, cid: None
|
||||
)
|
||||
|
||||
# First probe response: type=15 (forum)
|
||||
probe_resp = MagicMock()
|
||||
probe_resp.status = 200
|
||||
probe_resp.json = AsyncMock(return_value={"type": 15})
|
||||
probe_resp.__aenter__ = AsyncMock(return_value=probe_resp)
|
||||
probe_resp.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
thread_resp = MagicMock()
|
||||
thread_resp.status = 201
|
||||
thread_resp.json = AsyncMock(return_value={"id": "t1", "message": {"id": "m1"}})
|
||||
thread_resp.__aenter__ = AsyncMock(return_value=thread_resp)
|
||||
thread_resp.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
probe_session = MagicMock()
|
||||
probe_session.__aenter__ = AsyncMock(return_value=probe_session)
|
||||
probe_session.__aexit__ = AsyncMock(return_value=None)
|
||||
probe_session.get = MagicMock(return_value=probe_resp)
|
||||
|
||||
thread_session = MagicMock()
|
||||
thread_session.__aenter__ = AsyncMock(return_value=thread_session)
|
||||
thread_session.__aexit__ = AsyncMock(return_value=None)
|
||||
thread_session.post = MagicMock(return_value=thread_resp)
|
||||
|
||||
# Two _send_discord calls: first does probe + thread-create; second should skip probe
|
||||
from tools import send_message_tool as smt
|
||||
|
||||
sessions_created = []
|
||||
|
||||
def session_factory(**kwargs):
|
||||
# Alternate: each new ClientSession() call returns a probe_session, thread_session pair
|
||||
idx = len(sessions_created)
|
||||
sessions_created.append(idx)
|
||||
# Returns the same mocks; the real code opens a probe session then a thread session.
|
||||
# Hand out probe_session if this is the first time called within _send_discord,
|
||||
# otherwise thread_session.
|
||||
if idx % 2 == 0:
|
||||
return probe_session
|
||||
return thread_session
|
||||
|
||||
with patch("aiohttp.ClientSession", side_effect=session_factory):
|
||||
result1 = asyncio.run(_send_discord("tok", "ch1", "first"))
|
||||
assert result1["success"] is True
|
||||
assert smt._probe_is_forum_cached("ch1") is True
|
||||
|
||||
# Second call: cache hits, no new probe session needed. We need to only
|
||||
# return thread_session now since probe is skipped.
|
||||
sessions_created.clear()
|
||||
with patch("aiohttp.ClientSession", return_value=thread_session):
|
||||
result2 = asyncio.run(_send_discord("tok", "ch1", "second"))
|
||||
assert result2["success"] is True
|
||||
# Only one session opened (thread creation) — no probe session this time
|
||||
# (verified by not raising from our side_effect exhaustion)
|
||||
|
|
|
|||
0
tests/tui_gateway/__init__.py
Normal file
0
tests/tui_gateway/__init__.py
Normal file
233
tests/tui_gateway/test_protocol.py
Normal file
233
tests/tui_gateway/test_protocol.py
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
"""Tests for tui_gateway JSON-RPC protocol plumbing."""
|
||||
|
||||
import io
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
_original_stdout = sys.stdout
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _restore_stdout():
|
||||
yield
|
||||
sys.stdout = _original_stdout
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def server():
|
||||
with patch.dict("sys.modules", {
|
||||
"hermes_constants": MagicMock(get_hermes_home=MagicMock(return_value="/tmp/hermes_test")),
|
||||
"hermes_cli.env_loader": MagicMock(),
|
||||
"hermes_cli.banner": MagicMock(),
|
||||
"hermes_state": MagicMock(),
|
||||
}):
|
||||
import importlib
|
||||
mod = importlib.import_module("tui_gateway.server")
|
||||
yield mod
|
||||
mod._sessions.clear()
|
||||
mod._pending.clear()
|
||||
mod._answers.clear()
|
||||
mod._methods.clear()
|
||||
importlib.reload(mod)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def capture(server):
|
||||
"""Redirect server's real stdout to a StringIO and return (server, buf)."""
|
||||
buf = io.StringIO()
|
||||
server._real_stdout = buf
|
||||
return server, buf
|
||||
|
||||
|
||||
# ── JSON-RPC envelope ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_unknown_method(server):
|
||||
resp = server.handle_request({"id": "1", "method": "bogus"})
|
||||
assert resp["error"]["code"] == -32601
|
||||
|
||||
|
||||
def test_ok_envelope(server):
|
||||
assert server._ok("r1", {"x": 1}) == {
|
||||
"jsonrpc": "2.0", "id": "r1", "result": {"x": 1},
|
||||
}
|
||||
|
||||
|
||||
def test_err_envelope(server):
|
||||
assert server._err("r2", 4001, "nope") == {
|
||||
"jsonrpc": "2.0", "id": "r2", "error": {"code": 4001, "message": "nope"},
|
||||
}
|
||||
|
||||
|
||||
# ── write_json ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_write_json(capture):
|
||||
server, buf = capture
|
||||
assert server.write_json({"test": True})
|
||||
assert json.loads(buf.getvalue()) == {"test": True}
|
||||
|
||||
|
||||
def test_write_json_broken_pipe(server):
|
||||
class _Broken:
|
||||
def write(self, _): raise BrokenPipeError
|
||||
def flush(self): raise BrokenPipeError
|
||||
|
||||
server._real_stdout = _Broken()
|
||||
assert server.write_json({"x": 1}) is False
|
||||
|
||||
|
||||
# ── _emit ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_emit_with_payload(capture):
|
||||
server, buf = capture
|
||||
server._emit("test.event", "s1", {"key": "val"})
|
||||
msg = json.loads(buf.getvalue())
|
||||
|
||||
assert msg["method"] == "event"
|
||||
assert msg["params"]["type"] == "test.event"
|
||||
assert msg["params"]["session_id"] == "s1"
|
||||
assert msg["params"]["payload"]["key"] == "val"
|
||||
|
||||
|
||||
def test_emit_without_payload(capture):
|
||||
server, buf = capture
|
||||
server._emit("ping", "s2")
|
||||
|
||||
assert "payload" not in json.loads(buf.getvalue())["params"]
|
||||
|
||||
|
||||
# ── Blocking prompt round-trip ───────────────────────────────────────
|
||||
|
||||
|
||||
def test_block_and_respond(capture):
|
||||
server, _ = capture
|
||||
result = [None]
|
||||
|
||||
threading.Thread(
|
||||
target=lambda: result.__setitem__(0, server._block("test.prompt", "s1", {"q": "?"}, timeout=5)),
|
||||
).start()
|
||||
|
||||
for _ in range(100):
|
||||
if server._pending:
|
||||
break
|
||||
threading.Event().wait(0.01)
|
||||
|
||||
rid = next(iter(server._pending))
|
||||
server._answers[rid] = "my_answer"
|
||||
server._pending[rid].set()
|
||||
|
||||
threading.Event().wait(0.1)
|
||||
assert result[0] == "my_answer"
|
||||
|
||||
|
||||
def test_clear_pending(server):
|
||||
ev = threading.Event()
|
||||
server._pending["r1"] = ev
|
||||
server._clear_pending()
|
||||
|
||||
assert ev.is_set()
|
||||
assert server._answers["r1"] == ""
|
||||
|
||||
|
||||
# ── Session lookup ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_sess_missing(server):
|
||||
_, err = server._sess({"session_id": "nope"}, "r1")
|
||||
assert err["error"]["code"] == 4001
|
||||
|
||||
|
||||
def test_sess_found(server):
|
||||
server._sessions["abc"] = {"agent": MagicMock()}
|
||||
s, err = server._sess({"session_id": "abc"}, "r1")
|
||||
|
||||
assert s is not None
|
||||
assert err is None
|
||||
|
||||
|
||||
# ── session.resume payload ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_session_resume_returns_hydrated_messages(server, monkeypatch):
|
||||
class _DB:
|
||||
def get_session(self, _sid):
|
||||
return {"id": "20260409_010101_abc123"}
|
||||
|
||||
def get_session_by_title(self, _title):
|
||||
return None
|
||||
|
||||
def reopen_session(self, _sid):
|
||||
return None
|
||||
|
||||
def get_messages_as_conversation(self, _sid):
|
||||
return [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "yo"},
|
||||
{"role": "tool", "content": "searched"},
|
||||
{"role": "assistant", "content": " "},
|
||||
{"role": "assistant", "content": None},
|
||||
{"role": "narrator", "content": "skip"},
|
||||
]
|
||||
|
||||
monkeypatch.setattr(server, "_get_db", lambda: _DB())
|
||||
monkeypatch.setattr(server, "_make_agent", lambda sid, key, session_id=None: object())
|
||||
monkeypatch.setattr(server, "_init_session", lambda sid, key, agent, history, cols=80: None)
|
||||
monkeypatch.setattr(server, "_session_info", lambda _agent: {"model": "test/model"})
|
||||
|
||||
resp = server.handle_request(
|
||||
{
|
||||
"id": "r1",
|
||||
"method": "session.resume",
|
||||
"params": {"session_id": "20260409_010101_abc123", "cols": 100},
|
||||
}
|
||||
)
|
||||
|
||||
assert "error" not in resp
|
||||
assert resp["result"]["message_count"] == 3
|
||||
assert resp["result"]["messages"] == [
|
||||
{"role": "user", "text": "hello"},
|
||||
{"role": "assistant", "text": "yo"},
|
||||
{"role": "tool", "name": "tool", "context": ""},
|
||||
]
|
||||
|
||||
|
||||
# ── Config I/O ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_config_load_missing(server, tmp_path):
|
||||
server._hermes_home = tmp_path
|
||||
assert server._load_cfg() == {}
|
||||
|
||||
|
||||
def test_config_roundtrip(server, tmp_path):
|
||||
server._hermes_home = tmp_path
|
||||
server._save_cfg({"model": "test/model"})
|
||||
assert server._load_cfg()["model"] == "test/model"
|
||||
|
||||
|
||||
# ── _cli_exec_blocked ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize("argv", [
|
||||
[],
|
||||
["setup"],
|
||||
["gateway"],
|
||||
["sessions", "browse"],
|
||||
["config", "edit"],
|
||||
])
|
||||
def test_cli_exec_blocked(server, argv):
|
||||
assert server._cli_exec_blocked(argv) is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("argv", [
|
||||
["version"],
|
||||
["sessions", "list"],
|
||||
])
|
||||
def test_cli_exec_allowed(server, argv):
|
||||
assert server._cli_exec_blocked(argv) is None
|
||||
67
tests/tui_gateway/test_render.py
Normal file
67
tests/tui_gateway/test_render.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
"""Tests for tui_gateway.render — rendering bridge fallback behavior."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tui_gateway.render import make_stream_renderer, render_diff, render_message
|
||||
|
||||
|
||||
def _stub_rich(mock_mod):
|
||||
return patch.dict("sys.modules", {"agent.rich_output": mock_mod})
|
||||
|
||||
|
||||
def _no_rich():
|
||||
return patch.dict("sys.modules", {"agent.rich_output": None})
|
||||
|
||||
|
||||
# ── render_message ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_render_message_none_without_module():
|
||||
with _no_rich():
|
||||
assert render_message("hello") is None
|
||||
|
||||
|
||||
def test_render_message_formatted():
|
||||
mod = MagicMock()
|
||||
mod.format_response.return_value = "<b>hi</b>"
|
||||
|
||||
with _stub_rich(mod):
|
||||
assert render_message("hi", 100) == "<b>hi</b>"
|
||||
|
||||
|
||||
def test_render_message_type_error_fallback():
|
||||
mod = MagicMock()
|
||||
mod.format_response.side_effect = [TypeError, "fallback"]
|
||||
|
||||
with _stub_rich(mod):
|
||||
assert render_message("hi") == "fallback"
|
||||
|
||||
|
||||
def test_render_message_exception_returns_none():
|
||||
mod = MagicMock()
|
||||
mod.format_response.side_effect = RuntimeError
|
||||
|
||||
with _stub_rich(mod):
|
||||
assert render_message("hi") is None
|
||||
|
||||
|
||||
# ── render_diff / make_stream_renderer ───────────────────────────────
|
||||
|
||||
|
||||
def test_render_diff_none_without_module():
|
||||
with _no_rich():
|
||||
assert render_diff("+line") is None
|
||||
|
||||
|
||||
def test_stream_renderer_none_without_module():
|
||||
with _no_rich():
|
||||
assert make_stream_renderer() is None
|
||||
|
||||
|
||||
def test_stream_renderer_returns_instance():
|
||||
renderer = MagicMock()
|
||||
mod = MagicMock()
|
||||
mod.StreamingRenderer.return_value = renderer
|
||||
|
||||
with _stub_rich(mod):
|
||||
assert make_stream_renderer(120) is renderer
|
||||
Loading…
Add table
Add a link
Reference in a new issue