Merge branch 'main' into fix/show-reasoning-per-platform

This commit is contained in:
Leon 2026-04-18 11:38:56 +08:00
commit ed7b42f889
401 changed files with 66696 additions and 1966 deletions

View file

@ -42,9 +42,10 @@ class TestToolProgressCallback:
def test_emits_tool_call_start(self, mock_conn, event_loop_fixture):
"""Tool progress should emit a ToolCallStart update."""
tool_call_ids = {}
tool_call_meta = {}
loop = event_loop_fixture
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
# Run callback in the event loop context
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
@ -66,9 +67,10 @@ class TestToolProgressCallback:
def test_handles_string_args(self, mock_conn, event_loop_fixture):
"""If args is a JSON string, it should be parsed."""
tool_call_ids = {}
tool_call_meta = {}
loop = event_loop_fixture
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
future = MagicMock(spec=Future)
@ -82,9 +84,10 @@ class TestToolProgressCallback:
def test_handles_non_dict_args(self, mock_conn, event_loop_fixture):
"""If args is not a dict, it should be wrapped."""
tool_call_ids = {}
tool_call_meta = {}
loop = event_loop_fixture
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
future = MagicMock(spec=Future)
@ -98,10 +101,11 @@ class TestToolProgressCallback:
def test_duplicate_same_name_tool_calls_use_fifo_ids(self, mock_conn, event_loop_fixture):
"""Multiple same-name tool calls should be tracked independently in order."""
tool_call_ids = {}
tool_call_meta = {}
loop = event_loop_fixture
progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
future = MagicMock(spec=Future)
@ -163,7 +167,7 @@ class TestStepCallback:
tool_call_ids = {"terminal": "tc-abc123"}
loop = event_loop_fixture
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
future = MagicMock(spec=Future)
@ -181,7 +185,7 @@ class TestStepCallback:
tool_call_ids = {}
loop = event_loop_fixture
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
cb(1, [{"name": "unknown_tool", "result": "ok"}])
@ -193,7 +197,7 @@ class TestStepCallback:
tool_call_ids = {"read_file": "tc-def456"}
loop = event_loop_fixture
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
future = MagicMock(spec=Future)
@ -212,7 +216,7 @@ class TestStepCallback:
tool_call_ids = {"terminal": deque(["tc-xyz789"])}
loop = event_loop_fixture
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
patch("acp_adapter.events.build_tool_complete") as mock_btc:
@ -224,7 +228,7 @@ class TestStepCallback:
cb(1, [{"name": "terminal", "result": '{"output": "hello"}'}])
mock_btc.assert_called_once_with(
"tc-xyz789", "terminal", result='{"output": "hello"}'
"tc-xyz789", "terminal", result='{"output": "hello"}', function_args=None, snapshot=None
)
def test_none_result_passed_through(self, mock_conn, event_loop_fixture):
@ -234,7 +238,7 @@ class TestStepCallback:
tool_call_ids = {"web_search": deque(["tc-aaa"])}
loop = event_loop_fixture
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
patch("acp_adapter.events.build_tool_complete") as mock_btc:
@ -244,7 +248,50 @@ class TestStepCallback:
cb(1, [{"name": "web_search", "result": None}])
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None)
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None, function_args=None, snapshot=None)
def test_step_callback_passes_arguments_and_snapshot(self, mock_conn, event_loop_fixture):
from collections import deque
tool_call_ids = {"write_file": deque(["tc-write"])}
tool_call_meta = {"tc-write": {"args": {"path": "fallback.txt"}, "snapshot": "snap"}}
loop = event_loop_fixture
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
patch("acp_adapter.events.build_tool_complete") as mock_btc:
future = MagicMock(spec=Future)
future.result.return_value = None
mock_rcts.return_value = future
cb(1, [{"name": "write_file", "result": '{"bytes_written": 23}', "arguments": {"path": "diff-test.txt"}}])
mock_btc.assert_called_once_with(
"tc-write",
"write_file",
result='{"bytes_written": 23}',
function_args={"path": "diff-test.txt"},
snapshot="snap",
)
def test_tool_progress_captures_snapshot_metadata(self, mock_conn, event_loop_fixture):
tool_call_ids = {}
tool_call_meta = {}
loop = event_loop_fixture
with patch("acp_adapter.events.make_tool_call_id", return_value="tc-meta"), \
patch("acp_adapter.events._send_update") as mock_send, \
patch("agent.display.capture_local_edit_snapshot", return_value="snapshot"):
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
cb("tool.started", "write_file", None, {"path": "diff-test.txt", "content": "hello"})
assert list(tool_call_ids["write_file"]) == ["tc-meta"]
assert tool_call_meta["tc-meta"] == {
"args": {"path": "diff-test.txt", "content": "hello"},
"snapshot": "snapshot",
}
mock_send.assert_called_once()
# ---------------------------------------------------------------------------

View file

@ -29,6 +29,7 @@ from acp.schema import (
from acp_adapter.server import HermesACPAgent
from acp_adapter.session import SessionManager
from acp_adapter.tools import build_tool_start
# ---------------------------------------------------------------------------
@ -181,6 +182,25 @@ class TestMcpRegistrationE2E:
assert complete_event.raw_output is not None
assert "hello" in str(complete_event.raw_output)
def test_patch_mode_tool_start_emits_diff_blocks_for_v4a_patch(self):
update = build_tool_start(
"tc-1",
"patch",
{
"mode": "patch",
"patch": "*** Begin Patch\n*** Update File: src/app.py\n@@\n-old line\n+new line\n*** Add File: src/new.py\n+hello\n*** End Patch",
},
)
assert len(update.content) == 2
assert update.content[0].type == "diff"
assert update.content[0].path == "src/app.py"
assert update.content[0].old_text == "old line"
assert update.content[0].new_text == "new line"
assert update.content[1].type == "diff"
assert update.content[1].path == "src/new.py"
assert update.content[1].new_text == "hello"
@pytest.mark.asyncio
async def test_prompt_tool_results_paired_by_call_id(self, acp_agent, mock_manager):
"""The ToolCallUpdate's toolCallId must match the ToolCallStart's."""

View file

@ -20,7 +20,9 @@ from acp.schema import (
NewSessionResponse,
PromptResponse,
ResumeSessionResponse,
SessionModelState,
SetSessionConfigOptionResponse,
SetSessionModelResponse,
SetSessionModeResponse,
SessionInfo,
TextContentBlock,
@ -127,6 +129,25 @@ class TestSessionOps:
assert state is not None
assert state.cwd == "/home/user/project"
@pytest.mark.asyncio
async def test_new_session_returns_model_state(self):
manager = SessionManager(
agent_factory=lambda: SimpleNamespace(model="gpt-5.4", provider="openai-codex")
)
acp_agent = HermesACPAgent(session_manager=manager)
with patch(
"hermes_cli.models.curated_models_for_provider",
return_value=[("gpt-5.4", "recommended"), ("gpt-5.4-mini", "")],
):
resp = await acp_agent.new_session(cwd="/tmp")
assert isinstance(resp.models, SessionModelState)
assert resp.models.current_model_id == "openai-codex:gpt-5.4"
assert resp.models.available_models[0].model_id == "openai-codex:gpt-5.4"
assert resp.models.available_models[0].description is not None
assert "Provider:" in resp.models.available_models[0].description
@pytest.mark.asyncio
async def test_available_commands_include_help(self, agent):
help_cmd = next(
@ -204,6 +225,33 @@ class TestListAndFork:
assert fork_resp.session_id
assert fork_resp.session_id != new_resp.session_id
@pytest.mark.asyncio
async def test_list_sessions_includes_title_and_updated_at(self, agent):
with patch.object(
agent.session_manager,
"list_sessions",
return_value=[
{
"session_id": "session-1",
"cwd": "/tmp/project",
"title": "Fix Zed session history",
"updated_at": 123.0,
}
],
):
resp = await agent.list_sessions(cwd="/tmp/project")
assert isinstance(resp.sessions[0], SessionInfo)
assert resp.sessions[0].title == "Fix Zed session history"
assert resp.sessions[0].updated_at == "123.0"
@pytest.mark.asyncio
async def test_list_sessions_passes_cwd_filter(self, agent):
with patch.object(agent.session_manager, "list_sessions", return_value=[]) as mock_list:
await agent.list_sessions(cwd="/mnt/e/Projects/AI/browser-link-3")
mock_list.assert_called_once_with(cwd="/mnt/e/Projects/AI/browser-link-3")
# ---------------------------------------------------------------------------
# session configuration / model routing
# ---------------------------------------------------------------------------
@ -257,6 +305,53 @@ class TestSessionConfiguration:
assert result == {}
assert state.model == "gpt-5.4"
@pytest.mark.asyncio
async def test_set_session_model_accepts_provider_prefixed_choice(self, tmp_path, monkeypatch):
runtime_calls = []
def fake_resolve_runtime_provider(requested=None, **kwargs):
runtime_calls.append(requested)
provider = requested or "openrouter"
return {
"provider": provider,
"api_mode": "anthropic_messages" if provider == "anthropic" else "chat_completions",
"base_url": f"https://{provider}.example/v1",
"api_key": f"{provider}-key",
"command": None,
"args": [],
}
def fake_agent(**kwargs):
return SimpleNamespace(
model=kwargs.get("model"),
provider=kwargs.get("provider"),
base_url=kwargs.get("base_url"),
api_mode=kwargs.get("api_mode"),
)
monkeypatch.setattr("hermes_cli.config.load_config", lambda: {
"model": {"provider": "openrouter", "default": "openrouter/gpt-5"}
})
monkeypatch.setattr(
"hermes_cli.runtime_provider.resolve_runtime_provider",
fake_resolve_runtime_provider,
)
manager = SessionManager(db=SessionDB(tmp_path / "state.db"))
with patch("run_agent.AIAgent", side_effect=fake_agent):
acp_agent = HermesACPAgent(session_manager=manager)
state = manager.create_session(cwd="/tmp")
result = await acp_agent.set_session_model(
model_id="anthropic:claude-sonnet-4-6",
session_id=state.session_id,
)
assert isinstance(result, SetSessionModelResponse)
assert state.model == "claude-sonnet-4-6"
assert state.agent.provider == "anthropic"
assert state.agent.base_url == "https://anthropic.example/v1"
assert runtime_calls[-1] == "anthropic"
# ---------------------------------------------------------------------------
# prompt
@ -354,6 +449,31 @@ class TestPrompt:
update = last_call[1].get("update") or last_call[0][1]
assert update.session_update == "agent_message_chunk"
@pytest.mark.asyncio
async def test_prompt_auto_titles_session(self, agent):
new_resp = await agent.new_session(cwd=".")
state = agent.session_manager.get_session(new_resp.session_id)
state.agent.run_conversation = MagicMock(return_value={
"final_response": "Here is the fix.",
"messages": [
{"role": "user", "content": "fix the broken ACP history"},
{"role": "assistant", "content": "Here is the fix."},
],
})
mock_conn = MagicMock(spec=acp.Client)
mock_conn.session_update = AsyncMock()
agent._conn = mock_conn
with patch("agent.title_generator.maybe_auto_title") as mock_title:
prompt = [TextContentBlock(type="text", text="fix the broken ACP history")]
await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
mock_title.assert_called_once()
assert mock_title.call_args.args[1] == new_resp.session_id
assert mock_title.call_args.args[2] == "fix the broken ACP history"
assert mock_title.call_args.args[3] == "Here is the fix."
@pytest.mark.asyncio
async def test_prompt_populates_usage_from_top_level_run_conversation_fields(self, agent):
"""ACP should map top-level token fields into PromptResponse.usage."""

View file

@ -3,6 +3,7 @@
import contextlib
import io
import json
import time
from types import SimpleNamespace
import pytest
from unittest.mock import MagicMock, patch
@ -100,15 +101,23 @@ class TestListAndCleanup:
def test_list_sessions_returns_created(self, manager):
s1 = manager.create_session(cwd="/a")
s2 = manager.create_session(cwd="/b")
s1.history.append({"role": "user", "content": "hello from a"})
s2.history.append({"role": "user", "content": "hello from b"})
listing = manager.list_sessions()
ids = {s["session_id"] for s in listing}
assert s1.session_id in ids
assert s2.session_id in ids
assert len(listing) == 2
def test_list_sessions_hides_empty_threads(self, manager):
manager.create_session(cwd="/empty")
assert manager.list_sessions() == []
def test_cleanup_clears_all(self, manager):
manager.create_session()
manager.create_session()
s1 = manager.create_session()
s2 = manager.create_session()
s1.history.append({"role": "user", "content": "one"})
s2.history.append({"role": "user", "content": "two"})
assert len(manager.list_sessions()) == 2
manager.cleanup()
assert manager.list_sessions() == []
@ -194,6 +203,8 @@ class TestPersistence:
def test_list_sessions_includes_db_only(self, manager):
"""Sessions only in DB (not in memory) appear in list_sessions."""
state = manager.create_session(cwd="/db-only")
state.history.append({"role": "user", "content": "database only thread"})
manager.save_session(state.session_id)
sid = state.session_id
# Drop from memory.
@ -204,6 +215,53 @@ class TestPersistence:
ids = {s["session_id"] for s in listing}
assert sid in ids
def test_list_sessions_filters_by_cwd(self, manager):
keep = manager.create_session(cwd="/keep")
drop = manager.create_session(cwd="/drop")
keep.history.append({"role": "user", "content": "keep me"})
drop.history.append({"role": "user", "content": "drop me"})
listing = manager.list_sessions(cwd="/keep")
ids = {s["session_id"] for s in listing}
assert keep.session_id in ids
assert drop.session_id not in ids
def test_list_sessions_matches_windows_and_wsl_paths(self, manager):
state = manager.create_session(cwd="/mnt/e/Projects/AI/browser-link-3")
state.history.append({"role": "user", "content": "same project from WSL"})
listing = manager.list_sessions(cwd=r"E:\Projects\AI\browser-link-3")
ids = {s["session_id"] for s in listing}
assert state.session_id in ids
def test_list_sessions_prefers_title_then_preview(self, manager):
state = manager.create_session(cwd="/named")
state.history.append({"role": "user", "content": "Investigate broken ACP history in Zed"})
manager.save_session(state.session_id)
db = manager._get_db()
db.set_session_title(state.session_id, "Fix Zed ACP history")
listing = manager.list_sessions(cwd="/named")
assert listing[0]["title"] == "Fix Zed ACP history"
db.set_session_title(state.session_id, "")
listing = manager.list_sessions(cwd="/named")
assert listing[0]["title"].startswith("Investigate broken ACP history")
def test_list_sessions_sorted_by_most_recent_activity(self, manager):
older = manager.create_session(cwd="/ordered")
older.history.append({"role": "user", "content": "older"})
manager.save_session(older.session_id)
time.sleep(0.02)
newer = manager.create_session(cwd="/ordered")
newer.history.append({"role": "user", "content": "newer"})
manager.save_session(newer.session_id)
listing = manager.list_sessions(cwd="/ordered")
assert [item["session_id"] for item in listing[:2]] == [newer.session_id, older.session_id]
assert listing[0]["updated_at"]
assert listing[1]["updated_at"]
def test_fork_restores_source_from_db(self, manager):
"""Forking a session that is only in DB should work."""
original = manager.create_session()

View file

@ -215,6 +215,46 @@ class TestBuildToolComplete:
assert len(display_text) < 6000
assert "truncated" in display_text
def test_build_tool_complete_for_patch_uses_diff_blocks(self):
"""Completed patch calls should keep structured diff content for Zed."""
patch_result = (
'{"success": true, "diff": "--- a/README.md\\n+++ b/README.md\\n@@ -1 +1,2 @@\\n old line\\n+new line\\n", '
'"files_modified": ["README.md"]}'
)
result = build_tool_complete("tc-p1", "patch", patch_result)
assert isinstance(result, ToolCallProgress)
assert len(result.content) == 1
diff_item = result.content[0]
assert isinstance(diff_item, FileEditToolCallContent)
assert diff_item.path == "README.md"
assert diff_item.old_text == "old line"
assert diff_item.new_text == "old line\nnew line"
def test_build_tool_complete_for_patch_falls_back_to_text_when_no_diff(self):
result = build_tool_complete("tc-p2", "patch", '{"success": true}')
assert isinstance(result, ToolCallProgress)
assert isinstance(result.content[0], ContentToolCallContent)
def test_build_tool_complete_for_write_file_uses_snapshot_diff(self, tmp_path):
target = tmp_path / "diff-test.txt"
snapshot = type("Snapshot", (), {"paths": [target], "before": {str(target): None}})()
target.write_text("hello from hermes\n", encoding="utf-8")
result = build_tool_complete(
"tc-wf1",
"write_file",
'{"bytes_written": 18, "dirs_created": false}',
function_args={"path": str(target), "content": "hello from hermes\n"},
snapshot=snapshot,
)
assert isinstance(result, ToolCallProgress)
assert len(result.content) == 1
diff_item = result.content[0]
assert isinstance(diff_item, FileEditToolCallContent)
assert diff_item.path.endswith("diff-test.txt")
assert diff_item.old_text is None
assert diff_item.new_text == "hello from hermes"
# ---------------------------------------------------------------------------
# extract_locations

View file

@ -696,6 +696,95 @@ class TestIsConnectionError:
assert _is_connection_error(err) is False
class TestKimiForCodingTemperature:
"""kimi-for-coding now requires temperature=0.6 exactly."""
def test_build_call_kwargs_forces_fixed_temperature(self):
from agent.auxiliary_client import _build_call_kwargs
kwargs = _build_call_kwargs(
provider="kimi-coding",
model="kimi-for-coding",
messages=[{"role": "user", "content": "hello"}],
temperature=0.3,
)
assert kwargs["temperature"] == 0.6
def test_build_call_kwargs_injects_temperature_when_missing(self):
from agent.auxiliary_client import _build_call_kwargs
kwargs = _build_call_kwargs(
provider="kimi-coding",
model="kimi-for-coding",
messages=[{"role": "user", "content": "hello"}],
temperature=None,
)
assert kwargs["temperature"] == 0.6
def test_auto_routed_kimi_for_coding_sync_call_uses_fixed_temperature(self):
client = MagicMock()
client.base_url = "https://api.kimi.com/coding/v1"
response = MagicMock()
client.chat.completions.create.return_value = response
with patch(
"agent.auxiliary_client._get_cached_client",
return_value=(client, "kimi-for-coding"),
), patch(
"agent.auxiliary_client._resolve_task_provider_model",
return_value=("auto", "kimi-for-coding", None, None, None),
):
result = call_llm(
task="session_search",
messages=[{"role": "user", "content": "hello"}],
temperature=0.1,
)
assert result is response
kwargs = client.chat.completions.create.call_args.kwargs
assert kwargs["model"] == "kimi-for-coding"
assert kwargs["temperature"] == 0.6
@pytest.mark.asyncio
async def test_auto_routed_kimi_for_coding_async_call_uses_fixed_temperature(self):
client = MagicMock()
client.base_url = "https://api.kimi.com/coding/v1"
response = MagicMock()
client.chat.completions.create = AsyncMock(return_value=response)
with patch(
"agent.auxiliary_client._get_cached_client",
return_value=(client, "kimi-for-coding"),
), patch(
"agent.auxiliary_client._resolve_task_provider_model",
return_value=("auto", "kimi-for-coding", None, None, None),
):
result = await async_call_llm(
task="session_search",
messages=[{"role": "user", "content": "hello"}],
temperature=0.1,
)
assert result is response
kwargs = client.chat.completions.create.call_args.kwargs
assert kwargs["model"] == "kimi-for-coding"
assert kwargs["temperature"] == 0.6
def test_non_kimi_model_still_preserves_temperature(self):
from agent.auxiliary_client import _build_call_kwargs
kwargs = _build_call_kwargs(
provider="kimi-coding",
model="kimi-k2.5",
messages=[{"role": "user", "content": "hello"}],
temperature=0.3,
)
assert kwargs["temperature"] == 0.3
# ---------------------------------------------------------------------------
# async_call_llm payment / connection fallback (#7512 bug 2)
# ---------------------------------------------------------------------------

View file

@ -0,0 +1,311 @@
"""Regression tests for the ``auto`` → main-model-first policy.
Prior to this change, aggregator users (OpenRouter / Nous Portal) had aux
tasks routed through a cheap provider-side default (Gemini Flash) while
non-aggregator users got their main model. This made behavior inconsistent
and surprising users picked Claude but got Gemini Flash summaries.
The current policy: ``auto`` means "use my main chat model" for every user,
regardless of provider type. Explicit per-task overrides in ``config.yaml``
(``auxiliary.<task>.provider``) still win. The cheap fallback chain only
runs when the main provider has no working client.
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
# ── Text aux tasks — _resolve_auto ──────────────────────────────────────────
class TestResolveAutoMainFirst:
"""_resolve_auto() must prefer main provider + main model for every user."""
def test_openrouter_main_uses_main_model_for_aux(self, monkeypatch):
"""OpenRouter main user → aux uses their picked OR model, not Gemini Flash."""
monkeypatch.setenv("OPENROUTER_API_KEY", "or-test-key")
with patch(
"agent.auxiliary_client._read_main_provider",
return_value="openrouter",
), patch(
"agent.auxiliary_client._read_main_model",
return_value="anthropic/claude-sonnet-4.6",
), patch(
"agent.auxiliary_client.resolve_provider_client"
) as mock_resolve:
mock_client = MagicMock()
mock_resolve.return_value = (mock_client, "anthropic/claude-sonnet-4.6")
from agent.auxiliary_client import _resolve_auto
client, model = _resolve_auto()
assert client is mock_client
assert model == "anthropic/claude-sonnet-4.6"
# Verify it asked resolve_provider_client for the MAIN provider+model,
# not a fallback-chain provider
mock_resolve.assert_called_once()
assert mock_resolve.call_args.args[0] == "openrouter"
assert mock_resolve.call_args.args[1] == "anthropic/claude-sonnet-4.6"
def test_nous_main_uses_main_model_for_aux(self, monkeypatch):
"""Nous Portal main user → aux uses their picked Nous model, not free-tier MiMo."""
# No OPENROUTER_API_KEY → ensures if main failed we'd fall to chain
with patch(
"agent.auxiliary_client._read_main_provider", return_value="nous",
), patch(
"agent.auxiliary_client._read_main_model",
return_value="anthropic/claude-opus-4.6",
), patch(
"agent.auxiliary_client.resolve_provider_client"
) as mock_resolve:
mock_client = MagicMock()
mock_resolve.return_value = (mock_client, "anthropic/claude-opus-4.6")
from agent.auxiliary_client import _resolve_auto
client, model = _resolve_auto()
assert client is mock_client
assert model == "anthropic/claude-opus-4.6"
assert mock_resolve.call_args.args[0] == "nous"
def test_non_aggregator_main_still_uses_main(self, monkeypatch):
"""Non-aggregator main (DeepSeek) → unchanged behavior, main model used."""
monkeypatch.setenv("DEEPSEEK_API_KEY", "ds-test")
with patch(
"agent.auxiliary_client._read_main_provider", return_value="deepseek",
), patch(
"agent.auxiliary_client._read_main_model", return_value="deepseek-chat",
), patch(
"agent.auxiliary_client.resolve_provider_client"
) as mock_resolve:
mock_client = MagicMock()
mock_resolve.return_value = (mock_client, "deepseek-chat")
from agent.auxiliary_client import _resolve_auto
client, model = _resolve_auto()
assert client is mock_client
assert model == "deepseek-chat"
assert mock_resolve.call_args.args[0] == "deepseek"
def test_main_unavailable_falls_through_to_chain(self, monkeypatch):
"""Main provider with no working client → fall back to aux chain."""
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
chain_client = MagicMock()
with patch(
"agent.auxiliary_client._read_main_provider", return_value="anthropic",
), patch(
"agent.auxiliary_client._read_main_model", return_value="claude-opus",
), patch(
"agent.auxiliary_client.resolve_provider_client",
return_value=(None, None), # main provider has no client
), patch(
"agent.auxiliary_client._try_openrouter",
return_value=(chain_client, "google/gemini-3-flash-preview"),
):
from agent.auxiliary_client import _resolve_auto
client, model = _resolve_auto()
assert client is chain_client
assert model == "google/gemini-3-flash-preview"
def test_no_main_config_uses_chain_directly(self):
"""No main provider configured → skip step 1, use chain (no regression)."""
chain_client = MagicMock()
with patch(
"agent.auxiliary_client._read_main_provider", return_value="",
), patch(
"agent.auxiliary_client._read_main_model", return_value="",
), patch(
"agent.auxiliary_client._try_openrouter",
return_value=(chain_client, "google/gemini-3-flash-preview"),
):
from agent.auxiliary_client import _resolve_auto
client, model = _resolve_auto()
assert client is chain_client
def test_runtime_override_wins_over_config(self, monkeypatch):
"""main_runtime kwarg overrides config-read main provider/model."""
with patch(
"agent.auxiliary_client._read_main_provider",
return_value="openrouter",
), patch(
"agent.auxiliary_client._read_main_model", return_value="config-model",
), patch(
"agent.auxiliary_client.resolve_provider_client"
) as mock_resolve:
mock_resolve.return_value = (MagicMock(), "runtime-model")
from agent.auxiliary_client import _resolve_auto
_resolve_auto(main_runtime={
"provider": "anthropic",
"model": "runtime-model",
"base_url": "",
"api_key": "",
"api_mode": "",
})
# Runtime override wins
assert mock_resolve.call_args.args[0] == "anthropic"
assert mock_resolve.call_args.args[1] == "runtime-model"
# ── Vision — resolve_vision_provider_client ─────────────────────────────────
class TestResolveVisionMainFirst:
"""Vision auto-detection prefers main provider + main model first."""
def test_openrouter_main_vision_uses_main_model(self, monkeypatch):
"""OpenRouter main with vision-capable model → aux vision uses main model."""
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
with patch(
"agent.auxiliary_client._read_main_provider", return_value="openrouter",
), patch(
"agent.auxiliary_client._read_main_model",
return_value="anthropic/claude-sonnet-4.6",
), patch(
"agent.auxiliary_client.resolve_provider_client"
) as mock_resolve, patch(
"agent.auxiliary_client._resolve_task_provider_model",
return_value=("auto", None, None, None, None),
):
mock_client = MagicMock()
mock_resolve.return_value = (mock_client, "anthropic/claude-sonnet-4.6")
from agent.auxiliary_client import resolve_vision_provider_client
provider, client, model = resolve_vision_provider_client()
assert provider == "openrouter"
assert client is mock_client
assert model == "anthropic/claude-sonnet-4.6"
# Verify it did NOT call the strict vision backend for OpenRouter
# (which would have used a cheap gemini-flash-preview default)
mock_resolve.assert_called_once()
assert mock_resolve.call_args.args[0] == "openrouter"
assert mock_resolve.call_args.args[1] == "anthropic/claude-sonnet-4.6"
def test_nous_main_vision_uses_main_model(self):
"""Nous Portal main → aux vision uses main model, not free-tier MiMo-V2-Omni."""
with patch(
"agent.auxiliary_client._read_main_provider", return_value="nous",
), patch(
"agent.auxiliary_client._read_main_model",
return_value="openai/gpt-5",
), patch(
"agent.auxiliary_client.resolve_provider_client"
) as mock_resolve, patch(
"agent.auxiliary_client._resolve_task_provider_model",
return_value=("auto", None, None, None, None),
):
mock_client = MagicMock()
mock_resolve.return_value = (mock_client, "openai/gpt-5")
from agent.auxiliary_client import resolve_vision_provider_client
provider, client, model = resolve_vision_provider_client()
assert provider == "nous"
assert model == "openai/gpt-5"
def test_exotic_provider_with_vision_override_preserved(self):
"""xiaomi → mimo-v2-omni override still wins over main_model."""
with patch(
"agent.auxiliary_client._read_main_provider", return_value="xiaomi",
), patch(
"agent.auxiliary_client._read_main_model",
return_value="mimo-v2-pro", # text model
), patch(
"agent.auxiliary_client.resolve_provider_client"
) as mock_resolve, patch(
"agent.auxiliary_client._resolve_task_provider_model",
return_value=("auto", None, None, None, None),
):
mock_resolve.return_value = (MagicMock(), "mimo-v2-omni")
from agent.auxiliary_client import resolve_vision_provider_client
provider, client, model = resolve_vision_provider_client()
assert provider == "xiaomi"
# Should use mimo-v2-omni (vision override), not mimo-v2-pro (text main)
assert mock_resolve.call_args.args[1] == "mimo-v2-omni"
def test_main_unavailable_vision_falls_through_to_aggregators(self):
"""Main provider fails → fall back to OpenRouter/Nous strict backends."""
fallback_client = MagicMock()
with patch(
"agent.auxiliary_client._read_main_provider", return_value="deepseek",
), patch(
"agent.auxiliary_client._read_main_model", return_value="deepseek-chat",
), patch(
"agent.auxiliary_client.resolve_provider_client",
return_value=(None, None),
), patch(
"agent.auxiliary_client._resolve_strict_vision_backend",
return_value=(fallback_client, "google/gemini-3-flash-preview"),
), patch(
"agent.auxiliary_client._resolve_task_provider_model",
return_value=("auto", None, None, None, None),
):
from agent.auxiliary_client import resolve_vision_provider_client
provider, client, model = resolve_vision_provider_client()
assert client is fallback_client
assert provider in ("openrouter", "nous")
def test_explicit_provider_override_still_wins(self):
"""Explicit config override bypasses main-first policy."""
with patch(
"agent.auxiliary_client._read_main_provider", return_value="openrouter",
), patch(
"agent.auxiliary_client._read_main_model",
return_value="anthropic/claude-opus-4.6",
), patch(
"agent.auxiliary_client._resolve_task_provider_model",
return_value=("nous", None, None, None, None), # explicit override
), patch(
"agent.auxiliary_client._resolve_strict_vision_backend"
) as mock_strict:
mock_strict.return_value = (MagicMock(), "nous-default-model")
from agent.auxiliary_client import resolve_vision_provider_client
provider, client, model = resolve_vision_provider_client()
# Explicit "nous" override → uses strict backend, NOT main model path
assert provider == "nous"
mock_strict.assert_called_once_with("nous")
# ── Constant cleanup ────────────────────────────────────────────────────────
def test_aggregator_providers_constant_removed():
"""The dead _AGGREGATOR_PROVIDERS constant should no longer live in the module.
Removed when the main-first policy made the aggregator-skip guard obsolete.
"""
import agent.auxiliary_client as aux_mod
assert not hasattr(aux_mod, "_AGGREGATOR_PROVIDERS"), (
"_AGGREGATOR_PROVIDERS was removed when _resolve_auto stopped "
"treating aggregators specially. If you re-added it, the main-first "
"policy may have regressed."
)

View file

@ -826,6 +826,160 @@ class TestGeminiCloudCodeClient:
finally:
client.close()
class TestGeminiHttpErrorParsing:
"""Regression coverage for _gemini_http_error Google-envelope parsing.
These are the paths that users actually hit during Google-side throttling
(April 2026: gemini-2.5-pro MODEL_CAPACITY_EXHAUSTED, gemma-4-26b-it
returning 404). The error needs to carry status_code + response so the
main loop's error_classifier and Retry-After logic work.
"""
@staticmethod
def _fake_response(status: int, body: dict | str = "", headers=None):
"""Minimal httpx.Response stand-in (duck-typed for _gemini_http_error)."""
class _FakeResponse:
def __init__(self):
self.status_code = status
if isinstance(body, dict):
self.text = json.dumps(body)
else:
self.text = body
self.headers = headers or {}
return _FakeResponse()
def test_model_capacity_exhausted_produces_friendly_message(self):
from agent.gemini_cloudcode_adapter import _gemini_http_error
body = {
"error": {
"code": 429,
"message": "Resource has been exhausted (e.g. check quota).",
"status": "RESOURCE_EXHAUSTED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"reason": "MODEL_CAPACITY_EXHAUSTED",
"domain": "googleapis.com",
"metadata": {"model": "gemini-2.5-pro"},
},
{
"@type": "type.googleapis.com/google.rpc.RetryInfo",
"retryDelay": "30s",
},
],
}
}
err = _gemini_http_error(self._fake_response(429, body))
assert err.status_code == 429
assert err.code == "code_assist_capacity_exhausted"
assert err.retry_after == 30.0
assert err.details["reason"] == "MODEL_CAPACITY_EXHAUSTED"
# Message must be user-friendly, not a raw JSON dump.
message = str(err)
assert "gemini-2.5-pro" in message
assert "capacity exhausted" in message.lower()
assert "30s" in message
# response attr is preserved for run_agent's Retry-After header path.
assert err.response is not None
def test_resource_exhausted_without_reason(self):
from agent.gemini_cloudcode_adapter import _gemini_http_error
body = {
"error": {
"code": 429,
"message": "Quota exceeded for requests per minute.",
"status": "RESOURCE_EXHAUSTED",
}
}
err = _gemini_http_error(self._fake_response(429, body))
assert err.status_code == 429
assert err.code == "code_assist_rate_limited"
message = str(err)
assert "quota" in message.lower()
def test_404_model_not_found_produces_model_retired_message(self):
from agent.gemini_cloudcode_adapter import _gemini_http_error
body = {
"error": {
"code": 404,
"message": "models/gemma-4-26b-it is not found for API version v1internal",
"status": "NOT_FOUND",
}
}
err = _gemini_http_error(self._fake_response(404, body))
assert err.status_code == 404
message = str(err)
assert "not available" in message.lower() or "retired" in message.lower()
# Error message should reference the actual model text from Google.
assert "gemma-4-26b-it" in message
def test_unauthorized_preserves_status_code(self):
from agent.gemini_cloudcode_adapter import _gemini_http_error
err = _gemini_http_error(self._fake_response(
401, {"error": {"code": 401, "message": "Invalid token", "status": "UNAUTHENTICATED"}},
))
assert err.status_code == 401
assert err.code == "code_assist_unauthorized"
def test_retry_after_header_fallback(self):
"""If the body has no RetryInfo detail, fall back to Retry-After header."""
from agent.gemini_cloudcode_adapter import _gemini_http_error
resp = self._fake_response(
429,
{"error": {"code": 429, "message": "Rate limited", "status": "RESOURCE_EXHAUSTED"}},
headers={"Retry-After": "45"},
)
err = _gemini_http_error(resp)
assert err.retry_after == 45.0
def test_malformed_body_still_produces_structured_error(self):
"""Non-JSON body must not swallow status_code — we still want the classifier path."""
from agent.gemini_cloudcode_adapter import _gemini_http_error
err = _gemini_http_error(self._fake_response(500, "<html>internal error</html>"))
assert err.status_code == 500
# Raw body snippet must still be there for debugging.
assert "500" in str(err)
def test_status_code_flows_through_error_classifier(self):
"""End-to-end: CodeAssistError from a 429 must classify as rate_limit.
This is the whole point of adding status_code to CodeAssistError
_extract_status_code must see it and FailoverReason.rate_limit must
fire, so the main loop triggers fallback_providers.
"""
from agent.gemini_cloudcode_adapter import _gemini_http_error
from agent.error_classifier import classify_api_error, FailoverReason
body = {
"error": {
"code": 429,
"message": "Resource has been exhausted",
"status": "RESOURCE_EXHAUSTED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"reason": "MODEL_CAPACITY_EXHAUSTED",
"metadata": {"model": "gemini-2.5-pro"},
}
],
}
}
err = _gemini_http_error(self._fake_response(429, body))
classified = classify_api_error(
err, provider="google-gemini-cli", model="gemini-2.5-pro",
)
assert classified.status_code == 429
assert classified.reason == FailoverReason.rate_limit
# =============================================================================
# Provider registration
# =============================================================================

View file

@ -0,0 +1,71 @@
"""Tests for CLI /copy command."""
from unittest.mock import MagicMock, patch
from cli import HermesCLI
def _make_cli() -> HermesCLI:
cli_obj = HermesCLI.__new__(HermesCLI)
cli_obj.config = {}
cli_obj.console = MagicMock()
cli_obj.agent = None
cli_obj.conversation_history = []
cli_obj.session_id = "sess-copy-test"
cli_obj._pending_input = MagicMock()
cli_obj._app = None
return cli_obj
def test_copy_copies_latest_assistant_message():
cli_obj = _make_cli()
cli_obj.conversation_history = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "first"},
{"role": "assistant", "content": "latest"},
]
with patch.object(cli_obj, "_write_osc52_clipboard") as mock_copy:
result = cli_obj.process_command("/copy")
assert result is True
mock_copy.assert_called_once_with("latest")
def test_copy_with_index_uses_requested_assistant_message():
cli_obj = _make_cli()
cli_obj.conversation_history = [
{"role": "assistant", "content": "one"},
{"role": "assistant", "content": "two"},
]
with patch.object(cli_obj, "_write_osc52_clipboard") as mock_copy:
cli_obj.process_command("/copy 1")
mock_copy.assert_called_once_with("one")
def test_copy_strips_reasoning_blocks_before_copy():
cli_obj = _make_cli()
cli_obj.conversation_history = [
{
"role": "assistant",
"content": "<REASONING_SCRATCHPAD>internal</REASONING_SCRATCHPAD>\nVisible answer",
}
]
with patch.object(cli_obj, "_write_osc52_clipboard") as mock_copy:
cli_obj.process_command("/copy")
mock_copy.assert_called_once_with("Visible answer")
def test_copy_invalid_index_does_not_copy():
cli_obj = _make_cli()
cli_obj.conversation_history = [{"role": "assistant", "content": "only"}]
with patch.object(cli_obj, "_write_osc52_clipboard") as mock_copy, patch("cli._cprint") as mock_print:
cli_obj.process_command("/copy 99")
mock_copy.assert_not_called()
assert any("Invalid response number" in str(call) for call in mock_print.call_args_list)

View file

@ -64,6 +64,24 @@ class TestSaveConfigValueAtomic:
result = yaml.safe_load(config_env.read_text())
assert result["display"]["skin"] == "ares"
def test_preserves_env_ref_templates_in_unrelated_fields(self, config_env):
"""The /model --global persistence path must not inline env-backed secrets."""
config_env.write_text(yaml.dump({
"custom_providers": [{
"name": "tuzi",
"api_key": "${TU_ZI_API_KEY}",
"model": "claude-opus-4-6",
}],
"model": {"default": "test-model", "provider": "openrouter"},
}))
from cli import save_config_value
save_config_value("model.default", "doubao-pro")
result = yaml.safe_load(config_env.read_text())
assert result["model"]["default"] == "doubao-pro"
assert result["custom_providers"][0]["api_key"] == "${TU_ZI_API_KEY}"
def test_file_not_truncated_on_error(self, config_env, monkeypatch):
"""If atomic_yaml_write raises, the original file is untouched."""
original_content = config_env.read_text()

View file

@ -2,7 +2,8 @@
Surrogates (U+D800..U+DFFF) are invalid in UTF-8 and crash json.dumps()
inside the OpenAI SDK. They can appear via clipboard paste from rich-text
editors like Google Docs.
editors like Google Docs, OR from byte-level reasoning models (xiaomi/mimo,
kimi, glm) emitting lone halves in reasoning output.
"""
import json
import pytest
@ -11,6 +12,7 @@ from unittest.mock import MagicMock, patch
from run_agent import (
_sanitize_surrogates,
_sanitize_messages_surrogates,
_sanitize_structure_surrogates,
_SURROGATE_RE,
)
@ -109,6 +111,186 @@ class TestSanitizeMessagesSurrogates:
assert "\ufffd" in msgs[0]["content"]
class TestReasoningFieldSurrogates:
"""Surrogates in reasoning fields (byte-level reasoning models).
xiaomi/mimo, kimi, glm and similar byte-level tokenizers can emit lone
surrogates in reasoning output. These fields are carried through to the
API as `reasoning_content` on assistant messages, and must be sanitized
or json.dumps() crashes with 'utf-8' codec can't encode surrogates.
"""
def test_reasoning_field_sanitized(self):
msgs = [
{"role": "assistant", "content": "ok", "reasoning": "thought \udce2 here"},
]
assert _sanitize_messages_surrogates(msgs) is True
assert "\udce2" not in msgs[0]["reasoning"]
assert "\ufffd" in msgs[0]["reasoning"]
def test_reasoning_content_field_sanitized(self):
"""api_messages carry `reasoning_content` built from `reasoning`."""
msgs = [
{"role": "assistant", "content": "ok", "reasoning_content": "thought \udce2 here"},
]
assert _sanitize_messages_surrogates(msgs) is True
assert "\udce2" not in msgs[0]["reasoning_content"]
assert "\ufffd" in msgs[0]["reasoning_content"]
def test_reasoning_details_nested_sanitized(self):
"""reasoning_details is a list of dicts with nested string fields."""
msgs = [
{
"role": "assistant",
"content": "ok",
"reasoning_details": [
{"type": "reasoning.summary", "summary": "summary \udce2 text"},
{"type": "reasoning.text", "text": "chain \udc00 of thought"},
],
},
]
assert _sanitize_messages_surrogates(msgs) is True
assert "\udce2" not in msgs[0]["reasoning_details"][0]["summary"]
assert "\ufffd" in msgs[0]["reasoning_details"][0]["summary"]
assert "\udc00" not in msgs[0]["reasoning_details"][1]["text"]
assert "\ufffd" in msgs[0]["reasoning_details"][1]["text"]
def test_deeply_nested_reasoning_sanitized(self):
"""Nested dicts / lists inside extra fields are recursed into."""
msgs = [
{
"role": "assistant",
"content": "ok",
"reasoning_details": [
{
"type": "reasoning.encrypted",
"content": {
"encrypted_content": "opaque",
"text_parts": ["part1", "part2 \udce2 part"],
},
},
],
},
]
assert _sanitize_messages_surrogates(msgs) is True
assert (
msgs[0]["reasoning_details"][0]["content"]["text_parts"][1]
== "part2 \ufffd part"
)
def test_reasoning_end_to_end_json_serialization(self):
"""After sanitization, the full message dict must serialize clean."""
msgs = [
{
"role": "assistant",
"content": "answer",
"reasoning_content": "reasoning with \udce2 surrogate",
"reasoning_details": [
{"summary": "nested \udcb0 surrogate"},
],
},
]
_sanitize_messages_surrogates(msgs)
# Must round-trip through json + utf-8 encoding without error
payload = json.dumps(msgs, ensure_ascii=False).encode("utf-8")
assert b"\\" not in payload[:0] # sanity — just ensure we got bytes
assert len(payload) > 0
def test_no_surrogates_returns_false(self):
"""Clean reasoning fields don't trigger a modification."""
msgs = [
{
"role": "assistant",
"content": "ok",
"reasoning": "clean thought",
"reasoning_content": "also clean",
"reasoning_details": [{"summary": "clean summary"}],
},
]
assert _sanitize_messages_surrogates(msgs) is False
class TestSanitizeStructureSurrogates:
"""Test the _sanitize_structure_surrogates() helper for nested payloads."""
def test_empty_payload(self):
assert _sanitize_structure_surrogates({}) is False
assert _sanitize_structure_surrogates([]) is False
def test_flat_dict(self):
payload = {"a": "clean", "b": "dirty \udce2 text"}
assert _sanitize_structure_surrogates(payload) is True
assert payload["a"] == "clean"
assert "\ufffd" in payload["b"]
def test_flat_list(self):
payload = ["clean", "dirty \udce2"]
assert _sanitize_structure_surrogates(payload) is True
assert payload[0] == "clean"
assert "\ufffd" in payload[1]
def test_nested_dict_in_list(self):
payload = [{"x": "dirty \udce2"}, {"x": "clean"}]
assert _sanitize_structure_surrogates(payload) is True
assert "\ufffd" in payload[0]["x"]
assert payload[1]["x"] == "clean"
def test_deeply_nested(self):
payload = {
"level1": {
"level2": [
{"level3": "deep \udce2 surrogate"},
],
},
}
assert _sanitize_structure_surrogates(payload) is True
assert "\ufffd" in payload["level1"]["level2"][0]["level3"]
def test_clean_payload_returns_false(self):
payload = {"a": "clean", "b": [{"c": "also clean"}]}
assert _sanitize_structure_surrogates(payload) is False
def test_non_string_values_ignored(self):
payload = {"int": 42, "list": [1, 2, 3], "dict": {"none": None}, "bool": True}
assert _sanitize_structure_surrogates(payload) is False
# Non-string values survive unchanged
assert payload["int"] == 42
assert payload["list"] == [1, 2, 3]
class TestApiMessagesSurrogateRecovery:
"""Integration: verify the recovery block sanitizes api_messages.
The bug this guards against: a surrogate in `reasoning_content` on
api_messages (transformed from `reasoning` during build) crashes the
OpenAI SDK's json.dumps(), and the recovery block previously only
sanitized the canonical `messages` list not `api_messages` so the
next retry would send the same broken payload and fail 3 times.
"""
def test_api_messages_reasoning_content_sanitized(self):
"""The extended sanitizer catches reasoning_content in api_messages."""
api_messages = [
{"role": "system", "content": "sys"},
{
"role": "assistant",
"content": "response",
"reasoning_content": "thought \udce2 trail",
"tool_calls": [
{
"id": "call_1",
"function": {"name": "tool", "arguments": "{}"},
}
],
},
{"role": "tool", "content": "result", "tool_call_id": "call_1"},
]
assert _sanitize_messages_surrogates(api_messages) is True
assert "\udce2" not in api_messages[1]["reasoning_content"]
# Full payload must now serialize clean
json.dumps(api_messages, ensure_ascii=False).encode("utf-8")
class TestRunConversationSurrogateSanitization:
"""Integration: verify run_conversation sanitizes user_message."""

View file

@ -184,6 +184,8 @@ _HERMES_BEHAVIORAL_VARS = frozenset({
"HERMES_BACKGROUND_NOTIFICATIONS",
"HERMES_EXEC_ASK",
"HERMES_HOME_MODE",
"BROWSER_CDP_URL",
"CAMOFOX_URL",
})
@ -229,6 +231,15 @@ def _hermetic_environment(tmp_path, monkeypatch):
monkeypatch.setenv("LC_ALL", "C.UTF-8")
monkeypatch.setenv("PYTHONHASHSEED", "0")
# 4b. Disable AWS IMDS lookups. Without this, any test that ends up
# calling has_aws_credentials() / resolve_aws_auth_env_var()
# (e.g. provider auto-detect, status command, cron run_job) burns
# ~2s waiting for the metadata service at 169.254.169.254 to time
# out. Tests don't run on EC2 — IMDS is always unreachable here.
monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "true")
monkeypatch.setenv("AWS_METADATA_SERVICE_TIMEOUT", "1")
monkeypatch.setenv("AWS_METADATA_SERVICE_NUM_ATTEMPTS", "1")
# 5. Reset plugin singleton so tests don't leak plugins from
# ~/.hermes/plugins/ (which, per step 3, is now empty — but the
# singleton might still be cached from a previous test).

View file

@ -7,6 +7,7 @@ from unittest.mock import patch
from gateway.channel_directory import (
build_channel_directory,
lookup_channel_type,
resolve_channel_name,
format_directory_for_display,
load_directory,
@ -285,3 +286,49 @@ class TestFormatDirectoryForDisplay:
assert "Discord (Server1):" in result
assert "Discord (Server2):" in result
assert "discord:#general" in result
class TestLookupChannelType:
def _setup(self, tmp_path, platforms):
cache_file = _write_directory(tmp_path, platforms)
return patch("gateway.channel_directory.DIRECTORY_PATH", cache_file)
def test_forum_channel(self, tmp_path):
platforms = {
"discord": [
{"id": "100", "name": "ideas", "guild": "Server1", "type": "forum"},
]
}
with self._setup(tmp_path, platforms):
assert lookup_channel_type("discord", "100") == "forum"
def test_regular_channel(self, tmp_path):
platforms = {
"discord": [
{"id": "200", "name": "general", "guild": "Server1", "type": "channel"},
]
}
with self._setup(tmp_path, platforms):
assert lookup_channel_type("discord", "200") == "channel"
def test_unknown_chat_id_returns_none(self, tmp_path):
platforms = {
"discord": [
{"id": "200", "name": "general", "guild": "Server1", "type": "channel"},
]
}
with self._setup(tmp_path, platforms):
assert lookup_channel_type("discord", "999") is None
def test_unknown_platform_returns_none(self, tmp_path):
with self._setup(tmp_path, {}):
assert lookup_channel_type("discord", "100") is None
def test_channel_without_type_key_returns_none(self, tmp_path):
platforms = {
"discord": [
{"id": "300", "name": "general", "guild": "Server1"},
]
}
with self._setup(tmp_path, platforms):
assert lookup_channel_type("discord", "300") is None

View file

@ -160,6 +160,30 @@ class TestCommandBypassActiveSession:
assert sk not in adapter._pending_messages
assert any("handled:status" in r for r in adapter.sent_responses)
@pytest.mark.asyncio
async def test_agents_bypasses_guard(self):
"""/agents must bypass so active-task queries don't interrupt runs."""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event("/agents"))
assert sk not in adapter._pending_messages
assert any("handled:agents" in r for r in adapter.sent_responses)
@pytest.mark.asyncio
async def test_tasks_alias_bypasses_guard(self):
"""/tasks alias must bypass active-session guard too."""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event("/tasks"))
assert sk not in adapter._pending_messages
assert any("handled:tasks" in r for r in adapter.sent_responses)
@pytest.mark.asyncio
async def test_background_bypasses_guard(self):
"""/background must bypass so it spawns a parallel task, not an interrupt."""
@ -176,6 +200,38 @@ class TestCommandBypassActiveSession:
"/background response was not sent back to the user"
)
@pytest.mark.asyncio
async def test_help_bypasses_guard(self):
"""/help must bypass so it is not silently dropped as pending slash text."""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event("/help"))
assert sk not in adapter._pending_messages, (
"/help was queued as a pending message instead of being dispatched"
)
assert any("handled:help" in r for r in adapter.sent_responses), (
"/help response was not sent back to the user"
)
@pytest.mark.asyncio
async def test_update_bypasses_guard(self):
"""/update must bypass so it is not discarded by the pending-command safety net."""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event("/update"))
assert sk not in adapter._pending_messages, (
"/update was queued as a pending message instead of being dispatched"
)
assert any("handled:update" in r for r in adapter.sent_responses), (
"/update response was not sent back to the user"
)
@pytest.mark.asyncio
async def test_queue_bypasses_guard(self):
"""/queue must bypass so it can queue without interrupting."""

View file

@ -198,7 +198,7 @@ class TestSend:
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
adapter._http_client = mock_client
adapter._session_webhooks["chat-123"] = "https://cached.example/webhook"
adapter._session_webhooks["chat-123"] = ("https://cached.example/webhook", 9999999999999)
result = await adapter.send("chat-123", "Hello!")
assert result.success is True
@ -681,3 +681,290 @@ class TestIncomingHandlerProcess:
processing_gate.set()
await asyncio.sleep(0.05)
# ---------------------------------------------------------------------------
# Text extraction — mention preservation + platform sanity
# ---------------------------------------------------------------------------
class TestExtractTextMentions:
def test_preserves_at_mentions_in_text(self):
"""@mentions are routing signals (via isInAtList), not text to strip.
Stripping all @handles collateral-damages emails, SSH URLs, and
literal references the user wrote.
"""
from gateway.platforms.dingtalk import DingTalkAdapter
cases = [
("@bot hello", "@bot hello"),
("contact alice@example.com", "contact alice@example.com"),
("git@github.com:foo/bar.git", "git@github.com:foo/bar.git"),
("what does @openai think", "what does @openai think"),
("@机器人 转发给 @老王", "@机器人 转发给 @老王"),
]
for text, expected in cases:
msg = MagicMock()
msg.text = text
msg.rich_text = None
msg.rich_text_content = None
assert DingTalkAdapter._extract_text(msg) == expected, (
f"mangled: {text!r} -> {DingTalkAdapter._extract_text(msg)!r}"
)
def test_dingtalk_in_platform_enum(self):
assert Platform.DINGTALK.value == "dingtalk"
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Concurrency — chat-scoped message context
# ---------------------------------------------------------------------------
class TestMessageContextIsolation:
def test_contexts_keyed_by_chat_id(self):
"""Two concurrent chats must not clobber each other's context."""
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
msg_a = MagicMock(conversation_id="chat-A", sender_staff_id="user-A")
msg_b = MagicMock(conversation_id="chat-B", sender_staff_id="user-B")
adapter._message_contexts["chat-A"] = msg_a
adapter._message_contexts["chat-B"] = msg_b
assert adapter._message_contexts["chat-A"] is msg_a
assert adapter._message_contexts["chat-B"] is msg_b
# ---------------------------------------------------------------------------
# Card lifecycle: finalize via metadata["streaming"]
# ---------------------------------------------------------------------------
class TestCardLifecycle:
@pytest.fixture
def adapter_with_card(self):
from gateway.platforms.dingtalk import DingTalkAdapter
a = DingTalkAdapter(PlatformConfig(
enabled=True,
extra={"card_template_id": "tmpl-1"},
))
a._card_sdk = MagicMock()
a._card_sdk.create_card_with_options_async = AsyncMock()
a._card_sdk.deliver_card_with_options_async = AsyncMock()
a._card_sdk.streaming_update_with_options_async = AsyncMock()
a._http_client = AsyncMock()
a._get_access_token = AsyncMock(return_value="token")
# Minimal message context
msg = MagicMock(
conversation_id="chat-1",
conversation_type="1",
sender_staff_id="staff-1",
message_id="user-msg-1",
)
a._message_contexts["chat-1"] = msg
a._session_webhooks["chat-1"] = (
"https://api.dingtalk.com/x", 9999999999999,
)
return a
@pytest.mark.asyncio
async def test_final_reply_finalizes_card(self, adapter_with_card):
"""send(reply_to=...) creates a closed card (final response path)."""
a = adapter_with_card
result = await a.send("chat-1", "Hello", reply_to="user-msg-1")
assert result.success
call = a._card_sdk.streaming_update_with_options_async.call_args
assert call[0][0].is_finalize is True
# Not tracked as streaming — it's already closed.
assert "chat-1" not in a._streaming_cards
@pytest.mark.asyncio
async def test_intermediate_send_stays_streaming(self, adapter_with_card):
"""send() without reply_to creates an OPEN card (tool progress /
commentary / streaming first chunk). No flicker closedstreaming
when edit_message follows."""
a = adapter_with_card
result = await a.send("chat-1", "💻 terminal: ls")
assert result.success
call = a._card_sdk.streaming_update_with_options_async.call_args
assert call[0][0].is_finalize is False
# Tracked for sibling cleanup.
assert result.message_id in a._streaming_cards.get("chat-1", {})
@pytest.mark.asyncio
async def test_done_fires_only_when_reply_to_is_set(self, adapter_with_card):
"""reply_to distinguishes final response (base.py) from tool-progress
sends (run.py). Done must only fire for the former."""
a = adapter_with_card
fired: list[str] = []
a._fire_done_reaction = lambda cid: fired.append(cid)
# Tool-progress / commentary path: no reply_to — no Done.
await a.send("chat-1", "tool line")
assert fired == []
# Final response path: reply_to set — Done fires.
await a.send("chat-1", "final", reply_to="user-msg-1")
assert fired == ["chat-1"]
@pytest.mark.asyncio
async def test_edit_message_finalize_fires_done(self, adapter_with_card):
"""Stream consumer's final edit_message(finalize=True) fires Done."""
a = adapter_with_card
fired: list[str] = []
a._fire_done_reaction = lambda cid: fired.append(cid)
await a.send("chat-1", "initial")
# Reopen via edit_message(finalize=False) then close.
await a.edit_message(
chat_id="chat-1", message_id="track-X",
content="streaming...", finalize=False,
)
await a.edit_message(
chat_id="chat-1", message_id="track-X",
content="final", finalize=True,
)
assert "chat-1" in fired
@pytest.mark.asyncio
async def test_edit_message_finalize_false_tracks_sibling(self, adapter_with_card):
"""After edit_message(finalize=False), card is tracked as open."""
a = adapter_with_card
await a.edit_message(
chat_id="chat-1", message_id="track-1",
content="partial", finalize=False,
)
assert "chat-1" in a._streaming_cards
assert a._streaming_cards["chat-1"].get("track-1") == "partial"
@pytest.mark.asyncio
async def test_next_send_auto_closes_sibling_streaming_cards(
self, adapter_with_card,
):
"""Tool-progress card left open (send without reply_to + edits) must
be auto-closed when the final-reply send arrives."""
a = adapter_with_card
# First tool: intermediate send — card stays open.
r1 = await a.send("chat-1", "💻 tool1")
# Second tool: edit_message(finalize=False) — keeps streaming.
await a.edit_message(
chat_id="chat-1", message_id=r1.message_id,
content="💻 tool1\n💻 tool2", finalize=False,
)
assert r1.message_id in a._streaming_cards.get("chat-1", {})
a._card_sdk.streaming_update_with_options_async.reset_mock()
# Final response send auto-closes the sibling.
await a.send("chat-1", "final answer", reply_to="user-msg")
calls = a._card_sdk.streaming_update_with_options_async.call_args_list
assert len(calls) >= 2
# First call was the sibling close with last-seen tool-progress content.
first_req = calls[0][0][0]
assert first_req.out_track_id == r1.message_id
assert first_req.is_finalize is True
assert "tool1" in first_req.content
# Streaming tracking is cleared after close.
assert "chat-1" not in a._streaming_cards
@pytest.mark.asyncio
async def test_edit_message_requires_message_id(self, adapter_with_card):
a = adapter_with_card
result = await a.edit_message(
chat_id="chat-1", message_id="", content="x", finalize=True,
)
assert result.success is False
a._card_sdk.streaming_update_with_options_async.assert_not_called()
def test_fire_done_reaction_is_idempotent(self, adapter_with_card):
a = adapter_with_card
captured = []
def _capture(coro):
captured.append(coro)
a._spawn_bg = _capture
a._fire_done_reaction("chat-1")
a._fire_done_reaction("chat-1")
assert len(captured) == 1
captured[0].close()
# ---------------------------------------------------------------------------
# AI Card Tests
# ---------------------------------------------------------------------------
class TestDingTalkAdapterAICards:
@pytest.fixture
def config(self):
return PlatformConfig(
enabled=True,
extra={
"client_id": "test_id",
"client_secret": "test_secret",
"card_template_id": "test_card_template",
},
)
@pytest.fixture
def mock_stream_client(self):
client = MagicMock()
client.get_access_token = MagicMock(return_value="test_token")
return client
@pytest.fixture
def mock_http_client(self):
return AsyncMock()
@pytest.fixture
def mock_message(self):
msg = MagicMock()
msg.message_id = "test_msg_id"
msg.conversation_id = "test_conv_id"
msg.conversation_type = "1"
msg.sender_id = "sender1"
msg.sender_nick = "Test User"
msg.sender_staff_id = "staff1"
msg.text = MagicMock(content="Hello")
msg.session_webhook = "https://api.dingtalk.com/robot/sendBySession?session=test"
msg.session_webhook_expired_time = 999999999999
msg.create_at = int(datetime.now(tz=timezone.utc).timestamp() * 1000)
msg.at_users = []
return msg
@pytest.mark.asyncio
async def test_send_uses_ai_card_if_configured(self, config, mock_stream_client, mock_http_client, mock_message):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(config)
adapter._stream_client = mock_stream_client
adapter._http_client = mock_http_client
adapter._message_contexts["test_conv_id"] = mock_message
adapter._session_webhooks = {"test_conv_id": ("https://api.dingtalk.com/robot/sendBySession?session=test", 9999999999999)}
adapter._card_template_id = "test_card_template"
# Mock the card SDK with proper async methods
mock_card_sdk = MagicMock()
mock_card_sdk.create_card_with_options_async = AsyncMock()
mock_card_sdk.deliver_card_with_options_async = AsyncMock()
mock_card_sdk.streaming_update_with_options_async = AsyncMock()
adapter._card_sdk = mock_card_sdk
# Mock access token
adapter._get_access_token = AsyncMock(return_value="test_token")
result = await adapter.send("test_conv_id", "Hello World")
mock_card_sdk.create_card_with_options_async.assert_called_once()
mock_card_sdk.deliver_card_with_options_async.assert_called_once()
mock_card_sdk.streaming_update_with_options_async.assert_called_once()
assert result.success is True

View file

@ -157,3 +157,232 @@ async def test_send_does_not_retry_on_unrelated_errors():
# Only the first attempt happens — no reference-retry replay.
assert channel.send.await_count == 1
assert send_calls[0]["reference"] is reference_obj
# ---------------------------------------------------------------------------
# Forum channel tests
# ---------------------------------------------------------------------------
import discord as _discord_mod # noqa: E402 — imported after _ensure_discord_mock
class TestIsForumParent:
def test_none_returns_false(self):
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
assert adapter._is_forum_parent(None) is False
def test_forum_channel_class_instance(self):
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
forum_cls = getattr(_discord_mod, "ForumChannel", None)
if forum_cls is None:
# Re-create a type for the mock
forum_cls = type("ForumChannel", (), {})
_discord_mod.ForumChannel = forum_cls
ch = forum_cls()
assert adapter._is_forum_parent(ch) is True
def test_type_value_15(self):
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
ch = SimpleNamespace(type=15)
assert adapter._is_forum_parent(ch) is True
def test_regular_channel_returns_false(self):
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
ch = SimpleNamespace(type=0)
assert adapter._is_forum_parent(ch) is False
def test_thread_returns_false(self):
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
ch = SimpleNamespace(type=11) # public thread
assert adapter._is_forum_parent(ch) is False
@pytest.mark.asyncio
async def test_send_to_forum_creates_thread_post():
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
# thread object has no 'send' so _send_to_forum uses thread.thread
thread_ch = SimpleNamespace(id=555, send=AsyncMock(return_value=SimpleNamespace(id=600)))
thread = SimpleNamespace(
id=555,
message=SimpleNamespace(id=500),
thread=thread_ch,
)
forum_channel = _discord_mod.ForumChannel()
forum_channel.id = 999
forum_channel.name = "ideas"
forum_channel.create_thread = AsyncMock(return_value=thread)
adapter._client = SimpleNamespace(
get_channel=lambda _chat_id: forum_channel,
fetch_channel=AsyncMock(),
)
result = await adapter.send("999", "Hello forum!")
assert result.success is True
assert result.message_id == "500"
forum_channel.create_thread.assert_awaited_once()
@pytest.mark.asyncio
async def test_send_to_forum_sends_remaining_chunks():
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
# Force a small max message length so the message splits
adapter.MAX_MESSAGE_LENGTH = 20
chunk_msg_1 = SimpleNamespace(id=500)
chunk_msg_2 = SimpleNamespace(id=501)
thread_ch = SimpleNamespace(
id=555,
send=AsyncMock(return_value=chunk_msg_2),
)
# thread object has no 'send' so _send_to_forum uses thread.thread
thread = SimpleNamespace(
id=555,
message=chunk_msg_1,
thread=thread_ch,
)
forum_channel = _discord_mod.ForumChannel()
forum_channel.id = 999
forum_channel.name = "ideas"
forum_channel.create_thread = AsyncMock(return_value=thread)
adapter._client = SimpleNamespace(
get_channel=lambda _chat_id: forum_channel,
fetch_channel=AsyncMock(),
)
result = await adapter.send("999", "A" * 50)
assert result.success is True
assert result.message_id == "500"
# Should have sent at least one follow-up chunk
assert thread_ch.send.await_count >= 1
@pytest.mark.asyncio
async def test_send_to_forum_create_thread_failure():
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
forum_channel = _discord_mod.ForumChannel()
forum_channel.id = 999
forum_channel.name = "ideas"
forum_channel.create_thread = AsyncMock(side_effect=Exception("rate limited"))
adapter._client = SimpleNamespace(
get_channel=lambda _chat_id: forum_channel,
fetch_channel=AsyncMock(),
)
result = await adapter.send("999", "Hello forum!")
assert result.success is False
assert "rate limited" in result.error
# ---------------------------------------------------------------------------
# Forum follow-up chunk failure reporting + media on forum paths
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_send_to_forum_follow_up_chunk_failures_collected_as_warnings():
"""Partial-send chunk failures surface in raw_response['warnings']."""
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
adapter.MAX_MESSAGE_LENGTH = 20
chunk_msg_1 = SimpleNamespace(id=500)
# Every follow-up chunk fails — we should collect a warning per failure
thread_ch = SimpleNamespace(
id=555,
send=AsyncMock(side_effect=Exception("rate limited")),
)
thread = SimpleNamespace(id=555, message=chunk_msg_1, thread=thread_ch)
forum_channel = _discord_mod.ForumChannel()
forum_channel.id = 999
forum_channel.name = "ideas"
forum_channel.create_thread = AsyncMock(return_value=thread)
adapter._client = SimpleNamespace(
get_channel=lambda _chat_id: forum_channel,
fetch_channel=AsyncMock(),
)
# Long enough to produce multiple chunks
result = await adapter.send("999", "A" * 60)
# Starter message (first chunk) was delivered via create_thread, so send is
# successful overall — but follow-up chunks all failed and are reported.
assert result.success is True
assert result.message_id == "500"
warnings = (result.raw_response or {}).get("warnings") or []
assert len(warnings) >= 1
assert all("rate limited" in w for w in warnings)
@pytest.mark.asyncio
async def test_forum_post_file_creates_thread_with_attachment():
"""_forum_post_file routes file-bearing sends to create_thread with file kwarg."""
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
thread_ch = SimpleNamespace(id=777, send=AsyncMock())
thread = SimpleNamespace(id=777, message=SimpleNamespace(id=800), thread=thread_ch)
forum_channel = _discord_mod.ForumChannel()
forum_channel.id = 999
forum_channel.name = "ideas"
forum_channel.create_thread = AsyncMock(return_value=thread)
# discord.File is a real class; build a MagicMock that looks like one
fake_file = SimpleNamespace(filename="photo.png")
result = await adapter._forum_post_file(
forum_channel,
content="here is a photo",
file=fake_file,
)
assert result.success is True
assert result.message_id == "800"
forum_channel.create_thread.assert_awaited_once()
call_kwargs = forum_channel.create_thread.await_args.kwargs
assert call_kwargs["file"] is fake_file
assert call_kwargs["content"] == "here is a photo"
# Thread name derived from content's first line
assert call_kwargs["name"] == "here is a photo"
@pytest.mark.asyncio
async def test_forum_post_file_uses_filename_when_no_content():
"""Thread name falls back to file.filename when no content is provided."""
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
thread = SimpleNamespace(id=1, message=SimpleNamespace(id=2), thread=SimpleNamespace(id=1, send=AsyncMock()))
forum_channel = _discord_mod.ForumChannel()
forum_channel.id = 10
forum_channel.name = "forum"
forum_channel.create_thread = AsyncMock(return_value=thread)
fake_file = SimpleNamespace(filename="voice-message.ogg")
result = await adapter._forum_post_file(forum_channel, content="", file=fake_file)
assert result.success is True
call_kwargs = forum_channel.create_thread.await_args.kwargs
# Content was empty → thread name derived from filename
assert call_kwargs["name"] == "voice-message.ogg"
@pytest.mark.asyncio
async def test_forum_post_file_creation_failure():
"""_forum_post_file returns a failed SendResult when create_thread raises."""
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
forum_channel = _discord_mod.ForumChannel()
forum_channel.id = 999
forum_channel.create_thread = AsyncMock(side_effect=Exception("missing perms"))
result = await adapter._forum_post_file(
forum_channel,
content="hi",
file=SimpleNamespace(filename="x.png"),
)
assert result.success is False
assert "missing perms" in (result.error or "")

View file

@ -601,6 +601,10 @@ class TestAdapterBehavior(unittest.TestCase):
calls.append("message_recalled")
return self
def register_p2_customized_event(self, event_key, _handler):
calls.append(f"customized:{event_key}")
return self
def build(self):
calls.append("build")
return "handler"
@ -628,6 +632,7 @@ class TestAdapterBehavior(unittest.TestCase):
"bot_deleted",
"p2p_chat_entered",
"message_recalled",
"customized:drive.notice.comment_add_v1",
"build",
],
)

View file

@ -0,0 +1,261 @@
"""Tests for feishu_comment — event filtering, access control integration, wiki reverse lookup."""
import asyncio
import json
import unittest
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock, patch
from gateway.platforms.feishu_comment import (
parse_drive_comment_event,
_ALLOWED_NOTICE_TYPES,
_sanitize_comment_text,
)
def _make_event(
comment_id="c1",
reply_id="r1",
notice_type="add_reply",
file_token="docx_token",
file_type="docx",
from_open_id="ou_user",
to_open_id="ou_bot",
is_mentioned=True,
):
"""Build a minimal drive comment event SimpleNamespace."""
return SimpleNamespace(event={
"event_id": "evt_1",
"comment_id": comment_id,
"reply_id": reply_id,
"is_mentioned": is_mentioned,
"timestamp": "1713200000",
"notice_meta": {
"file_token": file_token,
"file_type": file_type,
"notice_type": notice_type,
"from_user_id": {"open_id": from_open_id},
"to_user_id": {"open_id": to_open_id},
},
})
class TestParseEvent(unittest.TestCase):
def test_parse_valid_event(self):
evt = _make_event()
parsed = parse_drive_comment_event(evt)
self.assertIsNotNone(parsed)
self.assertEqual(parsed["comment_id"], "c1")
self.assertEqual(parsed["file_type"], "docx")
self.assertEqual(parsed["from_open_id"], "ou_user")
self.assertEqual(parsed["to_open_id"], "ou_bot")
def test_parse_missing_event_attr(self):
self.assertIsNone(parse_drive_comment_event(object()))
def test_parse_none_event(self):
self.assertIsNone(parse_drive_comment_event(SimpleNamespace()))
class TestEventFiltering(unittest.TestCase):
"""Test the filtering logic in handle_drive_comment_event."""
def _run(self, coro):
return asyncio.get_event_loop().run_until_complete(coro)
@patch("gateway.platforms.feishu_comment_rules.load_config")
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed")
def test_self_reply_filtered(self, mock_allowed, mock_resolve, mock_load):
"""Events where from_open_id == self_open_id should be dropped."""
from gateway.platforms.feishu_comment import handle_drive_comment_event
evt = _make_event(from_open_id="ou_bot", to_open_id="ou_bot")
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
mock_load.assert_not_called()
@patch("gateway.platforms.feishu_comment_rules.load_config")
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed")
def test_wrong_receiver_filtered(self, mock_allowed, mock_resolve, mock_load):
"""Events where to_open_id != self_open_id should be dropped."""
from gateway.platforms.feishu_comment import handle_drive_comment_event
evt = _make_event(to_open_id="ou_other_bot")
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
mock_load.assert_not_called()
@patch("gateway.platforms.feishu_comment_rules.load_config")
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed")
def test_empty_to_open_id_filtered(self, mock_allowed, mock_resolve, mock_load):
"""Events with empty to_open_id should be dropped."""
from gateway.platforms.feishu_comment import handle_drive_comment_event
evt = _make_event(to_open_id="")
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
mock_load.assert_not_called()
@patch("gateway.platforms.feishu_comment_rules.load_config")
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed")
def test_invalid_notice_type_filtered(self, mock_allowed, mock_resolve, mock_load):
"""Events with unsupported notice_type should be dropped."""
from gateway.platforms.feishu_comment import handle_drive_comment_event
evt = _make_event(notice_type="resolve_comment")
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
mock_load.assert_not_called()
def test_allowed_notice_types(self):
self.assertIn("add_comment", _ALLOWED_NOTICE_TYPES)
self.assertIn("add_reply", _ALLOWED_NOTICE_TYPES)
self.assertNotIn("resolve_comment", _ALLOWED_NOTICE_TYPES)
class TestAccessControlIntegration(unittest.TestCase):
def _run(self, coro):
return asyncio.get_event_loop().run_until_complete(coro)
@patch("gateway.platforms.feishu_comment_rules.has_wiki_keys", return_value=False)
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed", return_value=False)
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
@patch("gateway.platforms.feishu_comment_rules.load_config")
def test_denied_user_no_side_effects(self, mock_load, mock_resolve, mock_allowed, mock_wiki_keys):
"""Denied user should not trigger typing reaction or agent."""
from gateway.platforms.feishu_comment import handle_drive_comment_event
from gateway.platforms.feishu_comment_rules import ResolvedCommentRule
mock_resolve.return_value = ResolvedCommentRule(True, "allowlist", frozenset(), "top")
mock_load.return_value = Mock()
client = Mock()
evt = _make_event()
self._run(handle_drive_comment_event(client, evt, self_open_id="ou_bot"))
# No API calls should be made for denied users
client.request.assert_not_called()
@patch("gateway.platforms.feishu_comment_rules.has_wiki_keys", return_value=False)
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed", return_value=False)
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
@patch("gateway.platforms.feishu_comment_rules.load_config")
def test_disabled_comment_skipped(self, mock_load, mock_resolve, mock_allowed, mock_wiki_keys):
"""Disabled comments should return immediately."""
from gateway.platforms.feishu_comment import handle_drive_comment_event
from gateway.platforms.feishu_comment_rules import ResolvedCommentRule
mock_resolve.return_value = ResolvedCommentRule(False, "allowlist", frozenset(), "top")
mock_load.return_value = Mock()
evt = _make_event()
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
mock_allowed.assert_not_called()
class TestSanitizeCommentText(unittest.TestCase):
def test_angle_brackets_escaped(self):
self.assertEqual(_sanitize_comment_text("List<String>"), "List&lt;String&gt;")
def test_ampersand_escaped_first(self):
self.assertEqual(_sanitize_comment_text("a & b"), "a &amp; b")
def test_ampersand_not_double_escaped(self):
result = _sanitize_comment_text("a < b & c > d")
self.assertEqual(result, "a &lt; b &amp; c &gt; d")
self.assertNotIn("&amp;lt;", result)
self.assertNotIn("&amp;gt;", result)
def test_plain_text_unchanged(self):
self.assertEqual(_sanitize_comment_text("hello world"), "hello world")
def test_empty_string(self):
self.assertEqual(_sanitize_comment_text(""), "")
def test_code_snippet(self):
text = 'if (a < b && c > 0) { return "ok"; }'
result = _sanitize_comment_text(text)
self.assertNotIn("<", result)
self.assertNotIn(">", result)
self.assertIn("&lt;", result)
self.assertIn("&gt;", result)
class TestWikiReverseLookup(unittest.TestCase):
def _run(self, coro):
return asyncio.get_event_loop().run_until_complete(coro)
@patch("gateway.platforms.feishu_comment._exec_request")
def test_reverse_lookup_success(self, mock_exec):
from gateway.platforms.feishu_comment import _reverse_lookup_wiki_token
mock_exec.return_value = (0, "Success", {
"node": {"node_token": "WIKI_TOKEN_123", "obj_token": "docx_abc"},
})
result = self._run(_reverse_lookup_wiki_token(Mock(), "docx", "docx_abc"))
self.assertEqual(result, "WIKI_TOKEN_123")
# Verify correct API params
call_args = mock_exec.call_args
queries = call_args[1].get("queries") or call_args[0][3]
query_dict = dict(queries)
self.assertEqual(query_dict["token"], "docx_abc")
self.assertEqual(query_dict["obj_type"], "docx")
@patch("gateway.platforms.feishu_comment._exec_request")
def test_reverse_lookup_not_wiki(self, mock_exec):
from gateway.platforms.feishu_comment import _reverse_lookup_wiki_token
mock_exec.return_value = (131001, "not found", {})
result = self._run(_reverse_lookup_wiki_token(Mock(), "docx", "docx_abc"))
self.assertIsNone(result)
@patch("gateway.platforms.feishu_comment._exec_request")
def test_reverse_lookup_service_error(self, mock_exec):
from gateway.platforms.feishu_comment import _reverse_lookup_wiki_token
mock_exec.return_value = (500, "internal error", {})
result = self._run(_reverse_lookup_wiki_token(Mock(), "docx", "docx_abc"))
self.assertIsNone(result)
@patch("gateway.platforms.feishu_comment._reverse_lookup_wiki_token", new_callable=AsyncMock)
@patch("gateway.platforms.feishu_comment_rules.has_wiki_keys", return_value=True)
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed", return_value=True)
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
@patch("gateway.platforms.feishu_comment_rules.load_config")
@patch("gateway.platforms.feishu_comment.add_comment_reaction", new_callable=AsyncMock)
@patch("gateway.platforms.feishu_comment.batch_query_comment", new_callable=AsyncMock)
@patch("gateway.platforms.feishu_comment.query_document_meta", new_callable=AsyncMock)
def test_wiki_lookup_triggered_when_no_exact_match(
self, mock_meta, mock_batch, mock_reaction,
mock_load, mock_resolve, mock_allowed, mock_wiki_keys, mock_lookup,
):
"""Wiki reverse lookup should fire when rule falls to wildcard/top and wiki keys exist."""
from gateway.platforms.feishu_comment import handle_drive_comment_event
from gateway.platforms.feishu_comment_rules import ResolvedCommentRule
# First resolve returns wildcard (no exact match), second returns exact wiki match
mock_resolve.side_effect = [
ResolvedCommentRule(True, "allowlist", frozenset(), "wildcard"),
ResolvedCommentRule(True, "allowlist", frozenset(), "exact:wiki:WIKI123"),
]
mock_load.return_value = Mock()
mock_lookup.return_value = "WIKI123"
mock_meta.return_value = {"title": "Test", "url": ""}
mock_batch.return_value = {"is_whole": False, "quote": ""}
evt = _make_event()
# Will proceed past access control but fail later — that's OK, we just test the lookup
try:
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
except Exception:
pass
mock_lookup.assert_called_once_with(unittest.mock.ANY, "docx", "docx_token")
self.assertEqual(mock_resolve.call_count, 2)
# Second call should include wiki_token
second_call_kwargs = mock_resolve.call_args_list[1]
self.assertEqual(second_call_kwargs[1].get("wiki_token") or second_call_kwargs[0][3], "WIKI123")
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,320 @@
"""Tests for feishu_comment_rules — 3-tier access control rule engine."""
import json
import os
import tempfile
import time
import unittest
from pathlib import Path
from unittest.mock import patch
from gateway.platforms.feishu_comment_rules import (
CommentsConfig,
CommentDocumentRule,
ResolvedCommentRule,
_MtimeCache,
_parse_document_rule,
has_wiki_keys,
is_user_allowed,
load_config,
pairing_add,
pairing_list,
pairing_remove,
resolve_rule,
)
class TestCommentDocumentRuleParsing(unittest.TestCase):
def test_parse_full_rule(self):
rule = _parse_document_rule({
"enabled": False,
"policy": "allowlist",
"allow_from": ["ou_a", "ou_b"],
})
self.assertFalse(rule.enabled)
self.assertEqual(rule.policy, "allowlist")
self.assertEqual(rule.allow_from, frozenset(["ou_a", "ou_b"]))
def test_parse_partial_rule(self):
rule = _parse_document_rule({"policy": "allowlist"})
self.assertIsNone(rule.enabled)
self.assertEqual(rule.policy, "allowlist")
self.assertIsNone(rule.allow_from)
def test_parse_empty_rule(self):
rule = _parse_document_rule({})
self.assertIsNone(rule.enabled)
self.assertIsNone(rule.policy)
self.assertIsNone(rule.allow_from)
def test_invalid_policy_ignored(self):
rule = _parse_document_rule({"policy": "invalid_value"})
self.assertIsNone(rule.policy)
class TestResolveRule(unittest.TestCase):
def test_exact_match(self):
cfg = CommentsConfig(
policy="pairing",
allow_from=frozenset(["ou_top"]),
documents={
"docx:abc": CommentDocumentRule(policy="allowlist"),
},
)
rule = resolve_rule(cfg, "docx", "abc")
self.assertEqual(rule.policy, "allowlist")
self.assertTrue(rule.match_source.startswith("exact:"))
def test_wildcard_match(self):
cfg = CommentsConfig(
policy="pairing",
documents={
"*": CommentDocumentRule(policy="allowlist"),
},
)
rule = resolve_rule(cfg, "docx", "unknown")
self.assertEqual(rule.policy, "allowlist")
self.assertEqual(rule.match_source, "wildcard")
def test_top_level_fallback(self):
cfg = CommentsConfig(policy="pairing", allow_from=frozenset(["ou_top"]))
rule = resolve_rule(cfg, "docx", "whatever")
self.assertEqual(rule.policy, "pairing")
self.assertEqual(rule.allow_from, frozenset(["ou_top"]))
self.assertEqual(rule.match_source, "top")
def test_exact_overrides_wildcard(self):
cfg = CommentsConfig(
policy="pairing",
documents={
"*": CommentDocumentRule(policy="pairing"),
"docx:abc": CommentDocumentRule(policy="allowlist"),
},
)
rule = resolve_rule(cfg, "docx", "abc")
self.assertEqual(rule.policy, "allowlist")
self.assertTrue(rule.match_source.startswith("exact:"))
def test_field_by_field_fallback(self):
"""Exact sets policy, wildcard sets allow_from, enabled from top."""
cfg = CommentsConfig(
enabled=True,
policy="pairing",
allow_from=frozenset(["ou_top"]),
documents={
"*": CommentDocumentRule(allow_from=frozenset(["ou_wildcard"])),
"docx:abc": CommentDocumentRule(policy="allowlist"),
},
)
rule = resolve_rule(cfg, "docx", "abc")
self.assertEqual(rule.policy, "allowlist")
self.assertEqual(rule.allow_from, frozenset(["ou_wildcard"]))
self.assertTrue(rule.enabled)
def test_explicit_empty_allow_from_does_not_fall_through(self):
"""allow_from=[] on exact should NOT inherit from wildcard or top."""
cfg = CommentsConfig(
allow_from=frozenset(["ou_top"]),
documents={
"*": CommentDocumentRule(allow_from=frozenset(["ou_wildcard"])),
"docx:abc": CommentDocumentRule(
policy="allowlist",
allow_from=frozenset(),
),
},
)
rule = resolve_rule(cfg, "docx", "abc")
self.assertEqual(rule.allow_from, frozenset())
def test_wiki_token_match(self):
cfg = CommentsConfig(
policy="pairing",
documents={
"wiki:WIKI123": CommentDocumentRule(policy="allowlist"),
},
)
rule = resolve_rule(cfg, "docx", "obj_token", wiki_token="WIKI123")
self.assertEqual(rule.policy, "allowlist")
self.assertTrue(rule.match_source.startswith("exact:wiki:"))
def test_exact_takes_priority_over_wiki(self):
cfg = CommentsConfig(
documents={
"docx:abc": CommentDocumentRule(policy="allowlist"),
"wiki:WIKI123": CommentDocumentRule(policy="pairing"),
},
)
rule = resolve_rule(cfg, "docx", "abc", wiki_token="WIKI123")
self.assertEqual(rule.policy, "allowlist")
self.assertTrue(rule.match_source.startswith("exact:docx:"))
def test_default_config(self):
cfg = CommentsConfig()
rule = resolve_rule(cfg, "docx", "anything")
self.assertTrue(rule.enabled)
self.assertEqual(rule.policy, "pairing")
self.assertEqual(rule.allow_from, frozenset())
class TestHasWikiKeys(unittest.TestCase):
def test_no_wiki_keys(self):
cfg = CommentsConfig(documents={
"docx:abc": CommentDocumentRule(policy="allowlist"),
"*": CommentDocumentRule(policy="pairing"),
})
self.assertFalse(has_wiki_keys(cfg))
def test_has_wiki_keys(self):
cfg = CommentsConfig(documents={
"wiki:WIKI123": CommentDocumentRule(policy="allowlist"),
})
self.assertTrue(has_wiki_keys(cfg))
def test_empty_documents(self):
cfg = CommentsConfig()
self.assertFalse(has_wiki_keys(cfg))
class TestIsUserAllowed(unittest.TestCase):
def test_allowlist_allows_listed(self):
rule = ResolvedCommentRule(True, "allowlist", frozenset(["ou_a"]), "top")
self.assertTrue(is_user_allowed(rule, "ou_a"))
def test_allowlist_denies_unlisted(self):
rule = ResolvedCommentRule(True, "allowlist", frozenset(["ou_a"]), "top")
self.assertFalse(is_user_allowed(rule, "ou_b"))
def test_allowlist_empty_denies_all(self):
rule = ResolvedCommentRule(True, "allowlist", frozenset(), "top")
self.assertFalse(is_user_allowed(rule, "ou_anyone"))
def test_pairing_allows_in_allow_from(self):
rule = ResolvedCommentRule(True, "pairing", frozenset(["ou_a"]), "top")
self.assertTrue(is_user_allowed(rule, "ou_a"))
def test_pairing_checks_store(self):
rule = ResolvedCommentRule(True, "pairing", frozenset(), "top")
with patch(
"gateway.platforms.feishu_comment_rules._load_pairing_approved",
return_value={"ou_approved"},
):
self.assertTrue(is_user_allowed(rule, "ou_approved"))
self.assertFalse(is_user_allowed(rule, "ou_unknown"))
class TestMtimeCache(unittest.TestCase):
def test_returns_empty_dict_for_missing_file(self):
cache = _MtimeCache(Path("/nonexistent/path.json"))
self.assertEqual(cache.load(), {})
def test_reads_file_and_caches(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump({"key": "value"}, f)
f.flush()
path = Path(f.name)
try:
cache = _MtimeCache(path)
data = cache.load()
self.assertEqual(data, {"key": "value"})
# Second load should use cache (same mtime)
data2 = cache.load()
self.assertEqual(data2, {"key": "value"})
finally:
path.unlink()
def test_reloads_on_mtime_change(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump({"v": 1}, f)
f.flush()
path = Path(f.name)
try:
cache = _MtimeCache(path)
self.assertEqual(cache.load(), {"v": 1})
# Modify file
time.sleep(0.05)
with open(path, "w") as f2:
json.dump({"v": 2}, f2)
# Force mtime change detection
os.utime(path, (time.time() + 1, time.time() + 1))
self.assertEqual(cache.load(), {"v": 2})
finally:
path.unlink()
class TestLoadConfig(unittest.TestCase):
def test_load_with_documents(self):
raw = {
"enabled": True,
"policy": "allowlist",
"allow_from": ["ou_a"],
"documents": {
"*": {"policy": "pairing"},
"docx:abc": {"policy": "allowlist", "allow_from": ["ou_b"]},
},
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(raw, f)
path = Path(f.name)
try:
with patch("gateway.platforms.feishu_comment_rules.RULES_FILE", path):
with patch("gateway.platforms.feishu_comment_rules._rules_cache", _MtimeCache(path)):
cfg = load_config()
self.assertTrue(cfg.enabled)
self.assertEqual(cfg.policy, "allowlist")
self.assertEqual(cfg.allow_from, frozenset(["ou_a"]))
self.assertIn("*", cfg.documents)
self.assertIn("docx:abc", cfg.documents)
self.assertEqual(cfg.documents["docx:abc"].policy, "allowlist")
finally:
path.unlink()
def test_load_missing_file_returns_defaults(self):
with patch("gateway.platforms.feishu_comment_rules._rules_cache", _MtimeCache(Path("/nonexistent"))):
cfg = load_config()
self.assertTrue(cfg.enabled)
self.assertEqual(cfg.policy, "pairing")
self.assertEqual(cfg.allow_from, frozenset())
self.assertEqual(cfg.documents, {})
class TestPairingStore(unittest.TestCase):
def setUp(self):
self._tmpdir = tempfile.mkdtemp()
self._pairing_file = Path(self._tmpdir) / "pairing.json"
with open(self._pairing_file, "w") as f:
json.dump({"approved": {}}, f)
self._patcher_file = patch("gateway.platforms.feishu_comment_rules.PAIRING_FILE", self._pairing_file)
self._patcher_cache = patch(
"gateway.platforms.feishu_comment_rules._pairing_cache",
_MtimeCache(self._pairing_file),
)
self._patcher_file.start()
self._patcher_cache.start()
def tearDown(self):
self._patcher_cache.stop()
self._patcher_file.stop()
if self._pairing_file.exists():
self._pairing_file.unlink()
os.rmdir(self._tmpdir)
def test_add_and_list(self):
self.assertTrue(pairing_add("ou_new"))
approved = pairing_list()
self.assertIn("ou_new", approved)
def test_add_duplicate(self):
pairing_add("ou_a")
self.assertFalse(pairing_add("ou_a"))
def test_remove(self):
pairing_add("ou_a")
self.assertTrue(pairing_remove("ou_a"))
self.assertNotIn("ou_a", pairing_list())
def test_remove_nonexistent(self):
self.assertFalse(pairing_remove("ou_nobody"))
if __name__ == "__main__":
unittest.main()

View file

@ -179,7 +179,7 @@ class TestVoiceAttachmentSSRFProtection:
from gateway.platforms.qqbot import QQAdapter, _ssrf_redirect_guard
client = mock.AsyncMock()
with mock.patch("gateway.platforms.qqbot.httpx.AsyncClient", return_value=client) as async_client_cls:
with mock.patch("gateway.platforms.qqbot.adapter.httpx.AsyncClient", return_value=client) as async_client_cls:
adapter = QQAdapter(_make_config(app_id="a", client_secret="b"))
adapter._ensure_token = mock.AsyncMock(side_effect=RuntimeError("stop after client creation"))

View file

@ -202,3 +202,120 @@ async def test_start_gateway_replace_force_uses_terminate_pid(monkeypatch, tmp_p
assert ok is True
assert calls == [(42, False), (42, True)]
@pytest.mark.asyncio
async def test_start_gateway_replace_writes_takeover_marker_before_sigterm(
monkeypatch, tmp_path
):
"""--replace must write a takeover marker BEFORE sending SIGTERM.
The marker lets the target's shutdown handler identify the signal as a
planned takeover ( exit 0) rather than an unexpected kill ( exit 1).
Without the marker, PR #5646's signal-recovery path would revive the
target via systemd Restart=on-failure, starting a flap loop.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
# Record the ORDER of marker-write + terminate_pid calls
events: list[str] = []
marker_paths_seen: list = []
def record_write_marker(target_pid: int) -> bool:
events.append(f"write_marker(target_pid={target_pid})")
# Also check that the marker file actually exists after this call
marker_paths_seen.append(
(tmp_path / ".gateway-takeover.json").exists() is False # not yet
)
# Actually write the marker so we can verify cleanup later
from gateway.status import _get_takeover_marker_path, _write_json_file, _get_process_start_time
_write_json_file(_get_takeover_marker_path(), {
"target_pid": target_pid,
"target_start_time": 0,
"replacer_pid": 100,
"written_at": "2026-04-17T00:00:00+00:00",
})
return True
def record_terminate(pid, force=False):
events.append(f"terminate_pid(pid={pid}, force={force})")
class _CleanExitRunner:
def __init__(self, config):
self.config = config
self.should_exit_cleanly = True
self.exit_reason = None
self.adapters = {}
async def start(self):
return True
async def stop(self):
return None
monkeypatch.setattr("gateway.status.get_running_pid", lambda: 42)
monkeypatch.setattr("gateway.status.remove_pid_file", lambda: None)
monkeypatch.setattr("gateway.status.release_all_scoped_locks", lambda: 0)
monkeypatch.setattr("gateway.status.write_takeover_marker", record_write_marker)
monkeypatch.setattr("gateway.status.terminate_pid", record_terminate)
monkeypatch.setattr("gateway.run.os.getpid", lambda: 100)
# Simulate old process exiting on first check so we don't loop into force-kill
monkeypatch.setattr(
"gateway.run.os.kill",
lambda pid, sig: (_ for _ in ()).throw(ProcessLookupError()),
)
monkeypatch.setattr("time.sleep", lambda _: None)
monkeypatch.setattr("tools.skills_sync.sync_skills", lambda quiet=True: None)
monkeypatch.setattr("hermes_logging.setup_logging", lambda hermes_home, mode: tmp_path)
monkeypatch.setattr("hermes_logging._add_rotating_handler", lambda *args, **kwargs: None)
monkeypatch.setattr("gateway.run.GatewayRunner", _CleanExitRunner)
from gateway.run import start_gateway
ok = await start_gateway(config=GatewayConfig(), replace=True, verbosity=None)
assert ok is True
# Ordering: marker written BEFORE SIGTERM
assert events[0] == "write_marker(target_pid=42)"
assert any(e.startswith("terminate_pid(pid=42") for e in events[1:])
# Marker file cleanup: replacer cleans it after loop completes
assert not (tmp_path / ".gateway-takeover.json").exists()
@pytest.mark.asyncio
async def test_start_gateway_replace_clears_marker_on_permission_denied(
monkeypatch, tmp_path
):
"""If we fail to kill the existing PID (permission denied), clean up the
marker so it doesn't grief an unrelated future shutdown."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
def write_marker(target_pid: int) -> bool:
from gateway.status import _get_takeover_marker_path, _write_json_file
_write_json_file(_get_takeover_marker_path(), {
"target_pid": target_pid,
"target_start_time": 0,
"replacer_pid": 100,
"written_at": "2026-04-17T00:00:00+00:00",
})
return True
def raise_permission(pid, force=False):
raise PermissionError("simulated EPERM")
monkeypatch.setattr("gateway.status.get_running_pid", lambda: 42)
monkeypatch.setattr("gateway.status.write_takeover_marker", write_marker)
monkeypatch.setattr("gateway.status.terminate_pid", raise_permission)
monkeypatch.setattr("gateway.run.os.getpid", lambda: 100)
monkeypatch.setattr("tools.skills_sync.sync_skills", lambda quiet=True: None)
monkeypatch.setattr("hermes_logging.setup_logging", lambda hermes_home, mode: tmp_path)
monkeypatch.setattr("hermes_logging._add_rotating_handler", lambda *args, **kwargs: None)
from gateway.run import start_gateway
# Should return False due to permission error
ok = await start_gateway(config=GatewayConfig(), replace=True, verbosity=None)
assert ok is False
# Marker must NOT be left behind
assert not (tmp_path / ".gateway-takeover.json").exists()

View file

@ -288,6 +288,38 @@ async def test_command_messages_do_not_leave_sentinel():
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
("command_text", "handler_attr", "handler_result"),
[
("/help", "_handle_help_command", "Help text"),
("/commands", "_handle_commands_command", "Commands text"),
("/update", "_handle_update_command", "Update text"),
("/profile", "_handle_profile_command", "Profile text"),
],
)
async def test_active_session_bypass_commands_dispatch_without_interrupt(
command_text,
handler_attr,
handler_result,
):
"""Gateway-handled bypass commands must return directly while an agent runs."""
runner = _make_runner()
event = _make_event(text=command_text)
session_key = build_session_key(event.source)
fake_agent = MagicMock()
fake_agent.get_activity_summary.return_value = {"seconds_since_activity": 0}
runner._running_agents[session_key] = fake_agent
setattr(runner, handler_attr, AsyncMock(return_value=handler_result))
result = await runner._handle_message(event)
assert result == handler_result
fake_agent.interrupt.assert_not_called()
assert session_key not in runner.adapters[Platform.TELEGRAM]._pending_messages
# ------------------------------------------------------------------
# Test 6: /stop during sentinel force-cleans and unlocks session
# ------------------------------------------------------------------

View file

@ -0,0 +1,231 @@
"""Regression tests for _release_running_agent_state and SessionDB shutdown.
Before this change, running-agent state lived in three dicts that drifted
out of sync:
self._running_agents AIAgent instance per session key
self._running_agents_ts start timestamp per session key
self._busy_ack_ts last busy-ack timestamp per session key
Six cleanup sites did ``del self._running_agents[key]`` without touching
the other two; one site only popped ``_running_agents`` and
``_running_agents_ts``; and only the stale-eviction site cleaned all
three. Each missed entry was a small persistent leak.
Also: SessionDB connections were never closed on gateway shutdown,
leaving WAL locks in place until Python actually exited.
"""
import threading
from unittest.mock import MagicMock
import pytest
def _make_runner():
"""Bare GatewayRunner wired with just the state the helper touches."""
from gateway.run import GatewayRunner
runner = GatewayRunner.__new__(GatewayRunner)
runner._running_agents = {}
runner._running_agents_ts = {}
runner._busy_ack_ts = {}
return runner
class TestReleaseRunningAgentStateUnit:
def test_pops_all_three_dicts(self):
runner = _make_runner()
runner._running_agents["k"] = MagicMock()
runner._running_agents_ts["k"] = 123.0
runner._busy_ack_ts["k"] = 456.0
runner._release_running_agent_state("k")
assert "k" not in runner._running_agents
assert "k" not in runner._running_agents_ts
assert "k" not in runner._busy_ack_ts
def test_idempotent_on_missing_key(self):
"""Calling twice (or on an absent key) must not raise."""
runner = _make_runner()
runner._release_running_agent_state("missing")
runner._release_running_agent_state("missing") # still fine
def test_noop_on_empty_session_key(self):
"""Empty string / None key is treated as a no-op."""
runner = _make_runner()
runner._running_agents[""] = "guard"
runner._release_running_agent_state("")
# Empty key not processed — guard value survives.
assert runner._running_agents[""] == "guard"
def test_preserves_other_sessions(self):
runner = _make_runner()
for k in ("a", "b", "c"):
runner._running_agents[k] = MagicMock()
runner._running_agents_ts[k] = 1.0
runner._busy_ack_ts[k] = 1.0
runner._release_running_agent_state("b")
assert set(runner._running_agents.keys()) == {"a", "c"}
assert set(runner._running_agents_ts.keys()) == {"a", "c"}
assert set(runner._busy_ack_ts.keys()) == {"a", "c"}
def test_handles_missing_busy_ack_attribute(self):
"""Backward-compatible with older runners lacking _busy_ack_ts."""
runner = _make_runner()
del runner._busy_ack_ts # simulate older version
runner._running_agents["k"] = MagicMock()
runner._running_agents_ts["k"] = 1.0
runner._release_running_agent_state("k") # should not raise
assert "k" not in runner._running_agents
assert "k" not in runner._running_agents_ts
def test_concurrent_release_is_safe(self):
"""Multiple threads releasing different keys concurrently."""
runner = _make_runner()
for i in range(50):
k = f"s{i}"
runner._running_agents[k] = MagicMock()
runner._running_agents_ts[k] = float(i)
runner._busy_ack_ts[k] = float(i)
def worker(keys):
for k in keys:
runner._release_running_agent_state(k)
threads = [
threading.Thread(target=worker, args=([f"s{i}" for i in range(start, 50, 5)],))
for start in range(5)
]
for t in threads:
t.start()
for t in threads:
t.join(timeout=5)
assert not t.is_alive()
assert runner._running_agents == {}
assert runner._running_agents_ts == {}
assert runner._busy_ack_ts == {}
class TestNoMoreBareDeleteSites:
"""Regression: all bare `del self._running_agents[key]` sites were
converted to use the helper. If a future contributor reverts one,
this test flags it. Docstrings / comments mentioning the old
pattern are allowed.
"""
def test_no_bare_del_of_running_agents_in_gateway_run(self):
from pathlib import Path
import re
gateway_run = (Path(__file__).parent.parent.parent / "gateway" / "run.py").read_text()
# Match `del self._running_agents[...]` that is NOT inside a
# triple-quoted docstring. We scan non-docstring lines only.
lines = gateway_run.splitlines()
in_docstring = False
docstring_delim = None
offenders = []
for idx, line in enumerate(lines, start=1):
stripped = line.strip()
if not in_docstring:
if stripped.startswith('"""') or stripped.startswith("'''"):
delim = stripped[:3]
# single-line docstring?
if stripped.count(delim) >= 2:
continue
in_docstring = True
docstring_delim = delim
continue
if re.search(r"\bdel\s+self\._running_agents\[", line):
offenders.append((idx, line.rstrip()))
else:
if docstring_delim and docstring_delim in stripped:
in_docstring = False
docstring_delim = None
assert offenders == [], (
"Found bare `del self._running_agents[...]` sites in gateway/run.py. "
"Use self._release_running_agent_state(session_key) instead so "
"_running_agents_ts and _busy_ack_ts are popped in lockstep.\n"
+ "\n".join(f" line {n}: {l}" for n, l in offenders)
)
class TestSessionDbCloseOnShutdown:
"""_stop_impl should call .close() on both self._session_db and
self.session_store._db to release SQLite WAL locks before the new
gateway (during --replace restart) tries to open the same file.
"""
def test_stop_impl_closes_both_session_dbs(self):
"""Run the exact shutdown block that closes SessionDBs and verify
.close() was called on both holders."""
from gateway.run import GatewayRunner
runner = GatewayRunner.__new__(GatewayRunner)
runner_db = MagicMock()
store_db = MagicMock()
runner._db = runner_db
runner.session_store = MagicMock()
runner.session_store._db = store_db
# Replicate the exact production loop from _stop_impl.
for _db_holder in (runner, getattr(runner, "session_store", None)):
_db = getattr(_db_holder, "_db", None) if _db_holder else None
if _db is None or not hasattr(_db, "close"):
continue
_db.close()
runner_db.close.assert_called_once()
store_db.close.assert_called_once()
def test_shutdown_tolerates_missing_session_store(self):
"""Gateway without a session_store attribute must not crash on shutdown."""
from gateway.run import GatewayRunner
runner = GatewayRunner.__new__(GatewayRunner)
runner._db = MagicMock()
# Deliberately no session_store attribute.
for _db_holder in (runner, getattr(runner, "session_store", None)):
_db = getattr(_db_holder, "_db", None) if _db_holder else None
if _db is None or not hasattr(_db, "close"):
continue
_db.close()
runner._db.close.assert_called_once()
def test_shutdown_tolerates_close_raising(self):
"""A close() that raises must not prevent subsequent cleanup."""
from gateway.run import GatewayRunner
runner = GatewayRunner.__new__(GatewayRunner)
flaky_db = MagicMock()
flaky_db.close.side_effect = RuntimeError("simulated lock error")
healthy_db = MagicMock()
runner._db = flaky_db
runner.session_store = MagicMock()
runner.session_store._db = healthy_db
# Same pattern as production: try/except around each close().
for _db_holder in (runner, getattr(runner, "session_store", None)):
_db = getattr(_db_holder, "_db", None) if _db_holder else None
if _db is None or not hasattr(_db, "close"):
continue
try:
_db.close()
except Exception:
pass
flaky_db.close.assert_called_once()
healthy_db.close.assert_called_once()

View file

@ -0,0 +1,270 @@
"""Tests for SessionStore.prune_old_entries and the gateway watcher that calls it.
The SessionStore in-memory dict (and its backing sessions.json) grew
unbounded every unique (platform, chat_id, thread_id, user_id) tuple
ever seen was kept forever, regardless of how stale it became. These
tests pin the prune behaviour:
* Entries older than max_age_days (by updated_at) are removed
* Entries marked ``suspended`` are preserved (user-paused)
* Entries with an active process attached are preserved
* max_age_days <= 0 disables pruning entirely
* sessions.json is rewritten with the post-prune dict
* The ``updated_at`` field not ``created_at`` drives the decision
(so a long-running-but-still-active session isn't pruned)
"""
import json
import threading
from datetime import datetime, timedelta
from unittest.mock import patch
import pytest
from gateway.config import GatewayConfig, Platform, SessionResetPolicy
from gateway.session import SessionEntry, SessionStore
def _make_store(tmp_path, max_age_days: int = 90, has_active_processes_fn=None):
"""Build a SessionStore bypassing SQLite/disk-load side effects."""
config = GatewayConfig(
default_reset_policy=SessionResetPolicy(mode="none"),
session_store_max_age_days=max_age_days,
)
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(
sessions_dir=tmp_path,
config=config,
has_active_processes_fn=has_active_processes_fn,
)
store._db = None
store._loaded = True
return store
def _entry(key: str, age_days: float, *, suspended: bool = False,
session_id: str | None = None) -> SessionEntry:
now = datetime.now()
return SessionEntry(
session_key=key,
session_id=session_id or f"sid_{key}",
created_at=now - timedelta(days=age_days + 30), # arbitrary older
updated_at=now - timedelta(days=age_days),
platform=Platform.TELEGRAM,
chat_type="dm",
suspended=suspended,
)
class TestPruneBasics:
def test_prune_removes_entries_past_max_age(self, tmp_path):
store = _make_store(tmp_path)
store._entries["old"] = _entry("old", age_days=100)
store._entries["fresh"] = _entry("fresh", age_days=5)
removed = store.prune_old_entries(max_age_days=90)
assert removed == 1
assert "old" not in store._entries
assert "fresh" in store._entries
def test_prune_uses_updated_at_not_created_at(self, tmp_path):
"""A session created long ago but updated recently must be kept."""
store = _make_store(tmp_path)
now = datetime.now()
entry = SessionEntry(
session_key="long-lived",
session_id="sid",
created_at=now - timedelta(days=365), # ancient
updated_at=now - timedelta(days=3), # but just chatted
platform=Platform.TELEGRAM,
chat_type="dm",
)
store._entries["long-lived"] = entry
removed = store.prune_old_entries(max_age_days=30)
assert removed == 0
assert "long-lived" in store._entries
def test_prune_disabled_when_max_age_is_zero(self, tmp_path):
store = _make_store(tmp_path, max_age_days=0)
for i in range(5):
store._entries[f"s{i}"] = _entry(f"s{i}", age_days=365)
assert store.prune_old_entries(0) == 0
assert len(store._entries) == 5
def test_prune_disabled_when_max_age_is_negative(self, tmp_path):
store = _make_store(tmp_path)
store._entries["s"] = _entry("s", age_days=365)
assert store.prune_old_entries(-1) == 0
assert "s" in store._entries
def test_prune_skips_suspended_entries(self, tmp_path):
"""/stop-suspended sessions must be kept for later resume."""
store = _make_store(tmp_path)
store._entries["suspended"] = _entry(
"suspended", age_days=1000, suspended=True
)
store._entries["idle"] = _entry("idle", age_days=1000)
removed = store.prune_old_entries(max_age_days=90)
assert removed == 1
assert "suspended" in store._entries
assert "idle" not in store._entries
def test_prune_skips_entries_with_active_processes(self, tmp_path):
"""Sessions with active bg processes aren't pruned even if old."""
active_session_ids = {"sid_active"}
def _has_active(session_id: str) -> bool:
return session_id in active_session_ids
store = _make_store(tmp_path, has_active_processes_fn=_has_active)
store._entries["active"] = _entry(
"active", age_days=1000, session_id="sid_active"
)
store._entries["idle"] = _entry(
"idle", age_days=1000, session_id="sid_idle"
)
removed = store.prune_old_entries(max_age_days=90)
assert removed == 1
assert "active" in store._entries
assert "idle" not in store._entries
def test_prune_does_not_write_disk_when_no_removals(self, tmp_path):
"""If nothing is evictable, _save() should NOT be called."""
store = _make_store(tmp_path)
store._entries["fresh1"] = _entry("fresh1", age_days=1)
store._entries["fresh2"] = _entry("fresh2", age_days=2)
save_calls = []
store._save = lambda: save_calls.append(1)
assert store.prune_old_entries(max_age_days=90) == 0
assert save_calls == []
def test_prune_writes_disk_after_removal(self, tmp_path):
store = _make_store(tmp_path)
store._entries["stale"] = _entry("stale", age_days=500)
store._entries["fresh"] = _entry("fresh", age_days=1)
save_calls = []
store._save = lambda: save_calls.append(1)
store.prune_old_entries(max_age_days=90)
assert save_calls == [1]
def test_prune_is_thread_safe(self, tmp_path):
"""Prune acquires _lock internally; concurrent update_session is safe."""
store = _make_store(tmp_path)
for i in range(20):
age = 1000 if i % 2 == 0 else 1
store._entries[f"s{i}"] = _entry(f"s{i}", age_days=age)
results = []
def _pruner():
results.append(store.prune_old_entries(max_age_days=90))
def _reader():
# Mimic a concurrent update_session reader iterating under lock.
with store._lock:
list(store._entries.keys())
threads = [threading.Thread(target=_pruner)]
threads += [threading.Thread(target=_reader) for _ in range(4)]
for t in threads:
t.start()
for t in threads:
t.join(timeout=5)
assert not t.is_alive()
# Exactly one pruner ran; removed exactly the 10 stale entries.
assert results == [10]
assert len(store._entries) == 10
for i in range(20):
if i % 2 == 1: # fresh
assert f"s{i}" in store._entries
class TestPrunePersistsToDisk:
def test_prune_rewrites_sessions_json(self, tmp_path):
"""After prune, sessions.json on disk reflects the new dict."""
config = GatewayConfig(
default_reset_policy=SessionResetPolicy(mode="none"),
session_store_max_age_days=90,
)
store = SessionStore(sessions_dir=tmp_path, config=config)
store._db = None
# Force-populate without calling get_or_create to avoid DB side-effects
store._entries["stale"] = _entry("stale", age_days=500)
store._entries["fresh"] = _entry("fresh", age_days=1)
store._loaded = True
store._save()
# Verify pre-prune state on disk.
saved_pre = json.loads((tmp_path / "sessions.json").read_text())
assert set(saved_pre.keys()) == {"stale", "fresh"}
# Prune and check disk.
store.prune_old_entries(max_age_days=90)
saved_post = json.loads((tmp_path / "sessions.json").read_text())
assert set(saved_post.keys()) == {"fresh"}
class TestGatewayConfigSerialization:
def test_session_store_max_age_days_defaults_to_90(self):
cfg = GatewayConfig()
assert cfg.session_store_max_age_days == 90
def test_session_store_max_age_days_roundtrips(self):
cfg = GatewayConfig(session_store_max_age_days=30)
restored = GatewayConfig.from_dict(cfg.to_dict())
assert restored.session_store_max_age_days == 30
def test_session_store_max_age_days_missing_defaults_90(self):
"""Loading an old config (pre-this-field) falls back to default."""
restored = GatewayConfig.from_dict({})
assert restored.session_store_max_age_days == 90
def test_session_store_max_age_days_negative_coerced_to_zero(self):
"""A negative value (accidental or hostile) becomes 0 (disabled)."""
restored = GatewayConfig.from_dict({"session_store_max_age_days": -5})
assert restored.session_store_max_age_days == 0
def test_session_store_max_age_days_bad_type_falls_back(self):
"""Non-int values fall back to the default, not a crash."""
restored = GatewayConfig.from_dict({"session_store_max_age_days": "nope"})
assert restored.session_store_max_age_days == 90
class TestGatewayWatcherCallsPrune:
"""The session_expiry_watcher should call prune_old_entries once per hour."""
def test_prune_gate_fires_on_first_tick(self):
"""First watcher tick has _last_prune_ts=0, so the gate opens."""
import time as _t
last_ts = 0.0
prune_interval = 3600.0
now = _t.time()
# Mirror the production gate check in _session_expiry_watcher.
should_prune = (now - last_ts) > prune_interval
assert should_prune is True
def test_prune_gate_suppresses_within_interval(self):
import time as _t
last_ts = _t.time() - 600 # 10 minutes ago
prune_interval = 3600.0
now = _t.time()
should_prune = (now - last_ts) > prune_interval
assert should_prune is False

View file

@ -63,6 +63,24 @@ class TestGatewayPidState:
assert status.get_running_pid() == os.getpid()
def test_get_running_pid_accepts_explicit_pid_path_without_cleanup(self, tmp_path, monkeypatch):
other_home = tmp_path / "profile-home"
other_home.mkdir()
pid_path = other_home / "gateway.pid"
pid_path.write_text(json.dumps({
"pid": os.getpid(),
"kind": "hermes-gateway",
"argv": ["python", "-m", "hermes_cli.main", "gateway"],
"start_time": 123,
}))
monkeypatch.setattr(status.os, "kill", lambda pid, sig: None)
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 123)
monkeypatch.setattr(status, "_read_process_cmdline", lambda pid: None)
assert status.get_running_pid(pid_path, cleanup_stale=False) == os.getpid()
assert pid_path.exists()
class TestGatewayRuntimeStatus:
def test_write_runtime_status_overwrites_stale_pid_on_restart(self, tmp_path, monkeypatch):
@ -246,3 +264,181 @@ class TestScopedLocks:
status.release_scoped_lock("telegram-bot-token", "secret")
assert not lock_path.exists()
class TestTakeoverMarker:
"""Tests for the --replace takeover marker.
The marker breaks the post-#5646 flap loop between two gateway services
fighting for the same bot token. The replacer writes a file naming the
target PID + start_time; the target's shutdown handler sees it and exits
0 instead of 1, so systemd's Restart=on-failure doesn't revive it.
"""
def test_write_marker_records_target_identity(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 42)
ok = status.write_takeover_marker(target_pid=12345)
assert ok is True
marker = tmp_path / ".gateway-takeover.json"
assert marker.exists()
payload = json.loads(marker.read_text())
assert payload["target_pid"] == 12345
assert payload["target_start_time"] == 42
assert payload["replacer_pid"] == os.getpid()
assert "written_at" in payload
def test_consume_returns_true_when_marker_names_self(self, tmp_path, monkeypatch):
"""Primary happy path: planned takeover is recognised."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
# Mark THIS process as the target
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 100)
ok = status.write_takeover_marker(target_pid=os.getpid())
assert ok is True
# Call consume as if this process just got SIGTERMed
result = status.consume_takeover_marker_for_self()
assert result is True
# Marker must be unlinked after consumption
assert not (tmp_path / ".gateway-takeover.json").exists()
def test_consume_returns_false_for_different_pid(self, tmp_path, monkeypatch):
"""A marker naming a DIFFERENT process must not be consumed as ours."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 100)
# Marker names a different PID
other_pid = os.getpid() + 9999
ok = status.write_takeover_marker(target_pid=other_pid)
assert ok is True
result = status.consume_takeover_marker_for_self()
assert result is False
# Marker IS unlinked even on non-match (the record has been consumed
# and isn't relevant to us — leaving it around would grief a later
# legitimate check).
assert not (tmp_path / ".gateway-takeover.json").exists()
def test_consume_returns_false_on_start_time_mismatch(self, tmp_path, monkeypatch):
"""PID reuse defence: old marker's start_time mismatches current process."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
# Marker says target started at time 100 with our PID
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 100)
status.write_takeover_marker(target_pid=os.getpid())
# Now change the reported start_time to simulate PID reuse
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 9999)
result = status.consume_takeover_marker_for_self()
assert result is False
def test_consume_returns_false_when_marker_missing(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
result = status.consume_takeover_marker_for_self()
assert result is False
def test_consume_returns_false_for_stale_marker(self, tmp_path, monkeypatch):
"""A marker older than 60s must be ignored."""
from datetime import datetime, timezone, timedelta
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
marker_path = tmp_path / ".gateway-takeover.json"
# Hand-craft a marker written 2 minutes ago
stale_time = (datetime.now(timezone.utc) - timedelta(minutes=2)).isoformat()
marker_path.write_text(json.dumps({
"target_pid": os.getpid(),
"target_start_time": 123,
"replacer_pid": 99999,
"written_at": stale_time,
}))
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 123)
result = status.consume_takeover_marker_for_self()
assert result is False
# Stale markers are unlinked so a later legit shutdown isn't griefed
assert not marker_path.exists()
def test_consume_handles_malformed_marker_gracefully(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
marker_path = tmp_path / ".gateway-takeover.json"
marker_path.write_text("not valid json{")
# Must not raise
result = status.consume_takeover_marker_for_self()
assert result is False
def test_consume_handles_marker_with_missing_fields(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
marker_path = tmp_path / ".gateway-takeover.json"
marker_path.write_text(json.dumps({"only_replacer_pid": 99999}))
result = status.consume_takeover_marker_for_self()
assert result is False
# Malformed marker should be cleaned up
assert not marker_path.exists()
def test_clear_takeover_marker_is_idempotent(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
# Nothing to clear — must not raise
status.clear_takeover_marker()
# Write then clear
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 100)
status.write_takeover_marker(target_pid=12345)
assert (tmp_path / ".gateway-takeover.json").exists()
status.clear_takeover_marker()
assert not (tmp_path / ".gateway-takeover.json").exists()
# Clear again — still no error
status.clear_takeover_marker()
def test_write_marker_returns_false_on_write_failure(self, tmp_path, monkeypatch):
"""write_takeover_marker is best-effort; returns False but doesn't raise."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
def raise_oserror(*args, **kwargs):
raise OSError("simulated write failure")
monkeypatch.setattr(status, "_write_json_file", raise_oserror)
ok = status.write_takeover_marker(target_pid=12345)
assert ok is False
def test_consume_ignores_marker_for_different_process_and_prevents_stale_grief(
self, tmp_path, monkeypatch
):
"""Regression: a stale marker from a dead replacer naming a dead
target must not accidentally cause an unrelated future gateway to
exit 0 on legitimate SIGTERM.
The distinguishing check is ``target_pid == our_pid AND
target_start_time == our_start_time``. Different PID always wins.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
marker_path = tmp_path / ".gateway-takeover.json"
# Fresh marker (timestamp is recent) but names a totally different PID
from datetime import datetime, timezone
marker_path.write_text(json.dumps({
"target_pid": os.getpid() + 10000,
"target_start_time": 42,
"replacer_pid": 99999,
"written_at": datetime.now(timezone.utc).isoformat(),
}))
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 42)
result = status.consume_takeover_marker_for_self()
# We are not the target — must NOT consume as planned
assert result is False

View file

@ -1,6 +1,7 @@
"""Tests for gateway /status behavior and token persistence."""
from datetime import datetime
import time
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
@ -111,6 +112,75 @@ async def test_status_command_includes_session_title_when_present():
assert "**Title:** My titled session" in result
@pytest.mark.asyncio
async def test_agents_command_reports_active_agents_and_processes(monkeypatch):
session_key = build_session_key(_make_source())
session_entry = SessionEntry(
session_key=session_key,
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
total_tokens=0,
)
runner = _make_runner(session_entry)
running_agent = SimpleNamespace(
session_id="sess-running",
model="openrouter/test-model",
interrupt=MagicMock(),
get_activity_summary=lambda: {"seconds_since_activity": 0},
)
runner._running_agents[session_key] = running_agent
runner._running_agents_ts = {session_key: time.time() - 8}
runner._background_tasks = set()
class _FakeRegistry:
def list_sessions(self):
return [
{
"session_id": "proc-1",
"status": "running",
"uptime_seconds": 17,
"command": "sleep 30",
}
]
monkeypatch.setattr("tools.process_registry.process_registry", _FakeRegistry())
result = await runner._handle_message(_make_event("/agents"))
assert "**Active agents:** 1" in result
assert "**Running background processes:** 1" in result
assert "proc-1" in result
running_agent.interrupt.assert_not_called()
@pytest.mark.asyncio
async def test_tasks_alias_routes_to_agents_command(monkeypatch):
session_entry = SessionEntry(
session_key=build_session_key(_make_source()),
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
total_tokens=0,
)
runner = _make_runner(session_entry)
runner._background_tasks = set()
class _FakeRegistry:
def list_sessions(self):
return []
monkeypatch.setattr("tools.process_registry.process_registry", _FakeRegistry())
result = await runner._handle_message(_make_event("/tasks"))
assert "Active Agents & Tasks" in result
@pytest.mark.asyncio
async def test_handle_message_persists_agent_token_counts(monkeypatch):
import gateway.run as gateway_run

View file

@ -88,6 +88,51 @@ class TestCleanForDisplay:
# ── Integration: _send_or_edit strips MEDIA: ─────────────────────────────
class TestFinalizeCapabilityGate:
"""Verify REQUIRES_EDIT_FINALIZE gates the redundant final edit.
Platforms that don't need an explicit finalize signal (Telegram,
Slack, Matrix, ) should skip the redundant final edit when the
mid-stream edit already delivered the final content. Platforms that
*do* need it (DingTalk AI Cards) must always receive a finalize=True
edit at the end of the stream.
"""
@pytest.mark.asyncio
async def test_identical_text_skip_respects_adapter_flag(self):
"""_send_or_edit short-circuits identical-text only when the
adapter doesn't require an explicit finalize signal."""
# Adapter without finalize requirement — should skip identical edit.
plain = MagicMock()
plain.REQUIRES_EDIT_FINALIZE = False
plain.send = AsyncMock(return_value=SimpleNamespace(
success=True, message_id="m1",
))
plain.edit_message = AsyncMock()
plain.MAX_MESSAGE_LENGTH = 4096
c1 = GatewayStreamConsumer(plain, "chat_1")
await c1._send_or_edit("hello") # first send
await c1._send_or_edit("hello", finalize=True) # identical → skip
plain.edit_message.assert_not_called()
# Adapter that requires finalize — must still fire the edit.
picky = MagicMock()
picky.REQUIRES_EDIT_FINALIZE = True
picky.send = AsyncMock(return_value=SimpleNamespace(
success=True, message_id="m1",
))
picky.edit_message = AsyncMock(return_value=SimpleNamespace(
success=True, message_id="m1",
))
picky.MAX_MESSAGE_LENGTH = 4096
c2 = GatewayStreamConsumer(picky, "chat_1")
await c2._send_or_edit("hello")
await c2._send_or_edit("hello", finalize=True)
# Finalize edit must go through even on identical content.
picky.edit_message.assert_called_once()
assert picky.edit_message.call_args[1]["finalize"] is True
class TestSendOrEditMediaStripping:
"""Verify _send_or_edit strips MEDIA: before sending to the platform."""

View file

@ -34,7 +34,12 @@ def _ensure_telegram_mock():
_ensure_telegram_mock()
from gateway.platforms.telegram import TelegramAdapter, _escape_mdv2, _strip_mdv2 # noqa: E402
from gateway.platforms.telegram import ( # noqa: E402
TelegramAdapter,
_escape_mdv2,
_strip_mdv2,
_wrap_markdown_tables,
)
# ---------------------------------------------------------------------------
@ -535,6 +540,152 @@ class TestStripMdv2:
assert _strip_mdv2("||hidden text||") == "hidden text"
# =========================================================================
# Markdown table auto-wrap
# =========================================================================
class TestWrapMarkdownTables:
"""_wrap_markdown_tables wraps GFM pipe tables in ``` fences so
Telegram renders them as monospace preformatted text instead of the
noisy backslash-pipe mess MarkdownV2 produces."""
def test_basic_table_wrapped(self):
text = (
"Scores:\n\n"
"| Player | Score |\n"
"|--------|-------|\n"
"| Alice | 150 |\n"
"| Bob | 120 |\n"
"\nEnd."
)
out = _wrap_markdown_tables(text)
# Table is now wrapped in a fence
assert "```\n| Player | Score |" in out
assert "| Bob | 120 |\n```" in out
# Surrounding prose is preserved
assert out.startswith("Scores:")
assert out.endswith("End.")
def test_bare_pipe_table_wrapped(self):
"""Tables without outer pipes (GFM allows this) are still detected."""
text = "head1 | head2\n--- | ---\na | b\nc | d"
out = _wrap_markdown_tables(text)
assert out.startswith("```\n")
assert out.rstrip().endswith("```")
assert "head1 | head2" in out
def test_alignment_separators(self):
"""Separator rows with :--- / ---: / :---: alignment markers match."""
text = (
"| Name | Age | City |\n"
"|:-----|----:|:----:|\n"
"| Ada | 30 | NYC |"
)
out = _wrap_markdown_tables(text)
assert out.count("```") == 2
def test_two_consecutive_tables_wrapped_separately(self):
text = (
"| A | B |\n"
"|---|---|\n"
"| 1 | 2 |\n"
"\n"
"| X | Y |\n"
"|---|---|\n"
"| 9 | 8 |"
)
out = _wrap_markdown_tables(text)
# Four fences total — one opening + closing per table
assert out.count("```") == 4
def test_plain_text_with_pipes_not_wrapped(self):
"""A bare pipe in prose must NOT trigger wrapping."""
text = "Use the | pipe operator to chain commands."
assert _wrap_markdown_tables(text) == text
def test_horizontal_rule_not_wrapped(self):
"""A lone '---' horizontal rule must not be mistaken for a separator."""
text = "Section A\n\n---\n\nSection B"
assert _wrap_markdown_tables(text) == text
def test_existing_code_block_with_pipes_left_alone(self):
"""A table already inside a fenced code block must not be re-wrapped."""
text = (
"```\n"
"| a | b |\n"
"|---|---|\n"
"| 1 | 2 |\n"
"```"
)
assert _wrap_markdown_tables(text) == text
def test_no_pipe_character_short_circuits(self):
text = "Plain **bold** text with no table."
assert _wrap_markdown_tables(text) == text
def test_no_dash_short_circuits(self):
text = "a | b\nc | d" # has pipes but no '-' separator row
assert _wrap_markdown_tables(text) == text
def test_single_column_separator_not_matched(self):
"""Single-column tables (rare) are not detected — we require at
least one internal pipe in the separator row to avoid false
positives on formatting rules."""
text = "| a |\n| - |\n| b |"
assert _wrap_markdown_tables(text) == text
class TestFormatMessageTables:
"""End-to-end: a pipe table passes through format_message with its
pipes and dashes left alone inside the fence, not mangled by MarkdownV2
escaping."""
def test_table_rendered_as_code_block(self, adapter):
text = (
"Data:\n\n"
"| Col1 | Col2 |\n"
"|------|------|\n"
"| A | B |\n"
)
out = adapter.format_message(text)
# Pipes inside the fenced block are NOT escaped
assert "```\n| Col1 | Col2 |" in out
assert "\\|" not in out.split("```")[1]
# Dashes in separator not escaped inside fence
assert "\\-" not in out.split("```")[1]
def test_text_after_table_still_formatted(self, adapter):
text = (
"| A | B |\n"
"|---|---|\n"
"| 1 | 2 |\n"
"\n"
"Nice **work** team!"
)
out = adapter.format_message(text)
# MarkdownV2 bold conversion still happens outside the table
assert "*work*" in out
# Exclamation outside fence is escaped
assert "\\!" in out
def test_multiple_tables_in_single_message(self, adapter):
text = (
"First:\n"
"| A | B |\n"
"|---|---|\n"
"| 1 | 2 |\n"
"\n"
"Second:\n"
"| X | Y |\n"
"|---|---|\n"
"| 9 | 8 |\n"
)
out = adapter.format_message(text)
# Two separate fenced blocks in the output
assert out.count("```") == 4
@pytest.mark.asyncio
async def test_send_escapes_chunk_indicator_for_markdownv2(adapter):
adapter.MAX_MESSAGE_LENGTH = 80

View file

@ -119,7 +119,7 @@ class TestWeComConnect:
class TestWeComReplyMode:
@pytest.mark.asyncio
async def test_send_uses_passive_reply_stream_when_reply_context_exists(self):
async def test_send_uses_passive_reply_markdown_when_reply_context_exists(self):
from gateway.platforms.wecom import WeComAdapter
adapter = WeComAdapter(PlatformConfig(enabled=True))
@ -134,9 +134,10 @@ class TestWeComReplyMode:
adapter._send_reply_request.assert_awaited_once()
args = adapter._send_reply_request.await_args.args
assert args[0] == "req-1"
assert args[1]["msgtype"] == "stream"
assert args[1]["stream"]["finish"] is True
assert args[1]["stream"]["content"] == "hello from reply"
# msgtype: stream triggers WeCom errcode 600039 on many mobile clients
# (unsupported type). Markdown renders everywhere.
assert args[1]["msgtype"] == "markdown"
assert args[1]["markdown"]["content"] == "hello from reply"
@pytest.mark.asyncio
async def test_send_image_file_uses_passive_reply_media_when_reply_context_exists(self):
@ -593,3 +594,193 @@ class TestInboundMessages:
await adapter._on_message(payload)
adapter.handle_message.assert_not_awaited()
class TestWeComZombieSessionFix:
"""Tests for PR #11572 — device_id, markdown reply, group req_id fallback."""
def test_adapter_generates_stable_device_id_per_instance(self):
from gateway.platforms.wecom import WeComAdapter
adapter = WeComAdapter(PlatformConfig(enabled=True))
assert isinstance(adapter._device_id, str)
assert len(adapter._device_id) > 0
# Second snapshot on the same adapter must be identical — only a fresh
# adapter instance should get a new device_id (one-per-reconnect is the
# zombie-session footgun we're fixing).
assert adapter._device_id == adapter._device_id
def test_different_adapter_instances_get_distinct_device_ids(self):
from gateway.platforms.wecom import WeComAdapter
a = WeComAdapter(PlatformConfig(enabled=True))
b = WeComAdapter(PlatformConfig(enabled=True))
assert a._device_id != b._device_id
@pytest.mark.asyncio
async def test_open_connection_includes_device_id_in_subscribe(self):
from gateway.platforms.wecom import APP_CMD_SUBSCRIBE, WeComAdapter
adapter = WeComAdapter(PlatformConfig(enabled=True))
adapter._bot_id = "test-bot"
adapter._secret = "test-secret"
sent_payloads = []
class _FakeWS:
closed = False
async def send_json(self, payload):
sent_payloads.append(payload)
async def close(self):
return None
class _FakeSession:
def __init__(self, *args, **kwargs):
pass
async def ws_connect(self, *args, **kwargs):
return _FakeWS()
async def close(self):
return None
async def _fake_cleanup():
return None
async def _fake_handshake(req_id):
return {"errcode": 0, "headers": {"req_id": req_id}}
adapter._cleanup_ws = _fake_cleanup
adapter._wait_for_handshake = _fake_handshake
with patch("gateway.platforms.wecom.aiohttp.ClientSession", _FakeSession):
await adapter._open_connection()
assert len(sent_payloads) == 1
subscribe = sent_payloads[0]
assert subscribe["cmd"] == APP_CMD_SUBSCRIBE
assert subscribe["body"]["bot_id"] == "test-bot"
assert subscribe["body"]["secret"] == "test-secret"
assert subscribe["body"]["device_id"] == adapter._device_id
@pytest.mark.asyncio
async def test_on_message_caches_last_req_id_per_chat(self):
from gateway.platforms.wecom import WeComAdapter
adapter = WeComAdapter(PlatformConfig(enabled=True))
adapter._text_batch_delay_seconds = 0
adapter.handle_message = AsyncMock()
adapter._extract_media = AsyncMock(return_value=([], []))
payload = {
"cmd": "aibot_msg_callback",
"headers": {"req_id": "req-abc"},
"body": {
"msgid": "msg-1",
"chatid": "group-1",
"chattype": "group",
"from": {"userid": "user-1"},
"msgtype": "text",
"text": {"content": "hi"},
},
}
await adapter._on_message(payload)
assert adapter._last_chat_req_ids["group-1"] == "req-abc"
@pytest.mark.asyncio
async def test_on_message_does_not_cache_blocked_sender_req_id(self):
"""Blocked chats shouldn't populate the proactive-send fallback cache."""
from gateway.platforms.wecom import WeComAdapter
adapter = WeComAdapter(
PlatformConfig(
enabled=True,
extra={"group_policy": "allowlist", "group_allow_from": ["group-ok"]},
)
)
adapter.handle_message = AsyncMock()
adapter._extract_media = AsyncMock(return_value=([], []))
payload = {
"cmd": "aibot_msg_callback",
"headers": {"req_id": "req-abc"},
"body": {
"msgid": "msg-1",
"chatid": "group-blocked",
"chattype": "group",
"from": {"userid": "user-1"},
"msgtype": "text",
"text": {"content": "hi"},
},
}
await adapter._on_message(payload)
adapter.handle_message.assert_not_awaited()
assert "group-blocked" not in adapter._last_chat_req_ids
def test_remember_chat_req_id_is_bounded(self):
from gateway.platforms.wecom import DEDUP_MAX_SIZE, WeComAdapter
adapter = WeComAdapter(PlatformConfig(enabled=True))
for i in range(DEDUP_MAX_SIZE + 50):
adapter._remember_chat_req_id(f"chat-{i}", f"req-{i}")
assert len(adapter._last_chat_req_ids) <= DEDUP_MAX_SIZE
# The most recently remembered chat must still be present.
latest = f"chat-{DEDUP_MAX_SIZE + 49}"
assert adapter._last_chat_req_ids[latest] == f"req-{DEDUP_MAX_SIZE + 49}"
def test_remember_chat_req_id_ignores_empty_values(self):
from gateway.platforms.wecom import WeComAdapter
adapter = WeComAdapter(PlatformConfig(enabled=True))
adapter._remember_chat_req_id("", "req-1")
adapter._remember_chat_req_id("chat-1", "")
adapter._remember_chat_req_id(" ", " ")
assert adapter._last_chat_req_ids == {}
@pytest.mark.asyncio
async def test_proactive_group_send_falls_back_to_cached_req_id(self):
"""Sending into a group without reply_to should use the last cached
req_id via APP_CMD_RESPONSE WeCom AI Bots cannot initiate APP_CMD_SEND
in group chats (errcode 600039)."""
from gateway.platforms.wecom import WeComAdapter
adapter = WeComAdapter(PlatformConfig(enabled=True))
adapter._last_chat_req_ids["group-1"] = "inbound-req-42"
adapter._send_reply_request = AsyncMock(
return_value={"headers": {"req_id": "inbound-req-42"}, "errcode": 0}
)
adapter._send_request = AsyncMock(
return_value={"headers": {"req_id": "new"}, "errcode": 0}
)
result = await adapter.send("group-1", "ping", reply_to=None)
assert result.success is True
# Must route through reply (APP_CMD_RESPONSE), not proactive send.
adapter._send_reply_request.assert_awaited_once()
adapter._send_request.assert_not_awaited()
args = adapter._send_reply_request.await_args.args
assert args[0] == "inbound-req-42"
assert args[1]["msgtype"] == "markdown"
assert args[1]["markdown"]["content"] == "ping"
@pytest.mark.asyncio
async def test_proactive_send_without_cached_req_id_uses_app_cmd_send(self):
"""When we have no prior req_id (fresh DM target), APP_CMD_SEND is used."""
from gateway.platforms.wecom import APP_CMD_SEND, WeComAdapter
adapter = WeComAdapter(PlatformConfig(enabled=True))
adapter._send_request = AsyncMock(
return_value={"headers": {"req_id": "new"}, "errcode": 0}
)
result = await adapter.send("fresh-dm-chat", "ping", reply_to=None)
assert result.success is True
adapter._send_request.assert_awaited_once()
cmd = adapter._send_request.await_args.args[0]
assert cmd == APP_CMD_SEND

View file

@ -33,6 +33,7 @@ class TestProviderRegistry:
("huggingface", "Hugging Face", "api_key"),
("zai", "Z.AI / GLM", "api_key"),
("xai", "xAI", "api_key"),
("nvidia", "NVIDIA NIM", "api_key"),
("kimi-coding", "Kimi / Moonshot", "api_key"),
("minimax", "MiniMax", "api_key"),
("minimax-cn", "MiniMax (China)", "api_key"),
@ -57,6 +58,12 @@ class TestProviderRegistry:
assert pconfig.base_url_env_var == "XAI_BASE_URL"
assert pconfig.inference_base_url == "https://api.x.ai/v1"
def test_nvidia_env_vars(self):
pconfig = PROVIDER_REGISTRY["nvidia"]
assert pconfig.api_key_env_vars == ("NVIDIA_API_KEY",)
assert pconfig.base_url_env_var == "NVIDIA_BASE_URL"
assert pconfig.inference_base_url == "https://integrate.api.nvidia.com/v1"
def test_copilot_env_vars(self):
pconfig = PROVIDER_REGISTRY["copilot"]
assert pconfig.api_key_env_vars == ("COPILOT_GITHUB_TOKEN", "GH_TOKEN", "GITHUB_TOKEN")

View file

@ -141,13 +141,93 @@ def test_auth_add_nous_oauth_persists_pool_entry(tmp_path, monkeypatch):
auth_add_command(_Args())
payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
# Pool has exactly one canonical `device_code` entry — not a duplicate
# pair of `manual:device_code` + `device_code` (the latter would be
# materialised by _seed_from_singletons on every load_pool).
entries = payload["credential_pool"]["nous"]
entry = next(item for item in entries if item["source"] == "manual:device_code")
assert entry["label"] == "nous@example.com"
assert entry["source"] == "manual:device_code"
device_code_entries = [
item for item in entries if item["source"] == "device_code"
]
assert len(device_code_entries) == 1, entries
assert not any(item["source"] == "manual:device_code" for item in entries)
entry = device_code_entries[0]
assert entry["source"] == "device_code"
assert entry["agent_key"] == "ak-test"
assert entry["portal_base_url"] == "https://portal.example.com"
# `hermes auth add nous` must also populate providers.nous so the
# 401-recovery path (resolve_nous_runtime_credentials) can mint a fresh
# agent_key when the 24h TTL expires. If this mirror is missing, recovery
# raises "Hermes is not logged into Nous Portal" and the agent dies.
singleton = payload["providers"]["nous"]
assert singleton["access_token"] == token
assert singleton["refresh_token"] == "refresh-token"
assert singleton["agent_key"] == "ak-test"
assert singleton["portal_base_url"] == "https://portal.example.com"
assert singleton["inference_base_url"] == "https://inference.example.com/v1"
def test_auth_add_nous_oauth_honors_custom_label(tmp_path, monkeypatch):
"""`hermes auth add nous --type oauth --label <name>` must preserve the
custom label end-to-end it was silently dropped in the first cut of the
persist_nous_credentials helper because `--label` wasn't threaded through.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
_write_auth_store(tmp_path, {"version": 1, "providers": {}})
token = _jwt_with_email("nous@example.com")
monkeypatch.setattr(
"hermes_cli.auth._nous_device_code_login",
lambda **kwargs: {
"portal_base_url": "https://portal.example.com",
"inference_base_url": "https://inference.example.com/v1",
"client_id": "hermes-cli",
"scope": "inference:mint_agent_key",
"token_type": "Bearer",
"access_token": token,
"refresh_token": "refresh-token",
"obtained_at": "2026-03-23T10:00:00+00:00",
"expires_at": "2026-03-23T11:00:00+00:00",
"expires_in": 3600,
"agent_key": "ak-test",
"agent_key_id": "ak-id",
"agent_key_expires_at": "2026-03-23T10:30:00+00:00",
"agent_key_expires_in": 1800,
"agent_key_reused": False,
"agent_key_obtained_at": "2026-03-23T10:00:10+00:00",
"tls": {"insecure": False, "ca_bundle": None},
},
)
from hermes_cli.auth_commands import auth_add_command
class _Args:
provider = "nous"
auth_type = "oauth"
api_key = None
label = "my-nous"
portal_url = None
inference_url = None
client_id = None
scope = None
no_browser = False
timeout = None
insecure = False
ca_bundle = None
auth_add_command(_Args())
payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
# Custom label reaches the pool entry …
pool_entry = payload["credential_pool"]["nous"][0]
assert pool_entry["source"] == "device_code"
assert pool_entry["label"] == "my-nous"
# … and survives in providers.nous so a subsequent load_pool() re-seeds
# it without reverting to the auto-derived fingerprint.
assert payload["providers"]["nous"]["label"] == "my-nous"
def test_auth_add_codex_oauth_persists_pool_entry(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))

View file

@ -456,3 +456,258 @@ class TestLoginNousSkipKeepsCurrent:
assert "nous" in auth_after.get("providers", {})
# =============================================================================
# persist_nous_credentials: shared helper for CLI + web dashboard login paths
# =============================================================================
def _full_state_fixture() -> dict:
"""Shape of the dict returned by _nous_device_code_login /
refresh_nous_oauth_from_state. Used as helper input."""
return {
"portal_base_url": "https://portal.example.com",
"inference_base_url": "https://inference.example.com/v1",
"client_id": "hermes-cli",
"scope": "inference:mint_agent_key",
"token_type": "Bearer",
"access_token": "access-tok",
"refresh_token": "refresh-tok",
"obtained_at": "2026-04-17T22:00:00+00:00",
"expires_at": "2026-04-17T22:15:00+00:00",
"expires_in": 900,
"agent_key": "agent-key-value",
"agent_key_id": "ak-id",
"agent_key_expires_at": "2026-04-18T22:00:00+00:00",
"agent_key_expires_in": 86400,
"agent_key_reused": False,
"agent_key_obtained_at": "2026-04-17T22:00:10+00:00",
"tls": {"insecure": False, "ca_bundle": None},
}
def test_persist_nous_credentials_writes_both_pool_and_providers(tmp_path, monkeypatch):
"""Helper must populate BOTH credential_pool.nous AND providers.nous.
Regression guard: before this helper existed, `hermes auth add nous`
wrote only the pool. After the Nous agent_key's 24h TTL expired, the
401-recovery path in run_agent.py called resolve_nous_runtime_credentials
which reads providers.nous, found it empty, raised AuthError, and the
agent failed with "Non-retryable client error". Both stores must stay
in sync at write time.
"""
from hermes_cli.auth import persist_nous_credentials, NOUS_DEVICE_CODE_SOURCE
hermes_home = tmp_path / "hermes"
hermes_home.mkdir(parents=True, exist_ok=True)
(hermes_home / "auth.json").write_text(json.dumps({
"version": 1, "providers": {},
}))
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
entry = persist_nous_credentials(_full_state_fixture())
assert entry is not None
assert entry.provider == "nous"
assert entry.source == NOUS_DEVICE_CODE_SOURCE
payload = json.loads((hermes_home / "auth.json").read_text())
# providers.nous populated with the full state (new behaviour)
singleton = payload["providers"]["nous"]
assert singleton["access_token"] == "access-tok"
assert singleton["refresh_token"] == "refresh-tok"
assert singleton["agent_key"] == "agent-key-value"
assert singleton["agent_key_expires_at"] == "2026-04-18T22:00:00+00:00"
# credential_pool.nous has exactly one canonical device_code entry
pool_entries = payload["credential_pool"]["nous"]
assert len(pool_entries) == 1, pool_entries
pool_entry = pool_entries[0]
assert pool_entry["source"] == NOUS_DEVICE_CODE_SOURCE
assert pool_entry["agent_key"] == "agent-key-value"
assert pool_entry["inference_base_url"] == "https://inference.example.com/v1"
def test_persist_nous_credentials_allows_recovery_from_401(tmp_path, monkeypatch):
"""End-to-end: after persisting via the helper, resolve_nous_runtime_credentials
must succeed (not raise "Hermes is not logged into Nous Portal").
This is the exact path that run_agent.py's `_try_refresh_nous_client_credentials`
calls after a Nous 401 before the fix it would raise AuthError because
providers.nous was empty.
"""
from hermes_cli.auth import persist_nous_credentials, resolve_nous_runtime_credentials
hermes_home = tmp_path / "hermes"
hermes_home.mkdir(parents=True, exist_ok=True)
(hermes_home / "auth.json").write_text(json.dumps({
"version": 1, "providers": {},
}))
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
persist_nous_credentials(_full_state_fixture())
# Stub the network-touching steps so we don't actually contact the
# portal — the point of this test is that state lookup succeeds and
# doesn't raise "Hermes is not logged into Nous Portal".
def _fake_refresh_access_token(*, client, portal_base_url, client_id, refresh_token):
return {
"access_token": "access-new",
"refresh_token": "refresh-new",
"expires_in": 900,
"token_type": "Bearer",
}
def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds):
return _mint_payload(api_key="new-agent-key")
monkeypatch.setattr("hermes_cli.auth._refresh_access_token", _fake_refresh_access_token)
monkeypatch.setattr("hermes_cli.auth._mint_agent_key", _fake_mint_agent_key)
creds = resolve_nous_runtime_credentials(min_key_ttl_seconds=300, force_mint=True)
assert creds["api_key"] == "new-agent-key"
def test_persist_nous_credentials_idempotent_no_duplicate_pool_entries(tmp_path, monkeypatch):
"""Re-running persist must upsert — not accumulate duplicate device_code rows.
Regression guard for the review comment on PR #11858: before normalisation,
the helper wrote `manual:device_code` while `_seed_from_singletons` wrote
`device_code`, so the pool grew a second duplicate entry on every
``load_pool()``. The helper now writes providers.nous and lets seeding
materialise the pool entry under the canonical ``device_code`` source, so
two persists still leave the pool with exactly one row.
"""
from hermes_cli.auth import persist_nous_credentials, NOUS_DEVICE_CODE_SOURCE
hermes_home = tmp_path / "hermes"
hermes_home.mkdir(parents=True, exist_ok=True)
(hermes_home / "auth.json").write_text(json.dumps({
"version": 1, "providers": {},
}))
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
first = _full_state_fixture()
persist_nous_credentials(first)
second = _full_state_fixture()
second["access_token"] = "access-second"
second["agent_key"] = "agent-key-second"
persist_nous_credentials(second)
payload = json.loads((hermes_home / "auth.json").read_text())
# providers.nous reflects the latest write (singleton semantics)
assert payload["providers"]["nous"]["access_token"] == "access-second"
assert payload["providers"]["nous"]["agent_key"] == "agent-key-second"
# credential_pool.nous has exactly one entry, carrying the latest agent_key
pool_entries = payload["credential_pool"]["nous"]
assert len(pool_entries) == 1, pool_entries
assert pool_entries[0]["source"] == NOUS_DEVICE_CODE_SOURCE
assert pool_entries[0]["agent_key"] == "agent-key-second"
# And no stray `manual:device_code` / `manual:dashboard_device_code` rows
assert not any(
e["source"].startswith("manual:") for e in pool_entries
)
def test_persist_nous_credentials_reloads_pool_after_singleton_write(tmp_path, monkeypatch):
"""The entry returned by the helper must come from a fresh ``load_pool`` so
callers observe the canonical seeded state, including any legacy entries
that ``_seed_from_singletons`` pruned or upserted.
"""
from hermes_cli.auth import persist_nous_credentials, NOUS_DEVICE_CODE_SOURCE
hermes_home = tmp_path / "hermes"
hermes_home.mkdir(parents=True, exist_ok=True)
(hermes_home / "auth.json").write_text(json.dumps({
"version": 1, "providers": {},
}))
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
entry = persist_nous_credentials(_full_state_fixture())
assert entry is not None
assert entry.source == NOUS_DEVICE_CODE_SOURCE
# Label derived by _seed_from_singletons via label_from_token; we don't
# assert its exact value, just that the helper returned a real entry.
assert entry.access_token == "access-tok"
assert entry.agent_key == "agent-key-value"
def test_persist_nous_credentials_embeds_custom_label(tmp_path, monkeypatch):
"""User-supplied ``--label`` round-trips through providers.nous and the pool.
Previously `hermes auth add nous --type oauth --label <name>` silently
dropped the label because persist_nous_credentials() ignored it and
_seed_from_singletons always auto-derived via label_from_token(). The
fix stashes the label inside providers.nous so seeding prefers it.
"""
from hermes_cli.auth import persist_nous_credentials, NOUS_DEVICE_CODE_SOURCE
hermes_home = tmp_path / "hermes"
hermes_home.mkdir(parents=True, exist_ok=True)
(hermes_home / "auth.json").write_text(json.dumps({
"version": 1, "providers": {},
}))
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
entry = persist_nous_credentials(_full_state_fixture(), label="my-personal")
assert entry is not None
assert entry.source == NOUS_DEVICE_CODE_SOURCE
assert entry.label == "my-personal"
# providers.nous carries the label so re-seeding on the next load_pool
# doesn't overwrite it with the auto-derived fingerprint.
payload = json.loads((hermes_home / "auth.json").read_text())
assert payload["providers"]["nous"]["label"] == "my-personal"
def test_persist_nous_credentials_custom_label_survives_reseed(tmp_path, monkeypatch):
"""Reopening the pool (which re-runs _seed_from_singletons) must keep the
user-chosen label instead of clobbering it with label_from_token output.
"""
from hermes_cli.auth import persist_nous_credentials
from agent.credential_pool import load_pool
hermes_home = tmp_path / "hermes"
hermes_home.mkdir(parents=True, exist_ok=True)
(hermes_home / "auth.json").write_text(json.dumps({
"version": 1, "providers": {},
}))
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
persist_nous_credentials(_full_state_fixture(), label="work-acct")
# Second load_pool triggers _seed_from_singletons again. Without the
# fix, this call overwrote the label with label_from_token(access_token).
pool = load_pool("nous")
entries = pool.entries()
assert len(entries) == 1
assert entries[0].label == "work-acct"
def test_persist_nous_credentials_no_label_uses_auto_derived(tmp_path, monkeypatch):
"""When the caller doesn't pass ``label``, the auto-derived fingerprint
is used (unchanged default behaviour regression guard).
"""
from hermes_cli.auth import persist_nous_credentials
hermes_home = tmp_path / "hermes"
hermes_home.mkdir(parents=True, exist_ok=True)
(hermes_home / "auth.json").write_text(json.dumps({
"version": 1, "providers": {},
}))
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
entry = persist_nous_credentials(_full_state_fixture())
assert entry is not None
# label_from_token derives from the access_token; exact value depends on
# the fingerprinter but it must not be empty and must not equal an
# arbitrary user string we never passed.
assert entry.label
assert entry.label != "my-personal"
# No "label" key embedded in providers.nous when the caller didn't supply one.
payload = json.loads((hermes_home / "auth.json").read_text())
assert "label" not in payload["providers"]["nous"]

View file

@ -0,0 +1,294 @@
"""Tests for the auxiliary-model configuration UI in ``hermes model``.
Covers the helper functions:
- ``_save_aux_choice`` writes to config.yaml without touching main model config
- ``_reset_aux_to_auto`` clears routing fields but preserves timeouts
- ``_format_aux_current`` renders current task config for the menu
- ``_AUX_TASKS`` stays in sync with ``DEFAULT_CONFIG["auxiliary"]``
These are pure-function tests the interactive menu loops are not covered
here (they're stdin-driven curses prompts).
"""
from __future__ import annotations
import pytest
from hermes_cli.config import DEFAULT_CONFIG, load_config
from hermes_cli.main import (
_AUX_TASKS,
_format_aux_current,
_reset_aux_to_auto,
_save_aux_choice,
)
# ── Default config ──────────────────────────────────────────────────────────
def test_title_generation_present_in_default_config():
"""`title_generation` task must be defined in DEFAULT_CONFIG.
Regression for an existing gap: title_generator.py calls
``call_llm(task="title_generation", ...)`` but the task was missing
from DEFAULT_CONFIG["auxiliary"], so the config-backed timeout/provider
overrides never worked for that task.
"""
assert "title_generation" in DEFAULT_CONFIG["auxiliary"]
tg = DEFAULT_CONFIG["auxiliary"]["title_generation"]
assert tg["provider"] == "auto"
assert tg["model"] == ""
assert tg["timeout"] > 0
def test_aux_tasks_keys_all_exist_in_default_config():
"""Every task the menu offers must be defined in DEFAULT_CONFIG."""
aux_keys = {k for k, _name, _desc in _AUX_TASKS}
default_keys = set(DEFAULT_CONFIG["auxiliary"].keys())
missing = aux_keys - default_keys
assert not missing, (
f"_AUX_TASKS references tasks not in DEFAULT_CONFIG.auxiliary: {missing}"
)
# ── _format_aux_current ─────────────────────────────────────────────────────
@pytest.mark.parametrize(
"task_cfg,expected",
[
({}, "auto"),
({"provider": "", "model": ""}, "auto"),
({"provider": "auto", "model": ""}, "auto"),
({"provider": "auto", "model": "gpt-4o"}, "auto · gpt-4o"),
({"provider": "openrouter", "model": ""}, "openrouter"),
(
{"provider": "openrouter", "model": "google/gemini-2.5-flash"},
"openrouter · google/gemini-2.5-flash",
),
({"provider": "nous", "model": "gemini-3-flash"}, "nous · gemini-3-flash"),
(
{"provider": "custom", "base_url": "http://localhost:11434/v1", "model": ""},
"custom (localhost:11434/v1)",
),
(
{
"provider": "custom",
"base_url": "http://localhost:11434/v1/",
"model": "qwen2.5:32b",
},
"custom (localhost:11434/v1) · qwen2.5:32b",
),
],
)
def test_format_aux_current(task_cfg, expected):
assert _format_aux_current(task_cfg) == expected
def test_format_aux_current_handles_non_dict():
assert _format_aux_current(None) == "auto"
assert _format_aux_current("string") == "auto"
# ── _save_aux_choice ────────────────────────────────────────────────────────
def test_save_aux_choice_persists_to_config_yaml(tmp_path, monkeypatch):
"""Saving a task writes provider/model/base_url/api_key to auxiliary.<task>."""
from pathlib import Path
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
(tmp_path / ".hermes").mkdir(exist_ok=True)
_save_aux_choice(
"vision", provider="openrouter", model="google/gemini-2.5-flash",
)
cfg = load_config()
v = cfg["auxiliary"]["vision"]
assert v["provider"] == "openrouter"
assert v["model"] == "google/gemini-2.5-flash"
assert v["base_url"] == ""
assert v["api_key"] == ""
def test_save_aux_choice_preserves_timeout(tmp_path, monkeypatch):
"""Saving must NOT clobber user-tuned timeout values."""
from pathlib import Path
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
(tmp_path / ".hermes").mkdir(exist_ok=True)
# Default vision timeout is 120
cfg_before = load_config()
default_timeout = cfg_before["auxiliary"]["vision"]["timeout"]
assert default_timeout == 120
_save_aux_choice("vision", provider="nous", model="gemini-3-flash")
cfg_after = load_config()
assert cfg_after["auxiliary"]["vision"]["timeout"] == default_timeout
# download_timeout also preserved for vision
assert cfg_after["auxiliary"]["vision"].get("download_timeout") == 30
def test_save_aux_choice_does_not_touch_main_model(tmp_path, monkeypatch):
"""Aux config must never mutate model.default / model.provider / model.base_url."""
from pathlib import Path
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
(tmp_path / ".hermes").mkdir(exist_ok=True)
# Simulate a configured main model
from hermes_cli.config import save_config
cfg = load_config()
cfg["model"] = {
"default": "claude-sonnet-4.6",
"provider": "anthropic",
"base_url": "",
}
save_config(cfg)
_save_aux_choice(
"compression", provider="custom",
base_url="http://localhost:11434/v1", model="qwen2.5:32b",
)
cfg = load_config()
# Main model untouched
assert cfg["model"]["default"] == "claude-sonnet-4.6"
assert cfg["model"]["provider"] == "anthropic"
# Aux saved correctly
c = cfg["auxiliary"]["compression"]
assert c["provider"] == "custom"
assert c["model"] == "qwen2.5:32b"
assert c["base_url"] == "http://localhost:11434/v1"
def test_save_aux_choice_creates_missing_task_entry(tmp_path, monkeypatch):
"""Saving a task that was wiped from config.yaml should recreate it."""
from pathlib import Path
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
(tmp_path / ".hermes").mkdir(exist_ok=True)
# Remove vision from config entirely
from hermes_cli.config import save_config
cfg = load_config()
cfg.setdefault("auxiliary", {}).pop("vision", None)
save_config(cfg)
_save_aux_choice("vision", provider="nous", model="gemini-3-flash")
cfg = load_config()
assert cfg["auxiliary"]["vision"]["provider"] == "nous"
assert cfg["auxiliary"]["vision"]["model"] == "gemini-3-flash"
# ── _reset_aux_to_auto ──────────────────────────────────────────────────────
def test_reset_aux_to_auto_clears_routing_preserves_timeouts(tmp_path, monkeypatch):
from pathlib import Path
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
(tmp_path / ".hermes").mkdir(exist_ok=True)
# Configure two tasks non-auto, and bump a timeout
_save_aux_choice("vision", provider="openrouter", model="gpt-4o")
_save_aux_choice("compression", provider="nous", model="gemini-3-flash")
from hermes_cli.config import save_config
cfg = load_config()
cfg["auxiliary"]["vision"]["timeout"] = 300 # user-tuned
save_config(cfg)
n = _reset_aux_to_auto()
assert n == 2 # both changed
cfg = load_config()
for task in ("vision", "compression"):
v = cfg["auxiliary"][task]
assert v["provider"] == "auto"
assert v["model"] == ""
assert v["base_url"] == ""
assert v["api_key"] == ""
# User-tuned timeout survives reset
assert cfg["auxiliary"]["vision"]["timeout"] == 300
# Default compression timeout preserved
assert cfg["auxiliary"]["compression"]["timeout"] == 120
def test_reset_aux_to_auto_idempotent(tmp_path, monkeypatch):
"""Second reset on already-auto config returns 0 without errors."""
from pathlib import Path
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
(tmp_path / ".hermes").mkdir(exist_ok=True)
assert _reset_aux_to_auto() == 0
_save_aux_choice("vision", provider="nous", model="gemini-3-flash")
assert _reset_aux_to_auto() == 1
assert _reset_aux_to_auto() == 0
# ── Menu dispatch ───────────────────────────────────────────────────────────
def test_select_provider_and_model_dispatches_to_aux_menu(tmp_path, monkeypatch):
"""Picking 'Configure auxiliary models...' in the provider list calls _aux_config_menu."""
from pathlib import Path
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
(tmp_path / ".hermes").mkdir(exist_ok=True)
from hermes_cli import main as main_mod
called = {"aux": 0, "flow": 0}
def fake_prompt(choices, *, default=0):
# Find the aux-config entry by its label text and return its index
for i, label in enumerate(choices):
if "Configure auxiliary models" in label:
return i
raise AssertionError("aux entry not in provider list")
monkeypatch.setattr(main_mod, "_prompt_provider_choice", fake_prompt)
monkeypatch.setattr(main_mod, "_aux_config_menu", lambda: called.__setitem__("aux", called["aux"] + 1))
# Guard against any main flow accidentally running
monkeypatch.setattr(main_mod, "_model_flow_openrouter",
lambda *a, **kw: called.__setitem__("flow", called["flow"] + 1))
main_mod.select_provider_and_model()
assert called["aux"] == 1, "aux menu not invoked"
assert called["flow"] == 0, "main provider flow should not run"
def test_leave_unchanged_replaces_cancel_label(tmp_path, monkeypatch):
"""The bottom cancel entry now reads 'Leave unchanged' (UX polish)."""
from pathlib import Path
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
(tmp_path / ".hermes").mkdir(exist_ok=True)
from hermes_cli import main as main_mod
captured: list[list[str]] = []
def fake_prompt(choices, *, default=0):
captured.append(list(choices))
# Pick 'Leave unchanged' (last item) to exit cleanly
for i, label in enumerate(choices):
if label == "Leave unchanged":
return i
raise AssertionError("Leave unchanged not in provider list")
monkeypatch.setattr(main_mod, "_prompt_provider_choice", fake_prompt)
main_mod.select_provider_and_model()
assert captured, "provider menu never rendered"
labels = captured[0]
assert "Leave unchanged" in labels
assert "Cancel" not in labels, "Cancel label should be replaced"
assert any("Configure auxiliary models" in label for label in labels)

View file

@ -106,6 +106,49 @@ class TestCmdUpdateBranchFallback:
pull_cmds = [c for c in commands if "pull" in c]
assert len(pull_cmds) == 0
@patch("shutil.which")
@patch("subprocess.run")
def test_update_refreshes_repo_and_tui_node_dependencies(
self, mock_run, mock_which, mock_args
):
mock_which.side_effect = {"uv": "/usr/bin/uv", "npm": "/usr/bin/npm"}.get
mock_run.side_effect = _make_run_side_effect(
branch="main", verify_ok=True, commit_count="1"
)
cmd_update(mock_args)
npm_calls = [
(call.args[0], call.kwargs.get("cwd"))
for call in mock_run.call_args_list
if call.args and call.args[0][0] == "/usr/bin/npm"
]
assert npm_calls == [
(
[
"/usr/bin/npm",
"install",
"--silent",
"--no-fund",
"--no-audit",
"--progress=false",
],
PROJECT_ROOT,
),
(
[
"/usr/bin/npm",
"install",
"--silent",
"--no-fund",
"--no-audit",
"--progress=false",
],
PROJECT_ROOT / "ui-tui",
),
]
def test_update_non_interactive_skips_migration_prompt(self, mock_args, capsys):
"""When stdin/stdout aren't TTYs, config migration prompt is skipped."""
with patch("shutil.which", return_value=None), patch(

View file

@ -93,6 +93,8 @@ class TestResolveCommand:
def test_canonical_name_resolves(self):
assert resolve_command("help").name == "help"
assert resolve_command("background").name == "background"
assert resolve_command("copy").name == "copy"
assert resolve_command("agents").name == "agents"
def test_alias_resolves_to_canonical(self):
assert resolve_command("bg").name == "background"
@ -102,6 +104,7 @@ class TestResolveCommand:
assert resolve_command("gateway").name == "platforms"
assert resolve_command("set-home").name == "sethome"
assert resolve_command("reload_mcp").name == "reload-mcp"
assert resolve_command("tasks").name == "agents"
def test_leading_slash_stripped(self):
assert resolve_command("/help").name == "help"

View file

@ -0,0 +1,169 @@
import textwrap
from hermes_cli.config import load_config, save_config
def _write_config(tmp_path, body: str):
(tmp_path / "config.yaml").write_text(textwrap.dedent(body), encoding="utf-8")
def _read_config(tmp_path) -> str:
return (tmp_path / "config.yaml").read_text(encoding="utf-8")
def test_save_config_preserves_env_refs_on_unrelated_change(monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.setenv("TU_ZI_API_KEY", "sk-realsecret")
monkeypatch.setenv("ALT_SECRET", "alt-secret")
_write_config(
tmp_path,
"""\
custom_providers:
- name: tuzi
base_url: https://api.tu-zi.com
api_key: ${TU_ZI_API_KEY}
headers:
Authorization: Bearer ${ALT_SECRET}
model: claude-opus-4-6
model:
default: claude-opus-4-6
""",
)
config = load_config()
config["model"]["default"] = "doubao-pro"
save_config(config)
saved = _read_config(tmp_path)
assert "api_key: ${TU_ZI_API_KEY}" in saved
assert "Authorization: Bearer ${ALT_SECRET}" in saved
assert "sk-realsecret" not in saved
assert "alt-secret" not in saved
def test_save_config_preserves_unresolved_env_refs(monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.delenv("MISSING_SECRET", raising=False)
_write_config(
tmp_path,
"""\
custom_providers:
- name: unresolved
api_key: ${MISSING_SECRET}
model: claude-opus-4-6
model:
default: claude-opus-4-6
""",
)
config = load_config()
config["display"]["compact"] = True
save_config(config)
assert "api_key: ${MISSING_SECRET}" in _read_config(tmp_path)
def test_save_config_allows_intentional_secret_value_change(monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.setenv("TU_ZI_API_KEY", "sk-old-secret")
_write_config(
tmp_path,
"""\
custom_providers:
- name: tuzi
api_key: ${TU_ZI_API_KEY}
model: claude-opus-4-6
model:
default: claude-opus-4-6
""",
)
config = load_config()
config["custom_providers"][0]["api_key"] = "sk-new-secret"
save_config(config)
saved = _read_config(tmp_path)
assert "api_key: sk-new-secret" in saved
assert "${TU_ZI_API_KEY}" not in saved
def test_save_config_preserves_template_when_env_rotates_after_load(monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.setenv("TU_ZI_API_KEY", "sk-old-secret")
_write_config(
tmp_path,
"""\
custom_providers:
- name: tuzi
api_key: ${TU_ZI_API_KEY}
model: claude-opus-4-6
model:
default: claude-opus-4-6
""",
)
config = load_config()
monkeypatch.setenv("TU_ZI_API_KEY", "sk-rotated-secret")
config["model"]["default"] = "doubao-pro"
save_config(config)
saved = _read_config(tmp_path)
assert "api_key: ${TU_ZI_API_KEY}" in saved
assert "sk-old-secret" not in saved
assert "sk-rotated-secret" not in saved
def test_save_config_keeps_edited_partial_template_strings_literal(monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.setenv("ALT_SECRET", "alt-secret")
_write_config(
tmp_path,
"""\
custom_providers:
- name: tuzi
headers:
Authorization: Bearer ${ALT_SECRET}
model: claude-opus-4-6
model:
default: claude-opus-4-6
""",
)
config = load_config()
config["custom_providers"][0]["headers"]["Authorization"] = "Token alt-secret"
save_config(config)
saved = _read_config(tmp_path)
assert "Authorization: Token alt-secret" in saved
assert "Authorization: Bearer ${ALT_SECRET}" not in saved
def test_save_config_falls_back_to_positional_matching_for_duplicate_names(monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.setenv("FIRST_SECRET", "first-secret")
monkeypatch.setenv("SECOND_SECRET", "second-secret")
_write_config(
tmp_path,
"""\
custom_providers:
- name: duplicate
api_key: ${FIRST_SECRET}
model: claude-opus-4-6
- name: duplicate
api_key: ${SECOND_SECRET}
model: doubao-pro
model:
default: claude-opus-4-6
""",
)
config = load_config()
config["display"]["compact"] = True
save_config(config)
saved = _read_config(tmp_path)
assert saved.count("name: duplicate") == 2
assert "api_key: ${FIRST_SECRET}" in saved
assert "api_key: ${SECOND_SECRET}" in saved
assert "first-secret" not in saved
assert "second-secret" not in saved

View file

@ -501,40 +501,272 @@ class TestDeletePaste:
class TestScheduleAutoDelete:
def test_spawns_detached_process(self):
"""``_schedule_auto_delete`` used to spawn a detached Python subprocess
per call (one per paste URL batch). Those subprocesses slept 6 hours
and accumulated forever under repeated use 15+ orphaned interpreters
were observed in production.
The new implementation is stateless: it records pending deletions to
``~/.hermes/pastes/pending.json`` and lets ``_sweep_expired_pastes``
handle the DELETE requests synchronously on the next ``hermes debug``
invocation.
"""
def test_does_not_spawn_subprocess(self, hermes_home):
"""Regression guard: _schedule_auto_delete must NEVER spawn subprocesses.
We assert this structurally rather than by mocking Popen: the new
implementation doesn't even import ``subprocess`` at module scope,
so a mock patch wouldn't find it.
"""
import ast
import inspect
from hermes_cli.debug import _schedule_auto_delete
with patch("subprocess.Popen") as mock_popen:
_schedule_auto_delete(
["https://paste.rs/abc", "https://paste.rs/def"],
delay_seconds=10,
)
# Strip the docstring before scanning so the regression-rationale
# prose inside it doesn't trigger our banned-word checks.
source = inspect.getsource(_schedule_auto_delete)
tree = ast.parse(source)
func_node = tree.body[0]
if (
func_node.body
and isinstance(func_node.body[0], ast.Expr)
and isinstance(func_node.body[0].value, ast.Constant)
and isinstance(func_node.body[0].value.value, str)
):
func_node.body = func_node.body[1:]
code_only = ast.unparse(func_node)
mock_popen.assert_called_once()
call_args = mock_popen.call_args
# Verify detached
assert call_args[1]["start_new_session"] is True
# Verify the script references both URLs
script = call_args[0][0][2] # [python, -c, script]
assert "paste.rs/abc" in script
assert "paste.rs/def" in script
assert "time.sleep(10)" in script
assert "Popen" not in code_only, (
"_schedule_auto_delete must not spawn subprocesses — "
"use pending.json + _sweep_expired_pastes instead"
)
assert "subprocess" not in code_only, (
"_schedule_auto_delete must not reference subprocess at all"
)
assert "time.sleep" not in code_only, (
"Regression: sleeping in _schedule_auto_delete is the bug being fixed"
)
def test_skips_non_paste_rs_urls(self):
from hermes_cli.debug import _schedule_auto_delete
# And verify that calling it doesn't produce any orphaned children
# (it should just write pending.json synchronously).
import os as _os
before = set(_os.listdir("/proc")) if _os.path.exists("/proc") else None
_schedule_auto_delete(
["https://paste.rs/abc", "https://paste.rs/def"],
delay_seconds=10,
)
if before is not None:
after = set(_os.listdir("/proc"))
new = after - before
# Filter to only integer-named entries (process PIDs)
new_pids = [p for p in new if p.isdigit()]
# It's fine if unrelated processes appeared — we just need to make
# sure we didn't spawn a long-sleeping one. The old bug spawned
# a python interpreter whose cmdline contained "time.sleep".
for pid in new_pids:
try:
with open(f"/proc/{pid}/cmdline", "rb") as f:
cmdline = f.read().decode("utf-8", errors="replace")
assert "time.sleep" not in cmdline, (
f"Leaked sleeper subprocess PID {pid}: {cmdline}"
)
except OSError:
pass # process exited already
with patch("subprocess.Popen") as mock_popen:
_schedule_auto_delete(["https://dpaste.com/something"])
def test_records_pending_to_json(self, hermes_home):
"""Scheduled URLs are persisted to pending.json with expiration."""
from hermes_cli.debug import _schedule_auto_delete, _pending_file
import json
mock_popen.assert_not_called()
_schedule_auto_delete(
["https://paste.rs/abc", "https://paste.rs/def"],
delay_seconds=10,
)
def test_handles_popen_failure_gracefully(self):
from hermes_cli.debug import _schedule_auto_delete
pending_path = _pending_file()
assert pending_path.exists()
with patch("subprocess.Popen",
side_effect=OSError("no such file")):
# Should not raise
_schedule_auto_delete(["https://paste.rs/abc"])
entries = json.loads(pending_path.read_text())
assert len(entries) == 2
urls = {e["url"] for e in entries}
assert urls == {"https://paste.rs/abc", "https://paste.rs/def"}
# expire_at is ~now + delay_seconds
import time
for e in entries:
assert e["expire_at"] > time.time()
assert e["expire_at"] <= time.time() + 15
def test_skips_non_paste_rs_urls(self, hermes_home):
"""dpaste.com URLs auto-expire — don't track them."""
from hermes_cli.debug import _schedule_auto_delete, _pending_file
_schedule_auto_delete(["https://dpaste.com/something"])
# pending.json should not be created for non-paste.rs URLs
assert not _pending_file().exists()
def test_merges_with_existing_pending(self, hermes_home):
"""Subsequent calls merge into existing pending.json."""
from hermes_cli.debug import _schedule_auto_delete, _load_pending
_schedule_auto_delete(["https://paste.rs/first"], delay_seconds=10)
_schedule_auto_delete(["https://paste.rs/second"], delay_seconds=10)
entries = _load_pending()
urls = {e["url"] for e in entries}
assert urls == {"https://paste.rs/first", "https://paste.rs/second"}
def test_dedupes_same_url(self, hermes_home):
"""Same URL recorded twice → one entry with the later expire_at."""
from hermes_cli.debug import _schedule_auto_delete, _load_pending
_schedule_auto_delete(["https://paste.rs/dup"], delay_seconds=10)
_schedule_auto_delete(["https://paste.rs/dup"], delay_seconds=100)
entries = _load_pending()
assert len(entries) == 1
assert entries[0]["url"] == "https://paste.rs/dup"
class TestSweepExpiredPastes:
"""Test the opportunistic sweep that replaces the sleeping subprocess."""
def test_sweep_empty_is_noop(self, hermes_home):
from hermes_cli.debug import _sweep_expired_pastes
deleted, remaining = _sweep_expired_pastes()
assert deleted == 0
assert remaining == 0
def test_sweep_deletes_expired_entries(self, hermes_home):
from hermes_cli.debug import (
_sweep_expired_pastes,
_save_pending,
_load_pending,
)
import time
# Seed pending.json with one expired + one future entry
_save_pending([
{"url": "https://paste.rs/expired", "expire_at": time.time() - 100},
{"url": "https://paste.rs/future", "expire_at": time.time() + 3600},
])
delete_calls = []
def fake_delete(url):
delete_calls.append(url)
return True
with patch("hermes_cli.debug.delete_paste", side_effect=fake_delete):
deleted, remaining = _sweep_expired_pastes()
assert delete_calls == ["https://paste.rs/expired"]
assert deleted == 1
assert remaining == 1
entries = _load_pending()
urls = {e["url"] for e in entries}
assert urls == {"https://paste.rs/future"}
def test_sweep_leaves_future_entries_alone(self, hermes_home):
from hermes_cli.debug import _sweep_expired_pastes, _save_pending
import time
_save_pending([
{"url": "https://paste.rs/future1", "expire_at": time.time() + 3600},
{"url": "https://paste.rs/future2", "expire_at": time.time() + 7200},
])
with patch("hermes_cli.debug.delete_paste") as mock_delete:
deleted, remaining = _sweep_expired_pastes()
mock_delete.assert_not_called()
assert deleted == 0
assert remaining == 2
def test_sweep_survives_network_failure(self, hermes_home):
"""Failed DELETEs stay in pending.json until the 24h grace window."""
from hermes_cli.debug import (
_sweep_expired_pastes,
_save_pending,
_load_pending,
)
import time
_save_pending([
{"url": "https://paste.rs/flaky", "expire_at": time.time() - 100},
])
with patch(
"hermes_cli.debug.delete_paste",
side_effect=Exception("network down"),
):
deleted, remaining = _sweep_expired_pastes()
# Failure within 24h grace → kept for retry
assert deleted == 0
assert remaining == 1
assert len(_load_pending()) == 1
def test_sweep_drops_entries_past_grace_window(self, hermes_home):
"""After 24h past expiration, give up even on network failures."""
from hermes_cli.debug import (
_sweep_expired_pastes,
_save_pending,
_load_pending,
)
import time
# Expired 25 hours ago → past the 24h grace window
very_old = time.time() - (25 * 3600)
_save_pending([
{"url": "https://paste.rs/ancient", "expire_at": very_old},
])
with patch(
"hermes_cli.debug.delete_paste",
side_effect=Exception("network down"),
):
deleted, remaining = _sweep_expired_pastes()
assert deleted == 1
assert remaining == 0
assert _load_pending() == []
class TestRunDebugSweepsOnInvocation:
"""``run_debug`` must sweep expired pastes on every invocation."""
def test_run_debug_calls_sweep(self, hermes_home):
from hermes_cli.debug import run_debug
args = MagicMock()
args.debug_command = None # default → prints help
with patch("hermes_cli.debug._sweep_expired_pastes") as mock_sweep:
run_debug(args)
mock_sweep.assert_called_once()
def test_run_debug_survives_sweep_failure(self, hermes_home, capsys):
"""If the sweep throws, the subcommand still runs."""
from hermes_cli.debug import run_debug
args = MagicMock()
args.debug_command = None
with patch(
"hermes_cli.debug._sweep_expired_pastes",
side_effect=RuntimeError("boom"),
):
run_debug(args) # must not raise
# Default subcommand still printed help
out = capsys.readouterr().out
assert "Usage: hermes debug" in out
class TestRunDebugDelete:

View file

@ -39,6 +39,76 @@ class TestSystemdLingerStatus:
assert gateway.get_systemd_linger_status() == (None, "not supported in Termux")
class TestContainerSystemdSupport:
def test_supports_systemd_services_in_container_with_user_manager(self, monkeypatch):
monkeypatch.setattr(gateway, "is_linux", lambda: True)
monkeypatch.setattr(gateway, "is_termux", lambda: False)
monkeypatch.setattr(gateway, "is_wsl", lambda: False)
monkeypatch.setattr(gateway, "is_container", lambda: True)
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/systemctl")
monkeypatch.setattr(gateway, "_systemd_operational", lambda system=False: not system)
assert gateway.supports_systemd_services() is True
def test_supports_systemd_services_in_container_with_system_manager(self, monkeypatch):
monkeypatch.setattr(gateway, "is_linux", lambda: True)
monkeypatch.setattr(gateway, "is_termux", lambda: False)
monkeypatch.setattr(gateway, "is_wsl", lambda: False)
monkeypatch.setattr(gateway, "is_container", lambda: True)
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/systemctl")
monkeypatch.setattr(gateway, "_systemd_operational", lambda system=False: system)
assert gateway.supports_systemd_services() is True
def test_supports_systemd_services_in_container_without_systemd(self, monkeypatch):
monkeypatch.setattr(gateway, "is_linux", lambda: True)
monkeypatch.setattr(gateway, "is_termux", lambda: False)
monkeypatch.setattr(gateway, "is_wsl", lambda: False)
monkeypatch.setattr(gateway, "is_container", lambda: True)
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/systemctl")
monkeypatch.setattr(gateway, "_systemd_operational", lambda system=False: False)
assert gateway.supports_systemd_services() is False
def test_gateway_install_in_container_with_operational_systemd_uses_systemd(monkeypatch):
monkeypatch.setattr(gateway, "supports_systemd_services", lambda: True)
monkeypatch.setattr(gateway, "is_wsl", lambda: False)
monkeypatch.setattr(gateway, "is_macos", lambda: False)
monkeypatch.setattr(gateway, "is_managed", lambda: False)
calls = []
monkeypatch.setattr(
gateway,
"systemd_install",
lambda force=False, system=False, run_as_user=None: calls.append((force, system, run_as_user)),
)
args = SimpleNamespace(
gateway_command="install",
force=False,
system=False,
run_as_user=None,
)
gateway.gateway_command(args)
assert calls == [(False, False, None)]
def test_gateway_start_in_container_with_operational_systemd_uses_systemd(monkeypatch):
monkeypatch.setattr(gateway, "supports_systemd_services", lambda: True)
monkeypatch.setattr(gateway, "is_wsl", lambda: False)
monkeypatch.setattr(gateway, "is_macos", lambda: False)
calls = []
monkeypatch.setattr(gateway, "systemd_start", lambda system=False: calls.append(system))
args = SimpleNamespace(gateway_command="start", system=False, all=False)
gateway.gateway_command(args)
assert calls == [False]
def test_systemd_status_warns_when_linger_disabled(monkeypatch, tmp_path, capsys):
unit_path = tmp_path / "hermes-gateway.service"
unit_path.write_text("[Unit]\n")
@ -179,6 +249,21 @@ def test_install_linux_gateway_from_setup_system_choice_as_root_installs(monkeyp
assert calls == [(True, True, "alice")]
def test_find_gateway_pids_falls_back_to_pid_file_when_process_scan_fails(monkeypatch):
monkeypatch.setattr(gateway, "_get_service_pids", lambda: set())
monkeypatch.setattr(gateway, "is_windows", lambda: False)
monkeypatch.setattr("gateway.status.get_running_pid", lambda: 321)
def fake_run(cmd, **kwargs):
if cmd[:4] == ["ps", "-A", "eww", "-o"]:
return SimpleNamespace(returncode=1, stdout="", stderr="ps failed")
raise AssertionError(f"Unexpected command: {cmd}")
monkeypatch.setattr(gateway.subprocess, "run", fake_run)
assert gateway.find_gateway_pids() == [321]
# ---------------------------------------------------------------------------
# _wait_for_gateway_exit
# ---------------------------------------------------------------------------

View file

@ -450,7 +450,6 @@ class TestGatewayServiceDetection:
assert gateway_cli._is_service_running() is False
class TestGatewaySystemServiceRouting:
def test_systemd_restart_self_requests_graceful_restart_and_waits(self, monkeypatch, capsys):
calls = []
@ -554,6 +553,38 @@ class TestGatewaySystemServiceRouting:
assert calls == [(False, False)]
def test_gateway_status_reports_manual_process_when_service_is_stopped(self, monkeypatch, capsys):
user_unit = SimpleNamespace(exists=lambda: True)
system_unit = SimpleNamespace(exists=lambda: False)
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
monkeypatch.setattr(
gateway_cli,
"get_systemd_unit_path",
lambda system=False: system_unit if system else user_unit,
)
monkeypatch.setattr(gateway_cli, "systemd_status", lambda deep=False, system=False: print("service stopped"))
monkeypatch.setattr(
gateway_cli,
"get_gateway_runtime_snapshot",
lambda system=False: gateway_cli.GatewayRuntimeSnapshot(
manager="systemd (user)",
service_installed=True,
service_running=False,
gateway_pids=(4321,),
service_scope="user",
),
)
gateway_cli.gateway_command(SimpleNamespace(gateway_command="status", deep=False, system=False))
out = capsys.readouterr().out
assert "service stopped" in out
assert "Gateway process is running for this profile" in out
assert "PID(s): 4321" in out
def test_gateway_status_on_termux_shows_manual_guidance(self, monkeypatch, capsys):
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: False)
monkeypatch.setattr(gateway_cli, "is_termux", lambda: True)
@ -1146,3 +1177,556 @@ class TestDockerAwareGateway:
out = capsys.readouterr().out
assert "docker" in out.lower()
assert "hermes gateway run" in out
class TestLegacyHermesUnitDetection:
"""Tests for _find_legacy_hermes_units / has_legacy_hermes_units.
These guard against the scenario that tripped Luis in April 2026: an
older install left a ``hermes.service`` unit behind when the service was
renamed to ``hermes-gateway.service``. After PR #5646 (signal recovery
via systemd), the two services began SIGTERM-flapping over the same
Telegram bot token in a 30-second cycle.
The detector must flag ``hermes.service`` ONLY when it actually runs our
gateway, and must NEVER flag profile units
(``hermes-gateway-<profile>.service``) or unrelated third-party services.
"""
# Minimal ExecStart that looks like our gateway
_OUR_UNIT_TEXT = (
"[Unit]\nDescription=Hermes Gateway\n[Service]\n"
"ExecStart=/usr/bin/python -m hermes_cli.main gateway run --replace\n"
)
@staticmethod
def _setup_search_paths(tmp_path, monkeypatch):
"""Redirect the legacy search to user_dir + system_dir under tmp_path."""
user_dir = tmp_path / "user"
system_dir = tmp_path / "system"
user_dir.mkdir()
system_dir.mkdir()
monkeypatch.setattr(
gateway_cli,
"_legacy_unit_search_paths",
lambda: [(False, user_dir), (True, system_dir)],
)
return user_dir, system_dir
def test_detects_legacy_hermes_service_in_user_scope(self, tmp_path, monkeypatch):
user_dir, _ = self._setup_search_paths(tmp_path, monkeypatch)
legacy = user_dir / "hermes.service"
legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
results = gateway_cli._find_legacy_hermes_units()
assert len(results) == 1
name, path, is_system = results[0]
assert name == "hermes.service"
assert path == legacy
assert is_system is False
assert gateway_cli.has_legacy_hermes_units() is True
def test_detects_legacy_hermes_service_in_system_scope(self, tmp_path, monkeypatch):
_, system_dir = self._setup_search_paths(tmp_path, monkeypatch)
legacy = system_dir / "hermes.service"
legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
results = gateway_cli._find_legacy_hermes_units()
assert len(results) == 1
name, path, is_system = results[0]
assert name == "hermes.service"
assert path == legacy
assert is_system is True
def test_ignores_profile_unit_hermes_gateway_coder(self, tmp_path, monkeypatch):
"""CRITICAL: profile units must NOT be flagged as legacy.
Teknium's concern — ``hermes-gateway-coder.service`` is our standard
naming for the ``coder`` profile. The legacy detector is an explicit
allowlist, not a glob, so profile units are safe.
"""
user_dir, system_dir = self._setup_search_paths(tmp_path, monkeypatch)
# Drop profile units in BOTH scopes with our ExecStart
for base in (user_dir, system_dir):
(base / "hermes-gateway-coder.service").write_text(
self._OUR_UNIT_TEXT, encoding="utf-8"
)
(base / "hermes-gateway-orcha.service").write_text(
self._OUR_UNIT_TEXT, encoding="utf-8"
)
(base / "hermes-gateway.service").write_text(
self._OUR_UNIT_TEXT, encoding="utf-8"
)
results = gateway_cli._find_legacy_hermes_units()
assert results == []
assert gateway_cli.has_legacy_hermes_units() is False
def test_ignores_unrelated_hermes_service(self, tmp_path, monkeypatch):
"""Third-party ``hermes.service`` that isn't ours stays untouched.
If a user has some other package named ``hermes`` installed as a
service, we must not flag it.
"""
user_dir, _ = self._setup_search_paths(tmp_path, monkeypatch)
(user_dir / "hermes.service").write_text(
"[Unit]\nDescription=Some Other Hermes\n[Service]\n"
"ExecStart=/opt/other-hermes/bin/daemon --foreground\n",
encoding="utf-8",
)
results = gateway_cli._find_legacy_hermes_units()
assert results == []
assert gateway_cli.has_legacy_hermes_units() is False
def test_returns_empty_when_no_legacy_files_exist(self, tmp_path, monkeypatch):
self._setup_search_paths(tmp_path, monkeypatch)
assert gateway_cli._find_legacy_hermes_units() == []
assert gateway_cli.has_legacy_hermes_units() is False
def test_detects_both_scopes_simultaneously(self, tmp_path, monkeypatch):
"""When a user has BOTH user-scope and system-scope legacy units,
both are reported so the migration step can remove them together."""
user_dir, system_dir = self._setup_search_paths(tmp_path, monkeypatch)
(user_dir / "hermes.service").write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
(system_dir / "hermes.service").write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
results = gateway_cli._find_legacy_hermes_units()
scopes = sorted(is_system for _, _, is_system in results)
assert scopes == [False, True]
def test_accepts_alternate_execstart_formats(self, tmp_path, monkeypatch):
"""Older installs may have used different python invocations.
ExecStart variants we've seen in the wild:
- python -m hermes_cli.main gateway run
- python path/to/hermes_cli/main.py gateway run
- hermes gateway run (direct binary)
- python path/to/gateway/run.py
"""
user_dir, _ = self._setup_search_paths(tmp_path, monkeypatch)
variants = [
"ExecStart=/venv/bin/python -m hermes_cli.main gateway run --replace",
"ExecStart=/venv/bin/python /opt/hermes/hermes_cli/main.py gateway run",
"ExecStart=/usr/local/bin/hermes gateway run --replace",
"ExecStart=/venv/bin/python /opt/hermes/gateway/run.py",
]
for i, execstart in enumerate(variants):
name = f"hermes.service" if i == 0 else f"hermes.service" # same name
# Test each variant fresh
(user_dir / "hermes.service").write_text(
f"[Unit]\nDescription=Old Hermes\n[Service]\n{execstart}\n",
encoding="utf-8",
)
results = gateway_cli._find_legacy_hermes_units()
assert len(results) == 1, f"Variant {i} not detected: {execstart!r}"
def test_print_legacy_unit_warning_is_noop_when_empty(self, tmp_path, monkeypatch, capsys):
self._setup_search_paths(tmp_path, monkeypatch)
gateway_cli.print_legacy_unit_warning()
out = capsys.readouterr().out
assert out == ""
def test_print_legacy_unit_warning_shows_migration_hint(self, tmp_path, monkeypatch, capsys):
user_dir, _ = self._setup_search_paths(tmp_path, monkeypatch)
(user_dir / "hermes.service").write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
gateway_cli.print_legacy_unit_warning()
out = capsys.readouterr().out
assert "Legacy" in out
assert "hermes.service" in out
assert "hermes gateway migrate-legacy" in out
def test_handles_unreadable_unit_file_gracefully(self, tmp_path, monkeypatch):
"""A permission error reading a unit file must not crash detection."""
user_dir, _ = self._setup_search_paths(tmp_path, monkeypatch)
unreadable = user_dir / "hermes.service"
unreadable.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
# Simulate a read failure — monkeypatch Path.read_text to raise
original_read_text = gateway_cli.Path.read_text
def raising_read_text(self, *args, **kwargs):
if self == unreadable:
raise PermissionError("simulated")
return original_read_text(self, *args, **kwargs)
monkeypatch.setattr(gateway_cli.Path, "read_text", raising_read_text)
# Should not raise
results = gateway_cli._find_legacy_hermes_units()
assert results == []
class TestRemoveLegacyHermesUnits:
"""Tests for remove_legacy_hermes_units (the migration action)."""
_OUR_UNIT_TEXT = (
"[Unit]\nDescription=Hermes Gateway\n[Service]\n"
"ExecStart=/usr/bin/python -m hermes_cli.main gateway run --replace\n"
)
@staticmethod
def _setup(tmp_path, monkeypatch, as_root=False):
user_dir = tmp_path / "user"
system_dir = tmp_path / "system"
user_dir.mkdir()
system_dir.mkdir()
monkeypatch.setattr(
gateway_cli,
"_legacy_unit_search_paths",
lambda: [(False, user_dir), (True, system_dir)],
)
# Mock systemctl — return success for everything
systemctl_calls: list[list[str]] = []
def fake_run(cmd, **kwargs):
systemctl_calls.append(cmd)
return SimpleNamespace(returncode=0, stdout="", stderr="")
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
monkeypatch.setattr(gateway_cli.os, "geteuid", lambda: 0 if as_root else 1000)
return user_dir, system_dir, systemctl_calls
def test_returns_zero_when_no_legacy_units(self, tmp_path, monkeypatch, capsys):
self._setup(tmp_path, monkeypatch)
removed, remaining = gateway_cli.remove_legacy_hermes_units(interactive=False)
assert removed == 0
assert remaining == []
assert "No legacy" in capsys.readouterr().out
def test_dry_run_lists_without_removing(self, tmp_path, monkeypatch, capsys):
user_dir, _, calls = self._setup(tmp_path, monkeypatch)
legacy = user_dir / "hermes.service"
legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
removed, remaining = gateway_cli.remove_legacy_hermes_units(
interactive=False, dry_run=True
)
assert removed == 0
assert remaining == [legacy]
assert legacy.exists() # Not removed
assert calls == [] # No systemctl invocations
out = capsys.readouterr().out
assert "dry-run" in out
def test_removes_user_scope_legacy_unit(self, tmp_path, monkeypatch, capsys):
user_dir, _, calls = self._setup(tmp_path, monkeypatch)
legacy = user_dir / "hermes.service"
legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
removed, remaining = gateway_cli.remove_legacy_hermes_units(interactive=False)
assert removed == 1
assert remaining == []
assert not legacy.exists()
# Must have invoked stop → disable → daemon-reload on user scope
cmds_joined = [" ".join(c) for c in calls]
assert any("--user stop hermes.service" in c for c in cmds_joined)
assert any("--user disable hermes.service" in c for c in cmds_joined)
assert any("--user daemon-reload" in c for c in cmds_joined)
def test_system_scope_without_root_defers_removal(self, tmp_path, monkeypatch, capsys):
_, system_dir, calls = self._setup(tmp_path, monkeypatch, as_root=False)
legacy = system_dir / "hermes.service"
legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
removed, remaining = gateway_cli.remove_legacy_hermes_units(interactive=False)
assert removed == 0
assert remaining == [legacy]
assert legacy.exists() # Not removed — requires sudo
out = capsys.readouterr().out
assert "sudo hermes gateway migrate-legacy" in out
def test_system_scope_with_root_removes(self, tmp_path, monkeypatch, capsys):
_, system_dir, calls = self._setup(tmp_path, monkeypatch, as_root=True)
legacy = system_dir / "hermes.service"
legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
removed, remaining = gateway_cli.remove_legacy_hermes_units(interactive=False)
assert removed == 1
assert remaining == []
assert not legacy.exists()
cmds_joined = [" ".join(c) for c in calls]
# System-scope uses plain "systemctl" (no --user)
assert any(
c.startswith("systemctl stop hermes.service") for c in cmds_joined
)
assert any(
c.startswith("systemctl disable hermes.service") for c in cmds_joined
)
def test_removes_both_scopes_with_root(self, tmp_path, monkeypatch, capsys):
user_dir, system_dir, _ = self._setup(tmp_path, monkeypatch, as_root=True)
user_legacy = user_dir / "hermes.service"
system_legacy = system_dir / "hermes.service"
user_legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
system_legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
removed, remaining = gateway_cli.remove_legacy_hermes_units(interactive=False)
assert removed == 2
assert remaining == []
assert not user_legacy.exists()
assert not system_legacy.exists()
def test_does_not_touch_profile_units_during_migration(
self, tmp_path, monkeypatch, capsys
):
"""Teknium's constraint: profile units (hermes-gateway-coder.service)
must survive a migration call, even if we somehow include them in the
search dir."""
user_dir, _, _ = self._setup(tmp_path, monkeypatch, as_root=True)
profile_unit = user_dir / "hermes-gateway-coder.service"
profile_unit.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
default_unit = user_dir / "hermes-gateway.service"
default_unit.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
removed, remaining = gateway_cli.remove_legacy_hermes_units(interactive=False)
assert removed == 0
assert remaining == []
# Both the profile unit and the current default unit must survive
assert profile_unit.exists()
assert default_unit.exists()
def test_interactive_prompt_no_skips_removal(self, tmp_path, monkeypatch, capsys):
"""When interactive=True and user answers no, no removal happens."""
user_dir, _, _ = self._setup(tmp_path, monkeypatch)
legacy = user_dir / "hermes.service"
legacy.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
monkeypatch.setattr(gateway_cli, "prompt_yes_no", lambda *a, **k: False)
removed, remaining = gateway_cli.remove_legacy_hermes_units(interactive=True)
assert removed == 0
assert remaining == [legacy]
assert legacy.exists()
class TestMigrateLegacyCommand:
"""Tests for the `hermes gateway migrate-legacy` subcommand dispatch."""
def test_migrate_legacy_subparser_accepts_dry_run_and_yes(self):
"""Verify the argparse subparser is registered and parses flags."""
import hermes_cli.main as cli_main
parser = cli_main.build_parser() if hasattr(cli_main, "build_parser") else None
# Fall back to calling main's setup helper if direct access isn't exposed
# The key thing: the subparser must exist. We verify by constructing
# a namespace through argparse directly — but if build_parser isn't
# public, just confirm that `hermes gateway --help` shows it.
import subprocess
import sys
project_root = cli_main.PROJECT_ROOT if hasattr(cli_main, "PROJECT_ROOT") else None
if project_root is None:
import hermes_cli.gateway as gw
project_root = gw.PROJECT_ROOT
result = subprocess.run(
[sys.executable, "-m", "hermes_cli.main", "gateway", "--help"],
cwd=str(project_root),
capture_output=True,
text=True,
timeout=15,
)
assert result.returncode == 0
assert "migrate-legacy" in result.stdout
def test_gateway_command_migrate_legacy_dispatches(
self, tmp_path, monkeypatch, capsys
):
"""gateway_command(args) with subcmd='migrate-legacy' calls the helper."""
called = {}
def fake_remove(interactive=True, dry_run=False):
called["interactive"] = interactive
called["dry_run"] = dry_run
return 0, []
monkeypatch.setattr(gateway_cli, "remove_legacy_hermes_units", fake_remove)
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
args = SimpleNamespace(
gateway_command="migrate-legacy", dry_run=False, yes=True
)
gateway_cli.gateway_command(args)
assert called == {"interactive": False, "dry_run": False}
def test_gateway_command_migrate_legacy_dry_run_passes_through(
self, monkeypatch
):
called = {}
def fake_remove(interactive=True, dry_run=False):
called["interactive"] = interactive
called["dry_run"] = dry_run
return 0, []
monkeypatch.setattr(gateway_cli, "remove_legacy_hermes_units", fake_remove)
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
args = SimpleNamespace(
gateway_command="migrate-legacy", dry_run=True, yes=False
)
gateway_cli.gateway_command(args)
assert called == {"interactive": True, "dry_run": True}
def test_migrate_legacy_on_unsupported_platform_prints_message(
self, monkeypatch, capsys
):
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: False)
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
args = SimpleNamespace(
gateway_command="migrate-legacy", dry_run=False, yes=True
)
gateway_cli.gateway_command(args)
out = capsys.readouterr().out
assert "only applies to systemd" in out
class TestSystemdInstallOffersLegacyRemoval:
"""Verify that systemd_install prompts to remove legacy units first."""
def test_install_offers_removal_when_legacy_detected(
self, tmp_path, monkeypatch, capsys
):
"""When legacy units exist, install flow should call the removal
helper before writing the new unit."""
remove_called = {}
def fake_remove(interactive=True, dry_run=False):
remove_called["invoked"] = True
remove_called["interactive"] = interactive
return 1, []
# has_legacy_hermes_units must return True
monkeypatch.setattr(gateway_cli, "has_legacy_hermes_units", lambda: True)
monkeypatch.setattr(gateway_cli, "remove_legacy_hermes_units", fake_remove)
monkeypatch.setattr(gateway_cli, "print_legacy_unit_warning", lambda: None)
# Answer "yes" to the legacy-removal prompt
monkeypatch.setattr(gateway_cli, "prompt_yes_no", lambda *a, **k: True)
# Mock the rest of the install flow
unit_path = tmp_path / "hermes-gateway.service"
monkeypatch.setattr(
gateway_cli, "get_systemd_unit_path", lambda system=False: unit_path
)
monkeypatch.setattr(
gateway_cli,
"generate_systemd_unit",
lambda system=False, run_as_user=None: "unit text\n",
)
monkeypatch.setattr(
gateway_cli.subprocess,
"run",
lambda cmd, **kw: SimpleNamespace(returncode=0, stdout="", stderr=""),
)
monkeypatch.setattr(gateway_cli, "_ensure_linger_enabled", lambda: None)
gateway_cli.systemd_install()
assert remove_called.get("invoked") is True
assert remove_called.get("interactive") is False # prompted elsewhere
def test_install_declines_legacy_removal_when_user_says_no(
self, tmp_path, monkeypatch
):
"""When legacy units exist and user declines, install still proceeds
but doesn't touch them."""
remove_called = {"invoked": False}
def fake_remove(interactive=True, dry_run=False):
remove_called["invoked"] = True
return 0, []
monkeypatch.setattr(gateway_cli, "has_legacy_hermes_units", lambda: True)
monkeypatch.setattr(gateway_cli, "remove_legacy_hermes_units", fake_remove)
monkeypatch.setattr(gateway_cli, "print_legacy_unit_warning", lambda: None)
monkeypatch.setattr(gateway_cli, "prompt_yes_no", lambda *a, **k: False)
unit_path = tmp_path / "hermes-gateway.service"
monkeypatch.setattr(
gateway_cli, "get_systemd_unit_path", lambda system=False: unit_path
)
monkeypatch.setattr(
gateway_cli,
"generate_systemd_unit",
lambda system=False, run_as_user=None: "unit text\n",
)
monkeypatch.setattr(
gateway_cli.subprocess,
"run",
lambda cmd, **kw: SimpleNamespace(returncode=0, stdout="", stderr=""),
)
monkeypatch.setattr(gateway_cli, "_ensure_linger_enabled", lambda: None)
gateway_cli.systemd_install()
# Helper must NOT have been called
assert remove_called["invoked"] is False
# New unit should still have been written
assert unit_path.exists()
assert unit_path.read_text() == "unit text\n"
def test_install_skips_legacy_check_when_none_present(
self, tmp_path, monkeypatch
):
"""No legacy → no prompt, no helper call."""
prompt_called = {"count": 0}
def counting_prompt(*a, **k):
prompt_called["count"] += 1
return True
remove_called = {"invoked": False}
def fake_remove(interactive=True, dry_run=False):
remove_called["invoked"] = True
return 0, []
monkeypatch.setattr(gateway_cli, "has_legacy_hermes_units", lambda: False)
monkeypatch.setattr(gateway_cli, "remove_legacy_hermes_units", fake_remove)
monkeypatch.setattr(gateway_cli, "prompt_yes_no", counting_prompt)
unit_path = tmp_path / "hermes-gateway.service"
monkeypatch.setattr(
gateway_cli, "get_systemd_unit_path", lambda system=False: unit_path
)
monkeypatch.setattr(
gateway_cli,
"generate_systemd_unit",
lambda system=False, run_as_user=None: "unit text\n",
)
monkeypatch.setattr(
gateway_cli.subprocess,
"run",
lambda cmd, **kw: SimpleNamespace(returncode=0, stdout="", stderr=""),
)
monkeypatch.setattr(gateway_cli, "_ensure_linger_enabled", lambda: None)
gateway_cli.systemd_install()
assert prompt_called["count"] == 0
assert remove_called["invoked"] is False

View file

@ -178,10 +178,6 @@ class TestGeminiContextLength:
ctx = get_model_context_length("gemma-4-31b-it", provider="gemini")
assert ctx == 256000
def test_gemma_4_26b_context(self):
ctx = get_model_context_length("gemma-4-26b-it", provider="gemini")
assert ctx == 256000
def test_gemini_3_context(self):
ctx = get_model_context_length("gemini-3.1-pro-preview", provider="gemini")
assert ctx == 1048576

View file

@ -403,7 +403,8 @@ class TestValidateFormatChecks:
def test_no_slash_model_rejected_if_not_in_api(self):
result = _validate("gpt-5.4", api_models=["openai/gpt-5.4"])
assert result["accepted"] is True
assert result["accepted"] is False
assert result["persist"] is False
assert "not found" in result["message"]
@ -429,10 +430,10 @@ class TestValidateApiFound:
# -- validate — API not found ------------------------------------------------
class TestValidateApiNotFound:
def test_model_not_in_api_accepted_with_warning(self):
def test_model_not_in_api_rejected_with_guidance(self):
result = _validate("anthropic/claude-nonexistent")
assert result["accepted"] is True
assert result["persist"] is True
assert result["accepted"] is False
assert result["persist"] is False
assert "not found" in result["message"]
def test_warning_includes_suggestions(self):
@ -456,30 +457,29 @@ class TestValidateApiNotFound:
assert "not found" in result["message"]
# -- validate — API unreachable — accept and persist everything ----------------
# -- validate — API unreachable — reject with guidance ----------------
class TestValidateApiFallback:
def test_any_model_accepted_when_api_down(self):
def test_any_model_rejected_when_api_down(self):
result = _validate("anthropic/claude-opus-4.6", api_models=None)
assert result["accepted"] is True
assert result["persist"] is True
assert result["accepted"] is False
assert result["persist"] is False
def test_unknown_model_also_accepted_when_api_down(self):
"""No hardcoded catalog gatekeeping — accept, persist, and warn."""
def test_unknown_model_also_rejected_when_api_down(self):
result = _validate("anthropic/claude-next-gen", api_models=None)
assert result["accepted"] is True
assert result["persist"] is True
assert result["accepted"] is False
assert result["persist"] is False
assert "could not reach" in result["message"].lower()
def test_zai_model_accepted_when_api_down(self):
def test_zai_model_rejected_when_api_down(self):
result = _validate("glm-5", provider="zai", api_models=None)
assert result["accepted"] is True
assert result["persist"] is True
assert result["accepted"] is False
assert result["persist"] is False
def test_unknown_provider_accepted_when_api_down(self):
def test_unknown_provider_rejected_when_api_down(self):
result = _validate("some-model", provider="totally-unknown", api_models=None)
assert result["accepted"] is True
assert result["persist"] is True
assert result["accepted"] is False
assert result["persist"] is False
def test_custom_endpoint_warns_with_probed_url_and_v1_hint(self):
with patch(
@ -499,8 +499,8 @@ class TestValidateApiFallback:
base_url="http://localhost:8000",
)
assert result["accepted"] is True
assert result["persist"] is True
assert result["accepted"] is False
assert result["persist"] is False
assert "http://localhost:8000/v1/models" in result["message"]
assert "http://localhost:8000/v1" in result["message"]

View file

@ -15,7 +15,7 @@ def test_opencode_go_appears_when_api_key_set():
opencode_go = next((p for p in providers if p["slug"] == "opencode-go"), None)
assert opencode_go is not None, "opencode-go should appear when OPENCODE_GO_API_KEY is set"
assert opencode_go["models"] == ["glm-5.1", "glm-5", "kimi-k2.5", "mimo-v2-pro", "mimo-v2-omni", "minimax-m2.7", "minimax-m2.5"]
assert opencode_go["models"] == ["kimi-k2.5", "glm-5.1", "glm-5", "mimo-v2-pro", "mimo-v2-omni", "minimax-m2.7", "minimax-m2.5"]
# opencode-go can appear as "built-in" (from PROVIDER_TO_MODELS_DEV when
# models.dev is reachable) or "hermes" (from HERMES_OVERLAYS fallback when
# the API is unavailable, e.g. in CI).

View file

@ -799,35 +799,30 @@ class TestEdgeCases:
assert default.skill_count == 0
def test_gateway_running_check_with_pid_file(self, profile_env):
"""Verify _check_gateway_running reads pid file and probes os.kill."""
"""Verify _check_gateway_running uses the shared gateway PID validator."""
from hermes_cli.profiles import _check_gateway_running
tmp_path = profile_env
default_home = tmp_path / ".hermes"
# No pid file -> not running
assert _check_gateway_running(default_home) is False
# Write a PID file with a JSON payload
pid_file = default_home / "gateway.pid"
pid_file.write_text(json.dumps({"pid": 99999}))
# os.kill(99999, 0) should raise ProcessLookupError -> not running
assert _check_gateway_running(default_home) is False
# Mock os.kill to simulate a running process
with patch("os.kill", return_value=None):
with patch("gateway.status.get_running_pid", return_value=99999) as mock_get_running_pid:
assert _check_gateway_running(default_home) is True
mock_get_running_pid.assert_called_once_with(
default_home / "gateway.pid",
cleanup_stale=False,
)
def test_gateway_running_check_plain_pid(self, profile_env):
"""Pid file containing just a number (legacy format)."""
"""Shared PID validator returning None means the profile is not running."""
from hermes_cli.profiles import _check_gateway_running
tmp_path = profile_env
default_home = tmp_path / ".hermes"
pid_file = default_home / "gateway.pid"
pid_file.write_text("99999")
with patch("os.kill", return_value=None):
assert _check_gateway_running(default_home) is True
with patch("gateway.status.get_running_pid", return_value=None) as mock_get_running_pid:
assert _check_gateway_running(default_home) is False
mock_get_running_pid.assert_called_once_with(
default_home / "gateway.pid",
cleanup_stale=False,
)
def test_profile_name_boundary_single_char(self):
"""Single alphanumeric character is valid."""

View file

@ -0,0 +1,53 @@
"""_tui_need_npm_install: auto npm when lockfile ahead of node_modules."""
import os
from pathlib import Path
import pytest
@pytest.fixture
def main_mod():
import hermes_cli.main as m
return m
def _touch_ink(root: Path) -> None:
ink = root / "node_modules" / "@hermes" / "ink" / "package.json"
ink.parent.mkdir(parents=True, exist_ok=True)
ink.write_text("{}")
def test_need_install_when_ink_missing(tmp_path: Path, main_mod) -> None:
(tmp_path / "package-lock.json").write_text("{}")
assert main_mod._tui_need_npm_install(tmp_path) is True
def test_need_install_when_lock_newer_than_marker(tmp_path: Path, main_mod) -> None:
_touch_ink(tmp_path)
(tmp_path / "package-lock.json").write_text("{}")
(tmp_path / "node_modules" / ".package-lock.json").write_text("{}")
os.utime(tmp_path / "package-lock.json", (200, 200))
os.utime(tmp_path / "node_modules" / ".package-lock.json", (100, 100))
assert main_mod._tui_need_npm_install(tmp_path) is True
def test_no_install_when_lock_older_than_marker(tmp_path: Path, main_mod) -> None:
_touch_ink(tmp_path)
(tmp_path / "package-lock.json").write_text("{}")
(tmp_path / "node_modules" / ".package-lock.json").write_text("{}")
os.utime(tmp_path / "package-lock.json", (100, 100))
os.utime(tmp_path / "node_modules" / ".package-lock.json", (200, 200))
assert main_mod._tui_need_npm_install(tmp_path) is False
def test_need_install_when_marker_missing(tmp_path: Path, main_mod) -> None:
_touch_ink(tmp_path)
(tmp_path / "package-lock.json").write_text("{}")
assert main_mod._tui_need_npm_install(tmp_path) is True
def test_no_install_without_lockfile_when_ink_present(tmp_path: Path, main_mod) -> None:
_touch_ink(tmp_path)
assert main_mod._tui_need_npm_install(tmp_path) is False

View file

@ -0,0 +1,121 @@
from argparse import Namespace
import sys
import types
import pytest
def _args(**overrides):
base = {
"continue_last": None,
"resume": None,
"tui": True,
}
base.update(overrides)
return Namespace(**base)
@pytest.fixture
def main_mod(monkeypatch):
import hermes_cli.main as mod
monkeypatch.setattr(mod, "_has_any_provider_configured", lambda: True)
return mod
def test_cmd_chat_tui_continue_uses_latest_tui_session(monkeypatch, main_mod):
calls = []
captured = {}
def fake_resolve_last(source="cli"):
calls.append(source)
return "20260408_235959_a1b2c3" if source == "tui" else None
def fake_launch(resume_session_id=None, tui_dev=False):
captured["resume"] = resume_session_id
raise SystemExit(0)
monkeypatch.setattr(main_mod, "_resolve_last_session", fake_resolve_last)
monkeypatch.setattr(main_mod, "_resolve_session_by_name_or_id", lambda val: val)
monkeypatch.setattr(main_mod, "_launch_tui", fake_launch)
with pytest.raises(SystemExit):
main_mod.cmd_chat(_args(continue_last=True))
assert calls == ["tui"]
assert captured["resume"] == "20260408_235959_a1b2c3"
def test_cmd_chat_tui_continue_falls_back_to_latest_cli_session(monkeypatch, main_mod):
calls = []
captured = {}
def fake_resolve_last(source="cli"):
calls.append(source)
if source == "tui":
return None
if source == "cli":
return "20260408_235959_d4e5f6"
return None
def fake_launch(resume_session_id=None, tui_dev=False):
captured["resume"] = resume_session_id
raise SystemExit(0)
monkeypatch.setattr(main_mod, "_resolve_last_session", fake_resolve_last)
monkeypatch.setattr(main_mod, "_resolve_session_by_name_or_id", lambda val: val)
monkeypatch.setattr(main_mod, "_launch_tui", fake_launch)
with pytest.raises(SystemExit):
main_mod.cmd_chat(_args(continue_last=True))
assert calls == ["tui", "cli"]
assert captured["resume"] == "20260408_235959_d4e5f6"
def test_cmd_chat_tui_resume_resolves_title_before_launch(monkeypatch, main_mod):
captured = {}
def fake_launch(resume_session_id=None, tui_dev=False):
captured["resume"] = resume_session_id
raise SystemExit(0)
monkeypatch.setattr(main_mod, "_resolve_session_by_name_or_id", lambda val: "20260409_000000_aa11bb")
monkeypatch.setattr(main_mod, "_launch_tui", fake_launch)
with pytest.raises(SystemExit):
main_mod.cmd_chat(_args(resume="my t0p session"))
assert captured["resume"] == "20260409_000000_aa11bb"
def test_print_tui_exit_summary_includes_resume_and_token_totals(monkeypatch, capsys):
import hermes_cli.main as main_mod
class _FakeDB:
def get_session(self, session_id):
assert session_id == "20260409_000001_abc123"
return {
"message_count": 2,
"input_tokens": 10,
"output_tokens": 6,
"cache_read_tokens": 2,
"cache_write_tokens": 2,
"reasoning_tokens": 1,
}
def get_session_title(self, _session_id):
return "demo title"
def close(self):
return None
monkeypatch.setitem(sys.modules, "hermes_state", types.SimpleNamespace(SessionDB=lambda: _FakeDB()))
main_mod._print_tui_exit_summary("20260409_000001_abc123")
out = capsys.readouterr().out
assert "Resume this session with:" in out
assert "hermes --tui --resume 20260409_000001_abc123" in out
assert 'hermes --tui -c "demo title"' in out
assert "Tokens: 21 (in 10, out 6, cache 4, reasoning 1)" in out

View file

@ -13,9 +13,29 @@ from unittest.mock import patch, MagicMock
import pytest
import hermes_cli.gateway as gateway_cli
import hermes_cli.main as cli_main
from hermes_cli.main import cmd_update
# ---------------------------------------------------------------------------
# Skip the real-time sleeps inside cmd_update's restart-verification path
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _no_restart_verify_sleep(monkeypatch):
"""hermes_cli/main.py uses time.sleep(3) after systemctl restart to
verify the service survived. Tests mock subprocess.run nothing
actually restarts so the 3s wait is dead time.
main.py does ``import time as _time`` at both module level (line 167)
and inside functions (lines 3281, 4384, 4401). Patching the global
``time.sleep`` affects only the duration of this test.
"""
import time as _real_time
monkeypatch.setattr(_real_time, "sleep", lambda *_a, **_k: None)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@ -915,3 +935,183 @@ class TestGatewayModeWritesExitCodeEarly:
assert exit_code_existed_at_restart, "systemctl restart was never called"
assert exit_code_existed_at_restart[0] is True, \
".update_exit_code must exist BEFORE systemctl restart (cgroup kill race)"
class TestCmdUpdateLegacyGatewayWarning:
"""Tests for the legacy hermes.service warning printed by `hermes update`.
Users who installed Hermes before the service rename often have a
dormant ``hermes.service`` that starts flap-fighting the current
``hermes-gateway.service`` after PR #5646. Every ``hermes update``
should remind them to run ``hermes gateway migrate-legacy`` until
they do.
"""
_OUR_UNIT_TEXT = (
"[Unit]\nDescription=Hermes Gateway\n[Service]\n"
"ExecStart=/usr/bin/python -m hermes_cli.main gateway run --replace\n"
)
@patch("shutil.which", return_value=None)
@patch("subprocess.run")
def test_update_prints_legacy_warning_when_detected(
self, mock_run, _mock_which, mock_args, capsys, tmp_path, monkeypatch,
):
"""Legacy units present → warning in update output with migrate command."""
user_dir = tmp_path / "user"
system_dir = tmp_path / "system"
user_dir.mkdir()
system_dir.mkdir()
legacy_path = user_dir / "hermes.service"
legacy_path.write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
monkeypatch.setattr(
gateway_cli,
"_legacy_unit_search_paths",
lambda: [(False, user_dir), (True, system_dir)],
)
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
mock_run.side_effect = _make_run_side_effect(commit_count="3")
with patch.object(gateway_cli, "find_gateway_pids", return_value=[]):
cmd_update(mock_args)
captured = capsys.readouterr().out
assert "Legacy Hermes gateway unit(s) detected" in captured
assert "hermes.service" in captured
assert "hermes gateway migrate-legacy" in captured
assert "(user scope)" in captured
@patch("shutil.which", return_value=None)
@patch("subprocess.run")
def test_update_silent_when_no_legacy_units(
self, mock_run, _mock_which, mock_args, capsys, tmp_path, monkeypatch,
):
"""No legacy units → no warning printed."""
user_dir = tmp_path / "user"
system_dir = tmp_path / "system"
user_dir.mkdir()
system_dir.mkdir()
monkeypatch.setattr(
gateway_cli,
"_legacy_unit_search_paths",
lambda: [(False, user_dir), (True, system_dir)],
)
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
mock_run.side_effect = _make_run_side_effect(commit_count="3")
with patch.object(gateway_cli, "find_gateway_pids", return_value=[]):
cmd_update(mock_args)
captured = capsys.readouterr().out
assert "Legacy Hermes gateway" not in captured
assert "migrate-legacy" not in captured
@patch("shutil.which", return_value=None)
@patch("subprocess.run")
def test_update_does_not_flag_profile_units(
self, mock_run, _mock_which, mock_args, capsys, tmp_path, monkeypatch,
):
"""Profile units (hermes-gateway-coder.service) must not trigger the warning.
This is the core safety invariant: the legacy allowlist is
``hermes.service`` only, no globs.
"""
user_dir = tmp_path / "user"
system_dir = tmp_path / "system"
user_dir.mkdir()
system_dir.mkdir()
# Drop a profile unit that an over-eager glob would match
(user_dir / "hermes-gateway-coder.service").write_text(
self._OUR_UNIT_TEXT, encoding="utf-8"
)
(user_dir / "hermes-gateway.service").write_text(
self._OUR_UNIT_TEXT, encoding="utf-8"
)
monkeypatch.setattr(
gateway_cli,
"_legacy_unit_search_paths",
lambda: [(False, user_dir), (True, system_dir)],
)
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
mock_run.side_effect = _make_run_side_effect(commit_count="3")
with patch.object(gateway_cli, "find_gateway_pids", return_value=[]):
cmd_update(mock_args)
captured = capsys.readouterr().out
assert "Legacy Hermes gateway" not in captured
assert "hermes-gateway-coder.service" not in captured # not flagged
@patch("shutil.which", return_value=None)
@patch("subprocess.run")
def test_update_skips_legacy_check_on_non_systemd_platforms(
self, mock_run, _mock_which, mock_args, capsys, tmp_path, monkeypatch,
):
"""macOS / Windows / Termux — skip check entirely since the rename
is systemd-specific."""
user_dir = tmp_path / "user"
user_dir.mkdir()
# Put a file that WOULD match if the check ran
(user_dir / "hermes.service").write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
monkeypatch.setattr(
gateway_cli,
"_legacy_unit_search_paths",
lambda: [(False, user_dir), (True, tmp_path / "system")],
)
monkeypatch.setattr(gateway_cli, "is_macos", lambda: True)
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: False)
mock_run.side_effect = _make_run_side_effect(
commit_count="3", launchctl_loaded=False,
)
with patch.object(gateway_cli, "find_gateway_pids", return_value=[]):
cmd_update(mock_args)
captured = capsys.readouterr().out
# Must not print the warning on non-systemd platforms
assert "Legacy Hermes gateway" not in captured
@patch("shutil.which", return_value=None)
@patch("subprocess.run")
def test_update_lists_system_scope_unit_with_sudo_hint(
self, mock_run, _mock_which, mock_args, capsys, tmp_path, monkeypatch,
):
"""System-scope legacy units need sudo — the warning must point that out."""
user_dir = tmp_path / "user"
system_dir = tmp_path / "system"
user_dir.mkdir()
system_dir.mkdir()
(system_dir / "hermes.service").write_text(self._OUR_UNIT_TEXT, encoding="utf-8")
monkeypatch.setattr(
gateway_cli,
"_legacy_unit_search_paths",
lambda: [(False, user_dir), (True, system_dir)],
)
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
mock_run.side_effect = _make_run_side_effect(commit_count="3")
with patch.object(gateway_cli, "find_gateway_pids", return_value=[]):
cmd_update(mock_args)
captured = capsys.readouterr().out
assert "Legacy Hermes gateway" in captured
assert "(system scope)" in captured
assert "sudo" in captured

View file

@ -31,6 +31,31 @@ def _isolate_env(tmp_path, monkeypatch):
monkeypatch.delenv("RETAINDB_PROJECT", raising=False)
@pytest.fixture(autouse=True)
def _cap_retaindb_sleeps(monkeypatch):
"""Cap production-code sleeps so background-thread tests run fast.
The retaindb ``_WriteQueue._flush_row`` does ``time.sleep(2)`` after
errors. Across multiple tests that trigger the retry path, that adds
up. Cap the module's bound ``time.sleep`` to 0.05s — tests don't care
about the exact retry delay, only that it happens. The test file's
own ``time.sleep`` stays real since it uses a different reference.
"""
try:
from plugins.memory import retaindb as _retaindb
except ImportError:
return
real_sleep = _retaindb.time.sleep
def _capped_sleep(seconds):
return real_sleep(min(float(seconds), 0.05))
import types as _types
fake_time = _types.SimpleNamespace(sleep=_capped_sleep, time=_retaindb.time.time)
monkeypatch.setattr(_retaindb, "time", fake_time)
# We need the repo root on sys.path so the plugin can import agent.memory_provider
import sys
_repo_root = str(Path(__file__).resolve().parents[2])
@ -130,16 +155,18 @@ class TestWriteQueue:
def test_enqueue_creates_row(self, tmp_path):
q, client, db_path = self._make_queue(tmp_path)
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
# Give the writer thread a moment to process
time.sleep(1)
# shutdown() blocks until the writer thread drains the queue — no need
# to pre-sleep (the old 1s sleep was a just-in-case wait, but shutdown
# does the right thing).
q.shutdown()
# If ingest succeeded, the row should be deleted
client.ingest_session.assert_called_once()
def test_enqueue_persists_to_sqlite(self, tmp_path):
client = MagicMock()
# Make ingest hang so the row stays in SQLite
client.ingest_session = MagicMock(side_effect=lambda *a, **kw: time.sleep(5))
# Make ingest slow so the row is still in SQLite when we peek.
# 0.5s is plenty — the test just needs the flush to still be in-flight.
client.ingest_session = MagicMock(side_effect=lambda *a, **kw: time.sleep(0.5))
db_path = tmp_path / "test_queue.db"
q = _WriteQueue(client, db_path)
q.enqueue("user1", "sess1", [{"role": "user", "content": "test"}])
@ -154,8 +181,7 @@ class TestWriteQueue:
def test_flush_deletes_row_on_success(self, tmp_path):
q, client, db_path = self._make_queue(tmp_path)
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
time.sleep(1)
q.shutdown()
q.shutdown() # blocks until drain
# Row should be gone
conn = sqlite3.connect(str(db_path))
rows = conn.execute("SELECT COUNT(*) FROM pending").fetchone()[0]
@ -168,14 +194,20 @@ class TestWriteQueue:
db_path = tmp_path / "test_queue.db"
q = _WriteQueue(client, db_path)
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
time.sleep(3) # Allow retry + sleep(2) in _flush_row
# Poll for the error to be recorded (max 2s), instead of a fixed 3s wait.
deadline = time.time() + 2.0
last_error = None
while time.time() < deadline:
conn = sqlite3.connect(str(db_path))
row = conn.execute("SELECT last_error FROM pending").fetchone()
conn.close()
if row and row[0]:
last_error = row[0]
break
time.sleep(0.05)
q.shutdown()
# Row should still exist with error recorded
conn = sqlite3.connect(str(db_path))
row = conn.execute("SELECT last_error FROM pending").fetchone()
conn.close()
assert row is not None
assert "API down" in row[0]
assert last_error is not None
assert "API down" in last_error
def test_thread_local_connection_reuse(self, tmp_path):
q, _, _ = self._make_queue(tmp_path)
@ -193,14 +225,27 @@ class TestWriteQueue:
client1.ingest_session = MagicMock(side_effect=RuntimeError("fail"))
q1 = _WriteQueue(client1, db_path)
q1.enqueue("user1", "sess1", [{"role": "user", "content": "lost turn"}])
time.sleep(3)
# Wait until the error is recorded (poll with short interval).
deadline = time.time() + 2.0
while time.time() < deadline:
conn = sqlite3.connect(str(db_path))
row = conn.execute("SELECT last_error FROM pending").fetchone()
conn.close()
if row and row[0]:
break
time.sleep(0.05)
q1.shutdown()
# Now create a new queue — it should replay the pending rows
client2 = MagicMock()
client2.ingest_session = MagicMock(return_value={"status": "ok"})
q2 = _WriteQueue(client2, db_path)
time.sleep(2)
# Poll for the replay to happen.
deadline = time.time() + 2.0
while time.time() < deadline:
if client2.ingest_session.called:
break
time.sleep(0.05)
q2.shutdown()
# The replayed row should have been ingested via client2

View file

@ -0,0 +1,34 @@
"""Fast-path fixtures shared across tests/run_agent/.
Many tests in this directory exercise the retry/backoff paths in the
agent loop. Production code uses ``jittered_backoff(base_delay=5.0)``
with a ``while time.time() < sleep_end`` loop a single retry test
spends 5+ seconds of real wall-clock time on backoff waits.
Mocking ``jittered_backoff`` to return 0.0 collapses the while-loop
to a no-op (``time.time() < time.time() + 0`` is false immediately),
which handles the most common case without touching ``time.sleep``.
We deliberately DO NOT mock ``time.sleep`` here some tests
(test_interrupt_propagation, test_primary_runtime_restore, etc.) use
the real ``time.sleep`` for threading coordination or assert that it
was called with specific values. Tests that want to additionally
fast-path direct ``time.sleep(N)`` calls in production code should
monkeypatch ``run_agent.time.sleep`` locally (see
``test_anthropic_error_handling.py`` for the pattern).
"""
from __future__ import annotations
import pytest
@pytest.fixture(autouse=True)
def _fast_retry_backoff(monkeypatch):
"""Short-circuit retry backoff for all tests in this directory."""
try:
import run_agent
except ImportError:
return
monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0)

View file

@ -32,6 +32,7 @@ class TestGeneric400Heuristic:
from run_agent import AIAgent
a = AIAgent(
api_key="test-key-12345",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,

View file

@ -19,6 +19,24 @@ import pytest
from agent.context_compressor import SUMMARY_PREFIX
from run_agent import AIAgent
import run_agent
# ---------------------------------------------------------------------------
# Fast backoff for compression retry tests
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _no_compression_sleep(monkeypatch):
"""Short-circuit the 2s time.sleep between compression retries.
Production code has ``time.sleep(2)`` in multiple places after a 413/context
compression, for rate-limit smoothing. Tests assert behavior, not timing.
"""
import time as _time
monkeypatch.setattr(_time, "sleep", lambda *_a, **_k: None)
monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0)
# ---------------------------------------------------------------------------
@ -69,6 +87,7 @@ def agent():
):
a = AIAgent(
api_key="test-key-1234567890",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,

View file

@ -29,6 +29,8 @@ class TestFlushDeduplication:
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
from run_agent import AIAgent
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
session_db=session_db,
@ -271,6 +273,8 @@ class TestFlushIdxInit:
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
from run_agent import AIAgent
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -283,6 +287,8 @@ class TestFlushIdxInit:
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
from run_agent import AIAgent
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,

View file

@ -27,6 +27,39 @@ from gateway.config import Platform
from gateway.session import SessionSource
# ---------------------------------------------------------------------------
# Fast backoff for tests that exercise the retry loop
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _no_backoff_wait(monkeypatch):
"""Short-circuit retry backoff so tests don't block on real wall-clock waits.
The production code uses jittered_backoff() with a 5s base delay plus a
tight time.sleep(0.2) loop. Without this patch, each 429/500/529 retry
test burns ~10s of real time on CI across six tests that's ~60s for
behavior we're not asserting against timing.
Tests assert retry counts and final results, never wait durations.
"""
import asyncio as _asyncio
import time as _time
monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0)
monkeypatch.setattr(_time, "sleep", lambda *_a, **_k: None)
# Also fast-path asyncio.sleep — the gateway's _run_agent path has
# several await asyncio.sleep(...) calls that add real wall-clock time.
_real_asyncio_sleep = _asyncio.sleep
async def _fast_sleep(delay=0, *args, **kwargs):
# Yield to the event loop but skip the actual delay.
await _real_asyncio_sleep(0)
monkeypatch.setattr(_asyncio, "sleep", _fast_sleep)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

View file

@ -37,6 +37,8 @@ class TestFlushAfterCompression:
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
from run_agent import AIAgent
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
session_db=session_db,

View file

@ -19,6 +19,8 @@ from run_agent import AIAgent
def test_create_openai_client_does_not_mutate_input_kwargs(mock_openai):
mock_openai.return_value = MagicMock()
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,

View file

@ -23,6 +23,8 @@ from run_agent import AIAgent
def _make_agent():
return AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,

View file

@ -13,6 +13,24 @@ from unittest.mock import MagicMock, patch, call
import pytest
@pytest.fixture(autouse=True)
def _mock_runtime_provider(monkeypatch):
"""run_job calls resolve_runtime_provider which can try real network
auto-detection (~4s of socket timeouts in hermetic CI). Mock it out
since these tests don't care about provider resolution — the agent
is mocked too."""
import hermes_cli.runtime_provider as rp
def _fake_resolve(*args, **kwargs):
return {
"provider": "openrouter",
"api_key": "test-key",
"base_url": "https://openrouter.ai/api/v1",
"model": "test/model",
"api_mode": "chat_completions",
}
monkeypatch.setattr(rp, "resolve_runtime_provider", _fake_resolve)
class TestCronJobCleanup:
"""cron/scheduler.py — end_session + close in the finally block."""

View file

@ -11,6 +11,16 @@ from unittest.mock import MagicMock, patch
import pytest
from run_agent import AIAgent
import run_agent
@pytest.fixture(autouse=True)
def _no_fallback_wait(monkeypatch):
"""Short-circuit time.sleep in fallback/recovery paths so tests don't
block on the ``min(3 + retry_count, 8)`` wait before a primary retry."""
import time as _time
monkeypatch.setattr(_time, "sleep", lambda *_a, **_k: None)
monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0)
def _make_tool_defs(*names: str) -> list:
@ -36,6 +46,7 @@ def _make_agent(fallback_model=None):
):
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,

View file

@ -45,6 +45,7 @@ def test_plugin_engine_gets_context_length_on_init():
agent = AIAgent(
api_key="test-key-1234567890",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
@ -75,6 +76,7 @@ def test_plugin_engine_update_model_args():
agent = AIAgent(
model="openrouter/auto",
api_key="test-key-1234567890",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,

View file

@ -19,6 +19,7 @@ def _make_agent(fallback_model=None):
):
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,

View file

@ -60,6 +60,9 @@ def _make_agent(monkeypatch, provider, api_mode="chat_completions", base_url="ht
)
if model:
kwargs["model"] = model
base_url="https://openrouter.ai/api/v1",
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
return AIAgent(**kwargs)
@ -248,6 +251,19 @@ class TestBuildApiKwargsChatCompletionsServiceTier:
assert "service_tier" not in kwargs
class TestBuildApiKwargsKimiFixedTemperature:
def test_kimi_for_coding_forces_temperature_on_main_chat_path(self, monkeypatch):
agent = _make_agent(
monkeypatch,
"kimi-coding",
base_url="https://api.kimi.com/coding/v1",
model="kimi-for-coding",
)
messages = [{"role": "user", "content": "hi"}]
kwargs = agent._build_api_kwargs(messages)
assert kwargs["temperature"] == 0.6
class TestBuildApiKwargsAIGateway:
def test_uses_chat_completions_format(self, monkeypatch):
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1", model="gpt-4o")

View file

@ -55,6 +55,7 @@ def agent():
):
a = AIAgent(
api_key="test-key-1234567890",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
@ -76,6 +77,7 @@ def agent_with_memory_tool():
):
a = AIAgent(
api_key="test-k...7890",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
@ -112,12 +114,14 @@ def test_aiagent_reuses_existing_errors_log_handler():
):
AIAgent(
api_key="test-k...7890",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
AIAgent(
api_key="test-k...7890",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
@ -491,6 +495,7 @@ class TestInit:
):
a = AIAgent(
api_key="test-key-1234567890",
base_url="https://openrouter.ai/api/v1",
model="openai/gpt-4o",
quiet_mode=True,
skip_context_files=True,
@ -542,6 +547,7 @@ class TestInit:
):
a = AIAgent(
api_key="test-key-1234567890",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
@ -557,6 +563,7 @@ class TestInit:
):
a = AIAgent(
api_key="test-key-1234567890",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
@ -694,6 +701,7 @@ class TestBuildSystemPrompt:
):
agent = AIAgent(
api_key="test-k...7890",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
@ -726,6 +734,7 @@ class TestToolUseEnforcementConfig:
a = AIAgent(
model=model,
api_key="test-key-1234567890",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
@ -822,6 +831,7 @@ class TestToolUseEnforcementConfig:
):
a = AIAgent(
api_key="test-key-1234567890",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
@ -3433,7 +3443,7 @@ class TestAnthropicBaseUrlPassthrough:
):
mock_build.return_value = MagicMock()
a = AIAgent(
api_key="sk-ant-api03-test1234567890",
api_key="sk-ant...7890",
api_mode="anthropic_messages",
quiet_mode=True,
skip_context_files=True,
@ -3457,6 +3467,7 @@ class TestAnthropicCredentialRefresh:
mock_build.side_effect = [old_client, new_client]
agent = AIAgent(
api_key="sk-ant-oat01-stale-token",
base_url="https://openrouter.ai/api/v1",
api_mode="anthropic_messages",
quiet_mode=True,
skip_context_files=True,
@ -3487,6 +3498,7 @@ class TestAnthropicCredentialRefresh:
):
agent = AIAgent(
api_key="sk-ant-oat01-same-token",
base_url="https://openrouter.ai/api/v1",
api_mode="anthropic_messages",
quiet_mode=True,
skip_context_files=True,
@ -3514,6 +3526,7 @@ class TestAnthropicCredentialRefresh:
):
agent = AIAgent(
api_key="sk-ant-oat01-current-token",
base_url="https://openrouter.ai/api/v1",
api_mode="anthropic_messages",
quiet_mode=True,
skip_context_files=True,

View file

@ -12,6 +12,15 @@ sys.modules.setdefault("fal_client", types.SimpleNamespace())
import run_agent
@pytest.fixture(autouse=True)
def _no_codex_backoff(monkeypatch):
"""Short-circuit retry backoff so Codex retry tests don't block on real
wall-clock waits (5s jittered_backoff base delay + tight time.sleep loop)."""
import time as _time
monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0)
monkeypatch.setattr(_time, "sleep", lambda *_a, **_k: None)
def _patch_agent_bootstrap(monkeypatch):
monkeypatch.setattr(
run_agent,

View file

@ -80,6 +80,8 @@ class TestStreamingAccumulator:
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -120,6 +122,8 @@ class TestStreamingAccumulator:
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -167,6 +171,8 @@ class TestStreamingAccumulator:
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -205,6 +211,8 @@ class TestStreamingAccumulator:
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -245,6 +253,8 @@ class TestStreamingCallbacks:
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -277,6 +287,8 @@ class TestStreamingCallbacks:
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -308,6 +320,8 @@ class TestStreamingCallbacks:
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -346,6 +360,8 @@ class TestStreamingCallbacks:
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -381,6 +397,8 @@ class TestStreamingCallbacks:
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -428,6 +446,8 @@ class TestStreamingFallback:
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -455,6 +475,8 @@ class TestStreamingFallback:
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -477,6 +499,8 @@ class TestStreamingFallback:
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -500,6 +524,8 @@ class TestStreamingFallback:
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -542,6 +568,8 @@ class TestStreamingFallback:
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -577,6 +605,8 @@ class TestStreamingFallback:
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -619,6 +649,8 @@ class TestReasoningStreaming:
mock_create.return_value = mock_client
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -646,6 +678,8 @@ class TestHasStreamConsumers:
def test_no_consumers(self):
from run_agent import AIAgent
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -656,6 +690,8 @@ class TestHasStreamConsumers:
def test_delta_callback_set(self):
from run_agent import AIAgent
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -667,6 +703,8 @@ class TestHasStreamConsumers:
def test_stream_callback_set(self):
from run_agent import AIAgent
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -688,6 +726,8 @@ class TestCodexStreamCallbacks:
deltas = []
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -729,6 +769,8 @@ class TestCodexStreamCallbacks:
from run_agent import AIAgent
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -792,6 +834,8 @@ class TestCodexStreamCallbacks:
)
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -810,6 +854,8 @@ class TestCodexStreamCallbacks:
from run_agent import AIAgent
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,
@ -861,6 +907,8 @@ class TestAnthropicStreamCallbacks:
from run_agent import AIAgent
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test/model",
quiet_mode=True,
skip_context_files=True,

View file

@ -22,6 +22,7 @@ def _make_agent(session_db, *, platform: str):
):
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,

View file

@ -0,0 +1,28 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
def test_run_task_forces_kimi_fixed_temperature():
with patch("openai.OpenAI") as mock_openai:
client = MagicMock()
client.chat.completions.create.return_value = SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content="done", tool_calls=[]))]
)
mock_openai.return_value = client
from mini_swe_runner import MiniSWERunner
runner = MiniSWERunner(
model="kimi-for-coding",
base_url="https://api.kimi.com/coding/v1",
api_key="test-key",
env_type="local",
max_iterations=1,
)
runner._create_env = MagicMock()
runner._cleanup_env = MagicMock()
result = runner.run_task("2+2")
assert result["completed"] is True
assert client.chat.completions.create.call_args.kwargs["temperature"] == 0.6

View file

@ -34,3 +34,21 @@ def test_messaging_extra_includes_qrcode_for_weixin_setup():
messaging_extra = optional_dependencies["messaging"]
assert any(dep.startswith("qrcode") for dep in messaging_extra)
def test_dingtalk_extra_includes_qrcode_for_qr_auth():
"""DingTalk's QR-code device-flow auth (hermes_cli/dingtalk_auth.py)
needs the qrcode package."""
optional_dependencies = _load_optional_dependencies()
dingtalk_extra = optional_dependencies["dingtalk"]
assert any(dep.startswith("qrcode") for dep in dingtalk_extra)
def test_feishu_extra_includes_qrcode_for_qr_login():
"""Feishu's QR login flow (gateway/platforms/feishu.py) needs the
qrcode package."""
optional_dependencies = _load_optional_dependencies()
feishu_extra = optional_dependencies["feishu"]
assert any(dep.startswith("qrcode") for dep in feishu_extra)

View file

@ -159,18 +159,34 @@ class TestCodeExecutionTZ:
return _json.dumps({"error": f"unexpected tool call: {function_name}"})
def test_tz_injected_when_configured(self):
"""When HERMES_TIMEZONE is set, child process sees TZ env var."""
"""When HERMES_TIMEZONE is set, child process sees TZ env var.
Verified alongside leak-prevention + empty-TZ handling in one
subprocess call so we don't pay 3x the subprocess startup cost
(each execute_code spawns a real Python subprocess ~3s).
"""
import json as _json
os.environ["HERMES_TIMEZONE"] = "Asia/Kolkata"
# One subprocess, three things checked:
# 1) TZ is injected as "Asia/Kolkata"
# 2) HERMES_TIMEZONE itself does NOT leak into the child env
probe = (
'import os; '
'print("TZ=" + os.environ.get("TZ", "NOT_SET")); '
'print("HERMES_TIMEZONE=" + os.environ.get("HERMES_TIMEZONE", "NOT_SET"))'
)
with patch("model_tools.handle_function_call", side_effect=self._mock_handle):
result = _json.loads(self._execute_code(
code='import os; print(os.environ.get("TZ", "NOT_SET"))',
task_id="tz-test",
code=probe,
task_id="tz-combined-test",
enabled_tools=[],
))
assert result["status"] == "success"
assert "Asia/Kolkata" in result["output"]
assert "TZ=Asia/Kolkata" in result["output"]
assert "HERMES_TIMEZONE=NOT_SET" in result["output"], (
"HERMES_TIMEZONE should not leak into child env (only TZ)"
)
def test_tz_not_injected_when_empty(self):
"""When HERMES_TIMEZONE is not set, child process has no TZ."""
@ -186,20 +202,6 @@ class TestCodeExecutionTZ:
assert result["status"] == "success"
assert "NOT_SET" in result["output"]
def test_hermes_timezone_not_leaked_to_child(self):
"""HERMES_TIMEZONE itself must NOT appear in child env (only TZ)."""
import json as _json
os.environ["HERMES_TIMEZONE"] = "Asia/Kolkata"
with patch("model_tools.handle_function_call", side_effect=self._mock_handle):
result = _json.loads(self._execute_code(
code='import os; print(os.environ.get("HERMES_TIMEZONE", "NOT_SET"))',
task_id="tz-leak-test",
enabled_tools=[],
))
assert result["status"] == "success"
assert "NOT_SET" in result["output"]
# =========================================================================
# Cron timezone-aware scheduling

View file

@ -31,6 +31,29 @@ def test_import_loads_env_from_hermes_home(tmp_path, monkeypatch):
assert os.getenv("OPENROUTER_API_KEY") == "from-hermes-home"
def test_generate_summary_custom_client_forces_kimi_temperature():
config = CompressionConfig(
summarization_model="kimi-for-coding",
temperature=0.3,
summary_target_tokens=100,
max_retries=1,
)
compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
compressor.config = config
compressor.logger = MagicMock()
compressor._use_call_llm = False
compressor.client = MagicMock()
compressor.client.chat.completions.create.return_value = SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content="[CONTEXT SUMMARY]: summary"))]
)
metrics = TrajectoryMetrics()
result = compressor._generate_summary("tool output", metrics)
assert result.startswith("[CONTEXT SUMMARY]:")
assert compressor.client.chat.completions.create.call_args.kwargs["temperature"] == 0.6
# ---------------------------------------------------------------------------
# CompressionConfig
# ---------------------------------------------------------------------------

View file

@ -11,6 +11,7 @@ each asyncio.run() gets a client bound to the current loop.
"""
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
@ -113,3 +114,30 @@ class TestSourceLineVerification:
"""_get_async_client method should exist."""
src = self._read_file()
assert "def _get_async_client(self)" in src
@pytest.mark.asyncio
async def test_generate_summary_async_custom_client_forces_kimi_temperature():
from trajectory_compressor import CompressionConfig, TrajectoryCompressor, TrajectoryMetrics
config = CompressionConfig(
summarization_model="kimi-for-coding",
temperature=0.3,
summary_target_tokens=100,
max_retries=1,
)
compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
compressor.config = config
compressor.logger = MagicMock()
compressor._use_call_llm = False
async_client = MagicMock()
async_client.chat.completions.create = MagicMock(return_value=SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content="[CONTEXT SUMMARY]: summary"))]
))
compressor._get_async_client = MagicMock(return_value=async_client)
metrics = TrajectoryMetrics()
result = await compressor._generate_summary_async("tool output", metrics)
assert result.startswith("[CONTEXT SUMMARY]:")
assert async_client.chat.completions.create.call_args.kwargs["temperature"] == 0.6

View file

@ -0,0 +1,440 @@
import json
import sys
import threading
import time
import types
from pathlib import Path
from unittest.mock import patch
from tui_gateway import server
class _ChunkyStdout:
def __init__(self):
self.parts: list[str] = []
def write(self, text: str) -> int:
for ch in text:
self.parts.append(ch)
time.sleep(0.0001)
return len(text)
def flush(self) -> None:
return None
class _BrokenStdout:
def write(self, text: str) -> int:
raise BrokenPipeError
def flush(self) -> None:
return None
def test_write_json_serializes_concurrent_writes(monkeypatch):
out = _ChunkyStdout()
monkeypatch.setattr(server, "_real_stdout", out)
threads = [
threading.Thread(target=server.write_json, args=({"seq": i, "text": "x" * 24},))
for i in range(8)
]
for t in threads:
t.start()
for t in threads:
t.join()
lines = "".join(out.parts).splitlines()
assert len(lines) == 8
assert {json.loads(line)["seq"] for line in lines} == set(range(8))
def test_write_json_returns_false_on_broken_pipe(monkeypatch):
monkeypatch.setattr(server, "_real_stdout", _BrokenStdout())
assert server.write_json({"ok": True}) is False
def test_status_callback_emits_kind_and_text():
with patch("tui_gateway.server._emit") as emit:
cb = server._agent_cbs("sid")["status_callback"]
cb("context_pressure", "85% to compaction")
emit.assert_called_once_with(
"status.update",
"sid",
{"kind": "context_pressure", "text": "85% to compaction"},
)
def test_status_callback_accepts_single_message_argument():
with patch("tui_gateway.server._emit") as emit:
cb = server._agent_cbs("sid")["status_callback"]
cb("thinking...")
emit.assert_called_once_with(
"status.update",
"sid",
{"kind": "status", "text": "thinking..."},
)
def _session(agent=None, **extra):
return {
"agent": agent if agent is not None else types.SimpleNamespace(),
"session_key": "session-key",
"history": [],
"history_lock": threading.Lock(),
"history_version": 0,
"running": False,
"attached_images": [],
"image_counter": 0,
"cols": 80,
"slash_worker": None,
"show_reasoning": False,
"tool_progress_mode": "all",
**extra,
}
def test_config_set_yolo_toggles_session_scope():
from tools.approval import clear_session, is_session_yolo_enabled
server._sessions["sid"] = _session()
try:
resp_on = server.handle_request({"id": "1", "method": "config.set", "params": {"session_id": "sid", "key": "yolo"}})
assert resp_on["result"]["value"] == "1"
assert is_session_yolo_enabled("session-key") is True
resp_off = server.handle_request({"id": "2", "method": "config.set", "params": {"session_id": "sid", "key": "yolo"}})
assert resp_off["result"]["value"] == "0"
assert is_session_yolo_enabled("session-key") is False
finally:
clear_session("session-key")
server._sessions.clear()
def test_enable_gateway_prompts_sets_gateway_env(monkeypatch):
monkeypatch.delenv("HERMES_EXEC_ASK", raising=False)
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
server._enable_gateway_prompts()
assert server.os.environ["HERMES_GATEWAY_SESSION"] == "1"
assert server.os.environ["HERMES_EXEC_ASK"] == "1"
assert server.os.environ["HERMES_INTERACTIVE"] == "1"
def test_setup_status_reports_provider_config(monkeypatch):
monkeypatch.setattr("hermes_cli.main._has_any_provider_configured", lambda: False)
resp = server.handle_request({"id": "1", "method": "setup.status", "params": {}})
assert resp["result"]["provider_configured"] is False
def test_config_set_reasoning_updates_live_session_and_agent(tmp_path, monkeypatch):
monkeypatch.setattr(server, "_hermes_home", tmp_path)
agent = types.SimpleNamespace(reasoning_config=None)
server._sessions["sid"] = _session(agent=agent)
resp_effort = server.handle_request(
{"id": "1", "method": "config.set", "params": {"session_id": "sid", "key": "reasoning", "value": "low"}}
)
assert resp_effort["result"]["value"] == "low"
assert agent.reasoning_config == {"enabled": True, "effort": "low"}
resp_show = server.handle_request(
{"id": "2", "method": "config.set", "params": {"session_id": "sid", "key": "reasoning", "value": "show"}}
)
assert resp_show["result"]["value"] == "show"
assert server._sessions["sid"]["show_reasoning"] is True
def test_config_set_verbose_updates_session_mode_and_agent(tmp_path, monkeypatch):
monkeypatch.setattr(server, "_hermes_home", tmp_path)
agent = types.SimpleNamespace(verbose_logging=False)
server._sessions["sid"] = _session(agent=agent)
resp = server.handle_request(
{"id": "1", "method": "config.set", "params": {"session_id": "sid", "key": "verbose", "value": "cycle"}}
)
assert resp["result"]["value"] == "verbose"
assert server._sessions["sid"]["tool_progress_mode"] == "verbose"
assert agent.verbose_logging is True
def test_config_set_model_uses_live_switch_path(monkeypatch):
server._sessions["sid"] = _session()
seen = {}
def _fake_apply(sid, session, raw):
seen["args"] = (sid, session["session_key"], raw)
return {"value": "new/model", "warning": "catalog unreachable"}
monkeypatch.setattr(server, "_apply_model_switch", _fake_apply)
resp = server.handle_request(
{"id": "1", "method": "config.set", "params": {"session_id": "sid", "key": "model", "value": "new/model"}}
)
assert resp["result"]["value"] == "new/model"
assert resp["result"]["warning"] == "catalog unreachable"
assert seen["args"] == ("sid", "session-key", "new/model")
def test_config_set_model_global_persists(monkeypatch):
class _Agent:
provider = "openrouter"
model = "old/model"
base_url = ""
api_key = "sk-old"
def switch_model(self, **kwargs):
return None
result = types.SimpleNamespace(
success=True,
new_model="anthropic/claude-sonnet-4.6",
target_provider="anthropic",
api_key="sk-new",
base_url="https://api.anthropic.com",
api_mode="anthropic_messages",
warning_message="",
)
seen = {}
saved = {}
def _switch_model(**kwargs):
seen.update(kwargs)
return result
server._sessions["sid"] = _session(agent=_Agent())
monkeypatch.setattr("hermes_cli.model_switch.switch_model", _switch_model)
monkeypatch.setattr(server, "_restart_slash_worker", lambda session: None)
monkeypatch.setattr(server, "_emit", lambda *args, **kwargs: None)
monkeypatch.setattr("hermes_cli.config.save_config", lambda cfg: saved.update(cfg))
resp = server.handle_request(
{"id": "1", "method": "config.set", "params": {"session_id": "sid", "key": "model", "value": "anthropic/claude-sonnet-4.6 --global"}}
)
assert resp["result"]["value"] == "anthropic/claude-sonnet-4.6"
assert seen["is_global"] is True
assert saved["model"]["default"] == "anthropic/claude-sonnet-4.6"
assert saved["model"]["provider"] == "anthropic"
assert saved["model"]["base_url"] == "https://api.anthropic.com"
def test_config_set_personality_rejects_unknown_name(monkeypatch):
monkeypatch.setattr(server, "_available_personalities", lambda cfg=None: {"helpful": "You are helpful."})
resp = server.handle_request(
{"id": "1", "method": "config.set", "params": {"key": "personality", "value": "bogus"}}
)
assert "error" in resp
assert "Unknown personality" in resp["error"]["message"]
def test_config_set_personality_resets_history_and_returns_info(monkeypatch):
session = _session(agent=types.SimpleNamespace(), history=[{"role": "user", "text": "hi"}], history_version=4)
new_agent = types.SimpleNamespace(model="x")
emits = []
server._sessions["sid"] = session
monkeypatch.setattr(server, "_available_personalities", lambda cfg=None: {"helpful": "You are helpful."})
monkeypatch.setattr(server, "_make_agent", lambda sid, key, session_id=None: new_agent)
monkeypatch.setattr(server, "_session_info", lambda agent: {"model": getattr(agent, "model", "?")})
monkeypatch.setattr(server, "_restart_slash_worker", lambda session: None)
monkeypatch.setattr(server, "_emit", lambda *args: emits.append(args))
monkeypatch.setattr(server, "_write_config_key", lambda path, value: None)
resp = server.handle_request(
{"id": "1", "method": "config.set", "params": {"session_id": "sid", "key": "personality", "value": "helpful"}}
)
assert resp["result"]["history_reset"] is True
assert resp["result"]["info"] == {"model": "x"}
assert session["history"] == []
assert session["history_version"] == 5
assert ("session.info", "sid", {"model": "x"}) in emits
def test_session_compress_uses_compress_helper(monkeypatch):
agent = types.SimpleNamespace()
server._sessions["sid"] = _session(agent=agent)
monkeypatch.setattr(server, "_compress_session_history", lambda session, focus_topic=None: (2, {"total": 42}))
monkeypatch.setattr(server, "_session_info", lambda _agent: {"model": "x"})
with patch("tui_gateway.server._emit") as emit:
resp = server.handle_request({"id": "1", "method": "session.compress", "params": {"session_id": "sid"}})
assert resp["result"]["removed"] == 2
assert resp["result"]["usage"]["total"] == 42
emit.assert_called_once_with("session.info", "sid", {"model": "x"})
def test_prompt_submit_sets_approval_session_key(monkeypatch):
from tools.approval import get_current_session_key
captured = {}
class _Agent:
def run_conversation(self, prompt, conversation_history=None, stream_callback=None):
captured["session_key"] = get_current_session_key(default="")
return {"final_response": "ok", "messages": [{"role": "assistant", "content": "ok"}]}
class _ImmediateThread:
def __init__(self, target=None, daemon=None):
self._target = target
def start(self):
self._target()
server._sessions["sid"] = _session(agent=_Agent())
monkeypatch.setattr(server.threading, "Thread", _ImmediateThread)
monkeypatch.setattr(server, "_emit", lambda *args, **kwargs: None)
monkeypatch.setattr(server, "make_stream_renderer", lambda cols: None)
monkeypatch.setattr(server, "render_message", lambda raw, cols: None)
resp = server.handle_request({"id": "1", "method": "prompt.submit", "params": {"session_id": "sid", "text": "ping"}})
assert resp["result"]["status"] == "streaming"
assert captured["session_key"] == "session-key"
def test_prompt_submit_expands_context_refs(monkeypatch):
captured = {}
class _Agent:
model = "test/model"
base_url = ""
api_key = ""
def run_conversation(self, prompt, conversation_history=None, stream_callback=None):
captured["prompt"] = prompt
return {"final_response": "ok", "messages": [{"role": "assistant", "content": "ok"}]}
class _ImmediateThread:
def __init__(self, target=None, daemon=None):
self._target = target
def start(self):
self._target()
fake_ctx = types.ModuleType("agent.context_references")
fake_ctx.preprocess_context_references = lambda message, **kwargs: types.SimpleNamespace(
blocked=False, message="expanded prompt", warnings=[], references=[], injected_tokens=0
)
fake_meta = types.ModuleType("agent.model_metadata")
fake_meta.get_model_context_length = lambda *args, **kwargs: 100000
server._sessions["sid"] = _session(agent=_Agent())
monkeypatch.setattr(server.threading, "Thread", _ImmediateThread)
monkeypatch.setattr(server, "_emit", lambda *args, **kwargs: None)
monkeypatch.setattr(server, "make_stream_renderer", lambda cols: None)
monkeypatch.setattr(server, "render_message", lambda raw, cols: None)
monkeypatch.setitem(sys.modules, "agent.context_references", fake_ctx)
monkeypatch.setitem(sys.modules, "agent.model_metadata", fake_meta)
server.handle_request({"id": "1", "method": "prompt.submit", "params": {"session_id": "sid", "text": "@diff"}})
assert captured["prompt"] == "expanded prompt"
def test_image_attach_appends_local_image(monkeypatch):
fake_cli = types.ModuleType("cli")
fake_cli._IMAGE_EXTENSIONS = {".png"}
fake_cli._split_path_input = lambda raw: (raw, "")
fake_cli._resolve_attachment_path = lambda raw: Path("/tmp/cat.png")
server._sessions["sid"] = _session()
monkeypatch.setitem(sys.modules, "cli", fake_cli)
resp = server.handle_request({"id": "1", "method": "image.attach", "params": {"session_id": "sid", "path": "/tmp/cat.png"}})
assert resp["result"]["attached"] is True
assert resp["result"]["name"] == "cat.png"
assert len(server._sessions["sid"]["attached_images"]) == 1
def test_command_dispatch_exec_nonzero_surfaces_error(monkeypatch):
monkeypatch.setattr(server, "_load_cfg", lambda: {"quick_commands": {"boom": {"type": "exec", "command": "boom"}}})
monkeypatch.setattr(
server.subprocess,
"run",
lambda *args, **kwargs: types.SimpleNamespace(returncode=1, stdout="", stderr="failed"),
)
resp = server.handle_request({"id": "1", "method": "command.dispatch", "params": {"name": "boom"}})
assert "error" in resp
assert "failed" in resp["error"]["message"]
def test_plugins_list_surfaces_loader_error(monkeypatch):
with patch("hermes_cli.plugins.get_plugin_manager", side_effect=Exception("boom")):
resp = server.handle_request({"id": "1", "method": "plugins.list", "params": {}})
assert "error" in resp
assert "boom" in resp["error"]["message"]
def test_complete_slash_surfaces_completer_error(monkeypatch):
with patch("hermes_cli.commands.SlashCommandCompleter", side_effect=Exception("no completer")):
resp = server.handle_request({"id": "1", "method": "complete.slash", "params": {"text": "/mo"}})
assert "error" in resp
assert "no completer" in resp["error"]["message"]
def test_input_detect_drop_attaches_image(monkeypatch):
fake_cli = types.ModuleType("cli")
fake_cli._detect_file_drop = lambda raw: {
"path": Path("/tmp/cat.png"),
"is_image": True,
"remainder": "",
}
server._sessions["sid"] = _session()
monkeypatch.setitem(sys.modules, "cli", fake_cli)
resp = server.handle_request(
{"id": "1", "method": "input.detect_drop", "params": {"session_id": "sid", "text": "/tmp/cat.png"}}
)
assert resp["result"]["matched"] is True
assert resp["result"]["is_image"] is True
assert resp["result"]["text"] == "[User attached image: cat.png]"
def test_rollback_restore_resolves_number_and_file_path():
calls = {}
class _Mgr:
enabled = True
def list_checkpoints(self, cwd):
return [{"hash": "aaa111"}, {"hash": "bbb222"}]
def restore(self, cwd, target, file_path=None):
calls["args"] = (cwd, target, file_path)
return {"success": True, "message": "done"}
server._sessions["sid"] = _session(agent=types.SimpleNamespace(_checkpoint_mgr=_Mgr()), history=[])
resp = server.handle_request(
{
"id": "1",
"method": "rollback.restore",
"params": {"session_id": "sid", "hash": "2", "file_path": "src/app.tsx"},
}
)
assert resp["result"]["success"] is True
assert calls["args"][1] == "bbb222"
assert calls["args"][2] == "src/app.tsx"

View file

@ -0,0 +1,199 @@
"""Accretion caps for _read_tracker (file_tools) and _completion_consumed
(process_registry).
Both structures are process-lifetime singletons that previously grew
unbounded in long-running CLI / gateway sessions:
file_tools._read_tracker[task_id]
read_history (set) one entry per unique (path, offset, limit)
dedup (dict) one entry per unique (path, offset, limit)
read_timestamps (dict) one entry per unique resolved path
process_registry._completion_consumed (set) one entry per session_id
ever polled / waited / logged
None of these were ever trimmed. A 10k-read CLI session accumulated
roughly 1.5MB of tracker state; a gateway with high background-process
churn accumulated ~20B per session_id until the process exited.
These tests pin the new caps + prune hooks.
"""
import pytest
class TestReadTrackerCaps:
def setup_method(self):
from tools import file_tools
# Clean slate per test.
with file_tools._read_tracker_lock:
file_tools._read_tracker.clear()
def test_read_history_capped(self, monkeypatch):
"""read_history set is bounded by _READ_HISTORY_CAP."""
from tools import file_tools as ft
monkeypatch.setattr(ft, "_READ_HISTORY_CAP", 10)
task_data = {
"last_key": None,
"consecutive": 0,
"read_history": set((f"/p{i}", 0, 500) for i in range(50)),
"dedup": {},
"read_timestamps": {},
}
ft._cap_read_tracker_data(task_data)
assert len(task_data["read_history"]) == 10
def test_dedup_capped_oldest_first(self, monkeypatch):
"""dedup dict is bounded; oldest entries evicted first."""
from tools import file_tools as ft
monkeypatch.setattr(ft, "_DEDUP_CAP", 5)
task_data = {
"read_history": set(),
"dedup": {(f"/p{i}", 0, 500): float(i) for i in range(20)},
"read_timestamps": {},
}
ft._cap_read_tracker_data(task_data)
assert len(task_data["dedup"]) == 5
# Entries 15-19 (inserted last) should survive.
assert ("/p19", 0, 500) in task_data["dedup"]
assert ("/p15", 0, 500) in task_data["dedup"]
# Entries 0-14 should be evicted.
assert ("/p0", 0, 500) not in task_data["dedup"]
assert ("/p14", 0, 500) not in task_data["dedup"]
def test_read_timestamps_capped_oldest_first(self, monkeypatch):
"""read_timestamps dict is bounded; oldest entries evicted first."""
from tools import file_tools as ft
monkeypatch.setattr(ft, "_READ_TIMESTAMPS_CAP", 3)
task_data = {
"read_history": set(),
"dedup": {},
"read_timestamps": {f"/path/{i}": float(i) for i in range(10)},
}
ft._cap_read_tracker_data(task_data)
assert len(task_data["read_timestamps"]) == 3
assert "/path/9" in task_data["read_timestamps"]
assert "/path/7" in task_data["read_timestamps"]
assert "/path/0" not in task_data["read_timestamps"]
def test_cap_is_idempotent_under_cap(self, monkeypatch):
"""When containers are under cap, _cap_read_tracker_data is a no-op."""
from tools import file_tools as ft
monkeypatch.setattr(ft, "_READ_HISTORY_CAP", 100)
monkeypatch.setattr(ft, "_DEDUP_CAP", 100)
monkeypatch.setattr(ft, "_READ_TIMESTAMPS_CAP", 100)
task_data = {
"read_history": {("/a", 0, 500), ("/b", 0, 500)},
"dedup": {("/a", 0, 500): 1.0},
"read_timestamps": {"/a": 1.0},
}
rh_before = set(task_data["read_history"])
dedup_before = dict(task_data["dedup"])
ts_before = dict(task_data["read_timestamps"])
ft._cap_read_tracker_data(task_data)
assert task_data["read_history"] == rh_before
assert task_data["dedup"] == dedup_before
assert task_data["read_timestamps"] == ts_before
def test_cap_handles_missing_containers(self):
"""Missing sub-keys don't cause AttributeError."""
from tools import file_tools as ft
ft._cap_read_tracker_data({}) # no containers at all
ft._cap_read_tracker_data({"read_history": None})
ft._cap_read_tracker_data({"dedup": None})
def test_live_cap_applied_after_read_add(self, tmp_path, monkeypatch):
"""Live read_file path enforces caps."""
from tools import file_tools as ft
monkeypatch.setattr(ft, "_READ_HISTORY_CAP", 3)
monkeypatch.setattr(ft, "_DEDUP_CAP", 3)
monkeypatch.setattr(ft, "_READ_TIMESTAMPS_CAP", 3)
# Create 10 distinct files and read each once.
for i in range(10):
p = tmp_path / f"file_{i}.txt"
p.write_text(f"content {i}\n" * 10)
ft.read_file_tool(path=str(p), task_id="long-session")
with ft._read_tracker_lock:
td = ft._read_tracker["long-session"]
assert len(td["read_history"]) <= 3
assert len(td["dedup"]) <= 3
assert len(td["read_timestamps"]) <= 3
class TestCompletionConsumedPrune:
def test_prune_drops_completion_entry_with_expired_session(self):
"""When a finished session is pruned, _completion_consumed is
cleared for the same session_id."""
from tools.process_registry import ProcessRegistry, FINISHED_TTL_SECONDS
import time
reg = ProcessRegistry()
# Fake a finished session whose started_at is older than the TTL.
class _FakeSess:
def __init__(self, sid):
self.id = sid
self.started_at = time.time() - (FINISHED_TTL_SECONDS + 100)
self.exited = True
reg._finished["stale-1"] = _FakeSess("stale-1")
reg._completion_consumed.add("stale-1")
with reg._lock:
reg._prune_if_needed()
assert "stale-1" not in reg._finished
assert "stale-1" not in reg._completion_consumed
def test_prune_drops_completion_entry_for_lru_evicted(self):
"""Same contract for the LRU path (over MAX_PROCESSES)."""
from tools import process_registry as pr
import time
reg = pr.ProcessRegistry()
class _FakeSess:
def __init__(self, sid, started):
self.id = sid
self.started_at = started
self.exited = True
# Fill above MAX_PROCESSES with recently-finished sessions.
now = time.time()
for i in range(pr.MAX_PROCESSES + 5):
sid = f"sess-{i}"
reg._finished[sid] = _FakeSess(sid, now - i) # sess-0 newest
reg._completion_consumed.add(sid)
with reg._lock:
# _prune_if_needed removes one oldest finished per invocation;
# call it enough times to trim back down.
for _ in range(10):
reg._prune_if_needed()
# The _completion_consumed set should not contain session IDs that
# are no longer in _running or _finished.
assert (reg._completion_consumed - (reg._running.keys() | reg._finished.keys())) == set()
def test_prune_clears_dangling_completion_entries(self):
"""Stale entries in _completion_consumed without a backing session
record are cleared out (belt-and-suspenders invariant)."""
from tools.process_registry import ProcessRegistry
reg = ProcessRegistry()
# Add a dangling entry that was never in _running or _finished.
reg._completion_consumed.add("dangling-never-tracked")
with reg._lock:
reg._prune_if_needed()
assert "dangling-never-tracked" not in reg._completion_consumed

View file

@ -2,11 +2,13 @@
import ast
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch as mock_patch
import tools.approval as approval_module
from tools.approval import (
_get_approval_mode,
_smart_approve,
approve_session,
detect_dangerous_command,
is_approved,
@ -26,6 +28,21 @@ class TestApprovalModeParsing:
assert _get_approval_mode() == "off"
class TestSmartApproval:
def test_smart_approval_uses_call_llm(self):
response = SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content="APPROVE"))]
)
with mock_patch("agent.auxiliary_client.call_llm", return_value=response) as mock_call:
result = _smart_approve("python -c \"print('hello')\"", "script execution via -c flag")
assert result == "approve"
mock_call.assert_called_once()
assert mock_call.call_args.kwargs["task"] == "approval"
assert mock_call.call_args.kwargs["temperature"] == 0
assert mock_call.call_args.kwargs["max_tokens"] == 16
class TestDetectDangerousRm:
def test_rm_rf_detected(self):
is_dangerous, key, desc = detect_dangerous_command("rm -rf /home/user")
@ -820,4 +837,3 @@ class TestChmodExecuteCombo:
dangerous, _, _ = detect_dangerous_command(cmd)
assert dangerous is False

View file

@ -77,3 +77,42 @@ class TestResolveCdpOverride:
"https://cdp.browser-use.example/session/json/version",
timeout=10,
)
class TestGetCdpOverride:
def test_prefers_env_var_over_config(self, monkeypatch):
import tools.browser_tool as browser_tool
monkeypatch.setenv("BROWSER_CDP_URL", HTTP_URL)
monkeypatch.setattr(
browser_tool,
"read_raw_config",
lambda: {"browser": {"cdp_url": "http://config-host:9222"}},
raising=False,
)
response = Mock()
response.raise_for_status.return_value = None
response.json.return_value = {"webSocketDebuggerUrl": WS_URL}
with patch("tools.browser_tool.requests.get", return_value=response) as mock_get:
resolved = browser_tool._get_cdp_override()
assert resolved == WS_URL
mock_get.assert_called_once_with(VERSION_URL, timeout=10)
def test_uses_config_browser_cdp_url_when_env_missing(self, monkeypatch):
import tools.browser_tool as browser_tool
monkeypatch.delenv("BROWSER_CDP_URL", raising=False)
response = Mock()
response.raise_for_status.return_value = None
response.json.return_value = {"webSocketDebuggerUrl": WS_URL}
with patch("hermes_cli.config.read_raw_config", return_value={"browser": {"cdp_url": HTTP_URL}}), \
patch("tools.browser_tool.requests.get", return_value=response) as mock_get:
resolved = browser_tool._get_cdp_override()
assert resolved == WS_URL
mock_get.assert_called_once_with(VERSION_URL, timeout=10)

View file

@ -28,12 +28,22 @@ def _isolate_sessions():
bt._active_sessions.update(orig)
def _make_socket_dir(tmpdir, session_name, pid=None):
"""Create a fake agent-browser socket directory with optional PID file."""
def _make_socket_dir(tmpdir, session_name, pid=None, owner_pid=None):
"""Create a fake agent-browser socket directory with optional PID files.
Args:
tmpdir: base temp directory
session_name: name like "h_abc1234567" or "cdp_abc1234567"
pid: daemon PID to write to <session>.pid (None = no file)
owner_pid: owning hermes PID to write to <session>.owner_pid
(None = no file; tests the legacy path)
"""
d = tmpdir / f"agent-browser-{session_name}"
d.mkdir()
if pid is not None:
(d / f"{session_name}.pid").write_text(str(pid))
if owner_pid is not None:
(d / f"{session_name}.owner_pid").write_text(str(owner_pid))
return d
@ -62,7 +72,10 @@ class TestReapOrphanedBrowserSessions:
assert not d.exists()
def test_orphaned_alive_daemon_is_killed(self, fake_tmpdir):
"""Alive daemon not tracked by _active_sessions gets SIGTERM."""
"""Alive daemon not tracked by _active_sessions gets SIGTERM (legacy path).
No owner_pid file => falls back to tracked_names check.
"""
from tools.browser_tool import _reap_orphaned_browser_sessions
d = _make_socket_dir(fake_tmpdir, "h_orphan12345", pid=12345)
@ -84,7 +97,7 @@ class TestReapOrphanedBrowserSessions:
assert (12345, signal.SIGTERM) in kill_calls
def test_tracked_session_is_not_reaped(self, fake_tmpdir):
"""Sessions tracked in _active_sessions are left alone."""
"""Sessions tracked in _active_sessions are left alone (legacy path)."""
import tools.browser_tool as bt
from tools.browser_tool import _reap_orphaned_browser_sessions
@ -156,3 +169,240 @@ class TestReapOrphanedBrowserSessions:
_reap_orphaned_browser_sessions()
assert not d.exists()
class TestOwnerPidCrossProcess:
"""Tests for owner_pid-based cross-process safe reaping.
The owner_pid file records which hermes process owns a daemon so that
concurrent hermes processes don't reap each other's active browser
sessions. Added to fix orphan accumulation from crashed processes.
"""
def test_alive_owner_is_not_reaped_even_when_untracked(self, fake_tmpdir):
"""Daemon with alive owner_pid is NOT reaped, even if not in our _active_sessions.
This is the core cross-process safety check: Process B scanning while
Process A is using a browser must not kill A's daemon.
"""
from tools.browser_tool import _reap_orphaned_browser_sessions
# Use our own PID as the "owner" — guaranteed alive
d = _make_socket_dir(
fake_tmpdir, "h_alive_owner", pid=12345, owner_pid=os.getpid()
)
kill_calls = []
def mock_kill(pid, sig):
kill_calls.append((pid, sig))
if pid == os.getpid() and sig == 0:
return # real existence check: owner alive
if sig == 0:
return # pretend daemon exists too
# Don't actually kill anything
with patch("os.kill", side_effect=mock_kill):
_reap_orphaned_browser_sessions()
# We should have checked the owner (sig 0) but never tried to kill
# the daemon.
assert (12345, signal.SIGTERM) not in kill_calls
# Dir should still exist
assert d.exists()
def test_dead_owner_triggers_reap(self, fake_tmpdir):
"""Daemon whose owner_pid is dead gets reaped."""
from tools.browser_tool import _reap_orphaned_browser_sessions
# PID 999999999 almost certainly doesn't exist
d = _make_socket_dir(
fake_tmpdir, "h_dead_owner1", pid=12345, owner_pid=999999999
)
kill_calls = []
def mock_kill(pid, sig):
kill_calls.append((pid, sig))
if pid == 999999999 and sig == 0:
raise ProcessLookupError # owner dead
if pid == 12345 and sig == 0:
return # daemon still alive
# SIGTERM to daemon — noop in test
with patch("os.kill", side_effect=mock_kill):
_reap_orphaned_browser_sessions()
# Owner checked (returned dead), daemon checked (alive), daemon killed
assert (999999999, 0) in kill_calls
assert (12345, 0) in kill_calls
assert (12345, signal.SIGTERM) in kill_calls
# Dir cleaned up
assert not d.exists()
def test_corrupt_owner_pid_falls_back_to_legacy(self, fake_tmpdir):
"""Corrupt owner_pid file → fall back to tracked_names check."""
import tools.browser_tool as bt
from tools.browser_tool import _reap_orphaned_browser_sessions
session_name = "h_corrupt_own"
d = _make_socket_dir(fake_tmpdir, session_name, pid=12345)
# Write garbage to owner_pid file
(d / f"{session_name}.owner_pid").write_text("not-a-pid")
# Register session so legacy fallback leaves it alone
bt._active_sessions["task"] = {"session_name": session_name}
kill_calls = []
def mock_kill(pid, sig):
kill_calls.append((pid, sig))
with patch("os.kill", side_effect=mock_kill):
_reap_orphaned_browser_sessions()
# Legacy path took over → tracked → not reaped
assert (12345, signal.SIGTERM) not in kill_calls
assert d.exists()
def test_owner_pid_permission_error_treated_as_alive(self, fake_tmpdir):
"""If os.kill(owner, 0) raises PermissionError, treat owner as alive.
PermissionError means the PID exists but is owned by a different user
we must not assume the owner is dead (could kill someone else's daemon).
"""
from tools.browser_tool import _reap_orphaned_browser_sessions
d = _make_socket_dir(
fake_tmpdir, "h_perm_owner1", pid=12345, owner_pid=22222
)
kill_calls = []
def mock_kill(pid, sig):
kill_calls.append((pid, sig))
if pid == 22222 and sig == 0:
raise PermissionError("not our user")
with patch("os.kill", side_effect=mock_kill):
_reap_orphaned_browser_sessions()
# Must NOT have tried to kill the daemon
assert (12345, signal.SIGTERM) not in kill_calls
assert d.exists()
def test_write_owner_pid_creates_file_with_current_pid(
self, fake_tmpdir, monkeypatch
):
"""_write_owner_pid(dir, session) writes <session>.owner_pid with os.getpid()."""
import tools.browser_tool as bt
session_name = "h_ownertest01"
socket_dir = fake_tmpdir / f"agent-browser-{session_name}"
socket_dir.mkdir()
bt._write_owner_pid(str(socket_dir), session_name)
owner_pid_file = socket_dir / f"{session_name}.owner_pid"
assert owner_pid_file.exists()
assert owner_pid_file.read_text().strip() == str(os.getpid())
def test_write_owner_pid_is_idempotent(self, fake_tmpdir):
"""Calling _write_owner_pid twice leaves a single owner_pid file."""
import tools.browser_tool as bt
session_name = "h_idempot1234"
socket_dir = fake_tmpdir / f"agent-browser-{session_name}"
socket_dir.mkdir()
bt._write_owner_pid(str(socket_dir), session_name)
bt._write_owner_pid(str(socket_dir), session_name)
files = list(socket_dir.glob("*.owner_pid"))
assert len(files) == 1
assert files[0].read_text().strip() == str(os.getpid())
def test_write_owner_pid_swallows_oserror(self, fake_tmpdir, monkeypatch):
"""OSError (e.g. permission denied) doesn't propagate — the reaper
falls back to the legacy tracked_names heuristic in that case.
"""
import tools.browser_tool as bt
def raise_oserror(*a, **kw):
raise OSError("permission denied")
monkeypatch.setattr("builtins.open", raise_oserror)
# Must not raise
bt._write_owner_pid(str(fake_tmpdir), "h_readonly123")
def test_run_browser_command_calls_write_owner_pid(
self, fake_tmpdir, monkeypatch
):
"""_run_browser_command wires _write_owner_pid after mkdir."""
import tools.browser_tool as bt
session_name = "h_wiringtest1"
# Short-circuit Popen so we exit after the owner_pid write
class _FakePopen:
def __init__(self, *a, **kw):
raise RuntimeError("short-circuit after owner_pid")
monkeypatch.setattr(bt.subprocess, "Popen", _FakePopen)
monkeypatch.setattr(bt, "_find_agent_browser", lambda: "/bin/true")
monkeypatch.setattr(
bt, "_requires_real_termux_browser_install", lambda *a: False
)
monkeypatch.setattr(
bt, "_get_session_info",
lambda task_id: {"session_name": session_name},
)
calls = []
orig_write = bt._write_owner_pid
def _spy(*a, **kw):
calls.append(a)
orig_write(*a, **kw)
monkeypatch.setattr(bt, "_write_owner_pid", _spy)
with patch("tools.browser_tool._socket_safe_tmpdir", return_value=str(fake_tmpdir)):
try:
bt._run_browser_command(task_id="test_task", command="goto", args=[])
except Exception:
pass
assert calls, "_run_browser_command must call _write_owner_pid"
# First positional arg is the socket_dir, second is the session_name
socket_dir_arg, session_name_arg = calls[0][0], calls[0][1]
assert session_name_arg == session_name
assert session_name in socket_dir_arg
class TestEmergencyCleanupRunsReaper:
"""Verify atexit-registered cleanup sweeps orphans even without an active session."""
def test_emergency_cleanup_calls_reaper(self, fake_tmpdir, monkeypatch):
"""_emergency_cleanup_all_sessions must call _reap_orphaned_browser_sessions."""
import tools.browser_tool as bt
# Reset the _cleanup_done flag so the cleanup actually runs
monkeypatch.setattr(bt, "_cleanup_done", False)
reaper_called = []
orig_reaper = bt._reap_orphaned_browser_sessions
def _spy_reaper():
reaper_called.append(True)
orig_reaper()
monkeypatch.setattr(bt, "_reap_orphaned_browser_sessions", _spy_reaper)
# No active sessions — reaper should still run
bt._emergency_cleanup_all_sessions()
assert reaper_called, (
"Reaper must run on exit even with no active sessions"
)

View file

@ -250,6 +250,15 @@ class TestWslHasImage:
mock_run.return_value = MagicMock(stdout="False\n", returncode=0)
assert _wsl_has_image() is False
def test_falls_back_to_get_clipboard_image(self):
with patch("hermes_cli.clipboard.subprocess.run") as mock_run:
mock_run.side_effect = [
MagicMock(stdout="False\n", returncode=0),
MagicMock(stdout="True\n", returncode=0),
]
assert _wsl_has_image() is True
assert mock_run.call_count == 2
def test_powershell_not_found(self):
with patch("hermes_cli.clipboard.subprocess.run", side_effect=FileNotFoundError):
assert _wsl_has_image() is False
@ -269,6 +278,18 @@ class TestWslSave:
assert _wsl_save(dest) is True
assert dest.read_bytes() == FAKE_PNG
def test_falls_back_to_get_clipboard_extraction(self, tmp_path):
dest = tmp_path / "out.png"
b64_png = base64.b64encode(FAKE_PNG).decode()
with patch("hermes_cli.clipboard.subprocess.run") as mock_run:
mock_run.side_effect = [
MagicMock(stdout="", returncode=1),
MagicMock(stdout=b64_png + "\n", returncode=0),
]
assert _wsl_save(dest) is True
assert mock_run.call_count == 2
assert dest.read_bytes() == FAKE_PNG
def test_no_image_returns_false(self, tmp_path):
dest = tmp_path / "out.png"
with patch("hermes_cli.clipboard.subprocess.run") as mock_run:
@ -528,6 +549,16 @@ class TestWindowsHasImage:
mock_run.return_value = MagicMock(stdout="False\n", returncode=0)
assert _windows_has_image() is False
def test_falls_back_to_get_clipboard_image(self):
with patch("hermes_cli.clipboard._get_ps_exe", return_value="powershell"):
with patch("hermes_cli.clipboard.subprocess.run") as mock_run:
mock_run.side_effect = [
MagicMock(stdout="False\n", returncode=0),
MagicMock(stdout="True\n", returncode=0),
]
assert _windows_has_image() is True
assert mock_run.call_count == 2
def test_no_powershell_available(self):
with patch("hermes_cli.clipboard._get_ps_exe", return_value=None):
assert _windows_has_image() is False
@ -559,6 +590,20 @@ class TestWindowsSave:
assert _windows_save(dest) is True
assert dest.read_bytes() == FAKE_PNG
def test_falls_back_to_filedrop_image(self, tmp_path):
dest = tmp_path / "out.png"
b64_png = base64.b64encode(FAKE_PNG).decode()
with patch("hermes_cli.clipboard._get_ps_exe", return_value="powershell"):
with patch("hermes_cli.clipboard.subprocess.run") as mock_run:
mock_run.side_effect = [
MagicMock(stdout="", returncode=1),
MagicMock(stdout="", returncode=1),
MagicMock(stdout=b64_png + "\n", returncode=0),
]
assert _windows_save(dest) is True
assert mock_run.call_count == 3
assert dest.read_bytes() == FAKE_PNG
def test_no_image_returns_false(self, tmp_path):
dest = tmp_path / "out.png"
with patch("hermes_cli.clipboard._get_ps_exe", return_value="powershell"):
@ -734,6 +779,18 @@ class TestHasClipboardImage:
assert has_clipboard_image() is True
m.assert_called_once()
def test_wsl_falls_through_to_wayland_when_windows_path_empty(self):
"""WSLg often bridges images to wl-paste even when powershell.exe check fails."""
with patch("hermes_cli.clipboard.sys") as mock_sys:
mock_sys.platform = "linux"
with patch("hermes_cli.clipboard._is_wsl", return_value=True):
with patch("hermes_cli.clipboard._wsl_has_image", return_value=False) as wsl:
with patch.dict(os.environ, {"WAYLAND_DISPLAY": "wayland-0"}):
with patch("hermes_cli.clipboard._wayland_has_image", return_value=True) as wl:
assert has_clipboard_image() is True
wsl.assert_called_once()
wl.assert_called_once()
def test_linux_wayland_dispatch(self):
with patch("hermes_cli.clipboard.sys") as mock_sys:
mock_sys.platform = "linux"

View file

@ -0,0 +1,62 @@
"""Tests for feishu_doc_tool and feishu_drive_tool — registration and schema validation."""
import importlib
import unittest
from tools.registry import registry
# Trigger tool discovery so feishu tools get registered
importlib.import_module("tools.feishu_doc_tool")
importlib.import_module("tools.feishu_drive_tool")
class TestFeishuToolRegistration(unittest.TestCase):
"""Verify feishu tools are registered and have valid schemas."""
EXPECTED_TOOLS = {
"feishu_doc_read": "feishu_doc",
"feishu_drive_list_comments": "feishu_drive",
"feishu_drive_list_comment_replies": "feishu_drive",
"feishu_drive_reply_comment": "feishu_drive",
"feishu_drive_add_comment": "feishu_drive",
}
def test_all_tools_registered(self):
for tool_name, toolset in self.EXPECTED_TOOLS.items():
entry = registry.get_entry(tool_name)
self.assertIsNotNone(entry, f"{tool_name} not registered")
self.assertEqual(entry.toolset, toolset)
def test_schemas_have_required_fields(self):
for tool_name in self.EXPECTED_TOOLS:
entry = registry.get_entry(tool_name)
schema = entry.schema
self.assertIn("name", schema)
self.assertEqual(schema["name"], tool_name)
self.assertIn("description", schema)
self.assertIn("parameters", schema)
self.assertIn("type", schema["parameters"])
self.assertEqual(schema["parameters"]["type"], "object")
def test_handlers_are_callable(self):
for tool_name in self.EXPECTED_TOOLS:
entry = registry.get_entry(tool_name)
self.assertTrue(callable(entry.handler))
def test_doc_read_schema_params(self):
entry = registry.get_entry("feishu_doc_read")
props = entry.schema["parameters"].get("properties", {})
self.assertIn("doc_token", props)
def test_drive_tools_require_file_token(self):
for tool_name in self.EXPECTED_TOOLS:
if tool_name == "feishu_doc_read":
continue
entry = registry.get_entry(tool_name)
props = entry.schema["parameters"].get("properties", {})
self.assertIn("file_token", props, f"{tool_name} missing file_token param")
self.assertIn("file_type", props, f"{tool_name} missing file_type param")
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,178 @@
"""Regression tests for cwd-staleness in ShellFileOperations.
The bug: ShellFileOperations captured the terminal env's cwd at __init__
time and used that stale value for every subsequent _exec() call. When
a user ran ``cd`` via the terminal tool, ``env.cwd`` updated but
``ops.cwd`` did not. Relative paths passed to patch/read/write/search
then targeted the wrong directory typically the session's start dir
instead of the current working directory.
Observed symptom: patch_replace() returned ``success=True`` with a
plausible diff, but the user's ``git diff`` showed no change (because
the patch landed in a different directory's copy of the same file).
Fix: _exec() now prefers the LIVE ``env.cwd`` over the init-time
``self.cwd``. Explicit ``cwd`` arg to _exec still wins over both.
"""
from __future__ import annotations
import os
import tempfile
import pytest
from tools.file_operations import ShellFileOperations
class _FakeEnv:
"""Minimal terminal env that tracks cwd across execute() calls.
Matches the real ``BaseEnvironment`` contract: ``cwd`` attribute plus
an ``execute(command, cwd=...)`` method whose return dict carries
``output`` and ``returncode``. Commands are executed in a real
subdirectory so file system effects match production.
"""
def __init__(self, start_cwd: str):
self.cwd = start_cwd
self.calls: list[dict] = []
def execute(self, command: str, cwd: str = None, **kwargs) -> dict:
import subprocess
self.calls.append({"command": command, "cwd": cwd})
# Simulate cd by updating self.cwd (the real env does the same
# via _extract_cwd_from_output after a successful command)
if command.strip().startswith("cd "):
new = command.strip()[3:].strip()
self.cwd = new
return {"output": "", "returncode": 0}
# Actually run the command — handle stdin via subprocess
stdin_data = kwargs.get("stdin_data")
proc = subprocess.run(
["bash", "-c", command],
cwd=cwd or self.cwd,
input=stdin_data,
capture_output=True,
text=True,
)
return {
"output": proc.stdout + proc.stderr,
"returncode": proc.returncode,
}
class TestShellFileOpsCwdTracking:
"""_exec() must use live env.cwd, not the init-time cached cwd."""
def test_exec_follows_env_cwd_after_cd(self, tmp_path):
dir_a = tmp_path / "a"
dir_b = tmp_path / "b"
dir_a.mkdir()
dir_b.mkdir()
(dir_a / "target.txt").write_text("content-a\n")
(dir_b / "target.txt").write_text("content-b\n")
env = _FakeEnv(start_cwd=str(dir_a))
ops = ShellFileOperations(env, cwd=str(dir_a))
assert ops.cwd == str(dir_a) # init-time
# Simulate the user running `cd b` in terminal
env.execute(f"cd {dir_b}")
assert env.cwd == str(dir_b)
assert ops.cwd == str(dir_a), "ops.cwd is still init-time (fallback only)"
# Reading a relative path must now hit dir_b, not dir_a
result = ops._exec("cat target.txt")
assert result.exit_code == 0
assert "content-b" in result.stdout, (
f"Expected dir_b content, got {result.stdout!r}. "
"Stale ops.cwd leaked through — _exec must prefer env.cwd."
)
def test_patch_replace_targets_live_cwd_not_init_cwd(self, tmp_path):
"""The exact bug reported: patch lands in wrong dir after cd."""
dir_a = tmp_path / "main"
dir_b = tmp_path / "worktree"
dir_a.mkdir()
dir_b.mkdir()
(dir_a / "t.txt").write_text("shared text\n")
(dir_b / "t.txt").write_text("shared text\n")
env = _FakeEnv(start_cwd=str(dir_a))
ops = ShellFileOperations(env, cwd=str(dir_a))
# Emulate user cd'ing into the worktree
env.execute(f"cd {dir_b}")
assert env.cwd == str(dir_b)
# Patch with a RELATIVE path — must target the worktree, not main
result = ops.patch_replace("t.txt", "shared text\n", "PATCHED\n")
assert result.success is True
assert (dir_b / "t.txt").read_text() == "PATCHED\n", (
"patch must land in the live-cwd dir (worktree)"
)
assert (dir_a / "t.txt").read_text() == "shared text\n", (
"patch must NOT land in the init-time dir (main)"
)
def test_explicit_cwd_arg_still_wins(self, tmp_path):
"""An explicit cwd= arg to _exec must override both env.cwd and self.cwd."""
dir_a = tmp_path / "a"
dir_b = tmp_path / "b"
dir_c = tmp_path / "c"
for d in (dir_a, dir_b, dir_c):
d.mkdir()
(dir_a / "target.txt").write_text("from-a\n")
(dir_b / "target.txt").write_text("from-b\n")
(dir_c / "target.txt").write_text("from-c\n")
env = _FakeEnv(start_cwd=str(dir_a))
ops = ShellFileOperations(env, cwd=str(dir_a))
env.execute(f"cd {dir_b}")
# Explicit cwd=dir_c should win over env.cwd (dir_b) and self.cwd (dir_a)
result = ops._exec("cat target.txt", cwd=str(dir_c))
assert "from-c" in result.stdout
def test_env_without_cwd_attribute_falls_back_to_self_cwd(self, tmp_path):
"""Backends without a cwd attribute still work via init-time cwd."""
dir_a = tmp_path / "fixed"
dir_a.mkdir()
(dir_a / "target.txt").write_text("fixed-content\n")
class _NoCwdEnv:
def execute(self, command, cwd=None, **kwargs):
import subprocess
proc = subprocess.run(["bash", "-c", command], cwd=cwd,
capture_output=True, text=True)
return {"output": proc.stdout, "returncode": proc.returncode}
env = _NoCwdEnv()
ops = ShellFileOperations(env, cwd=str(dir_a))
result = ops._exec("cat target.txt")
assert result.exit_code == 0
assert "fixed-content" in result.stdout
def test_patch_returns_success_only_when_file_actually_written(self, tmp_path):
"""Safety rail: patch_replace success must reflect the real file state.
This test doesn't trigger the bug directly (it would require manual
corruption of the write), but it pins the invariant: when
patch_replace returns success=True, the file on disk matches the
intended content. If a future write_file change ever regresses,
this test catches it.
"""
target = tmp_path / "file.txt"
target.write_text("old content\n")
env = _FakeEnv(start_cwd=str(tmp_path))
ops = ShellFileOperations(env, cwd=str(tmp_path))
result = ops.patch_replace(str(target), "old content\n", "new content\n")
assert result.success is True
assert result.error is None
assert target.read_text() == "new content\n", (
"patch_replace claimed success but file wasn't written correctly"
)

View file

@ -86,6 +86,7 @@ class TestProviderEnvBlocklist:
"MINIMAX_API_KEY": "mm-key",
"MINIMAX_CN_API_KEY": "mmcn-key",
"DEEPSEEK_API_KEY": "deepseek-key",
"NVIDIA_API_KEY": "nvidia-key",
}
result_env = _run_with_env(extra_os_env=registry_vars)

View file

@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
from gateway.config import Platform
from tools.send_message_tool import (
_derive_forum_thread_name,
_parse_target_ref,
_send_discord,
_send_matrix_via_adapter,
@ -1234,3 +1235,419 @@ class TestSendMatrixUrlEncoding:
put_url = mock_session.put.call_args[0][0]
assert "%21HLOQwxYGgFPMPJUSNR%3Amatrix.org" in put_url
assert "!HLOQwxYGgFPMPJUSNR:matrix.org" not in put_url
# ---------------------------------------------------------------------------
# Tests for _derive_forum_thread_name
# ---------------------------------------------------------------------------
class TestDeriveForumThreadName:
def test_single_line_message(self):
assert _derive_forum_thread_name("Hello world") == "Hello world"
def test_multi_line_uses_first_line(self):
assert _derive_forum_thread_name("First line\nSecond line") == "First line"
def test_strips_markdown_heading(self):
assert _derive_forum_thread_name("## My Heading") == "My Heading"
def test_strips_multiple_hash_levels(self):
assert _derive_forum_thread_name("### Deep heading") == "Deep heading"
def test_empty_message_falls_back_to_default(self):
assert _derive_forum_thread_name("") == "New Post"
def test_whitespace_only_falls_back(self):
assert _derive_forum_thread_name(" \n ") == "New Post"
def test_hash_only_falls_back(self):
assert _derive_forum_thread_name("###") == "New Post"
def test_truncates_to_100_chars(self):
long_title = "A" * 200
result = _derive_forum_thread_name(long_title)
assert len(result) == 100
def test_strips_whitespace_around_first_line(self):
assert _derive_forum_thread_name(" Title \nBody") == "Title"
# ---------------------------------------------------------------------------
# Tests for _send_discord with forum channel support
# ---------------------------------------------------------------------------
class TestSendDiscordForum:
"""_send_discord creates thread posts for forum channels."""
@staticmethod
def _build_mock(response_status, response_data=None, response_text="error body"):
mock_resp = MagicMock()
mock_resp.status = response_status
mock_resp.json = AsyncMock(return_value=response_data or {})
mock_resp.text = AsyncMock(return_value=response_text)
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=None)
mock_session = MagicMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=None)
mock_session.post = MagicMock(return_value=mock_resp)
mock_session.get = MagicMock(return_value=mock_resp)
return mock_session, mock_resp
def test_directory_forum_creates_thread(self):
"""Directory says 'forum' — creates a thread post."""
thread_data = {
"id": "t123",
"message": {"id": "m456"},
}
mock_session, _ = self._build_mock(200, response_data=thread_data)
with patch("aiohttp.ClientSession", return_value=mock_session), \
patch("gateway.channel_directory.lookup_channel_type", return_value="forum"):
result = asyncio.run(
_send_discord("tok", "forum_ch", "Hello forum")
)
assert result["success"] is True
assert result["thread_id"] == "t123"
assert result["message_id"] == "m456"
# Should POST to threads endpoint, not messages
call_url = mock_session.post.call_args.args[0]
assert "/threads" in call_url
assert "/messages" not in call_url
def test_directory_forum_skips_probe(self):
"""When directory says 'forum', no GET probe is made."""
thread_data = {"id": "t123", "message": {"id": "m456"}}
mock_session, _ = self._build_mock(200, response_data=thread_data)
with patch("aiohttp.ClientSession", return_value=mock_session), \
patch("gateway.channel_directory.lookup_channel_type", return_value="forum"):
asyncio.run(
_send_discord("tok", "forum_ch", "Hello")
)
# get() should never be called — directory resolved the type
mock_session.get.assert_not_called()
def test_directory_channel_skips_forum(self):
"""When directory says 'channel', sends via normal messages endpoint."""
mock_session, _ = self._build_mock(200, response_data={"id": "msg1"})
with patch("aiohttp.ClientSession", return_value=mock_session), \
patch("gateway.channel_directory.lookup_channel_type", return_value="channel"):
result = asyncio.run(
_send_discord("tok", "ch1", "Hello")
)
assert result["success"] is True
call_url = mock_session.post.call_args.args[0]
assert "/messages" in call_url
assert "/threads" not in call_url
def test_directory_none_probes_and_detects_forum(self):
"""When directory has no entry, probes GET /channels/{id} and detects type 15."""
probe_resp = MagicMock()
probe_resp.status = 200
probe_resp.json = AsyncMock(return_value={"type": 15})
probe_resp.__aenter__ = AsyncMock(return_value=probe_resp)
probe_resp.__aexit__ = AsyncMock(return_value=None)
thread_data = {"id": "t999", "message": {"id": "m888"}}
thread_resp = MagicMock()
thread_resp.status = 200
thread_resp.json = AsyncMock(return_value=thread_data)
thread_resp.text = AsyncMock(return_value="")
thread_resp.__aenter__ = AsyncMock(return_value=thread_resp)
thread_resp.__aexit__ = AsyncMock(return_value=None)
probe_session = MagicMock()
probe_session.__aenter__ = AsyncMock(return_value=probe_session)
probe_session.__aexit__ = AsyncMock(return_value=None)
probe_session.get = MagicMock(return_value=probe_resp)
thread_session = MagicMock()
thread_session.__aenter__ = AsyncMock(return_value=thread_session)
thread_session.__aexit__ = AsyncMock(return_value=None)
thread_session.post = MagicMock(return_value=thread_resp)
session_iter = iter([probe_session, thread_session])
with patch("aiohttp.ClientSession", side_effect=lambda **kw: next(session_iter)), \
patch("gateway.channel_directory.lookup_channel_type", return_value=None):
result = asyncio.run(
_send_discord("tok", "forum_ch", "Hello probe")
)
assert result["success"] is True
assert result["thread_id"] == "t999"
def test_directory_lookup_exception_falls_through_to_probe(self):
"""When lookup_channel_type raises, falls through to API probe."""
mock_session, _ = self._build_mock(200, response_data={"id": "msg1"})
with patch("aiohttp.ClientSession", return_value=mock_session), \
patch("gateway.channel_directory.lookup_channel_type", side_effect=Exception("io error")):
result = asyncio.run(
_send_discord("tok", "ch1", "Hello")
)
assert result["success"] is True
# Falls through to probe (GET)
mock_session.get.assert_called_once()
def test_forum_thread_creation_error(self):
"""Forum thread creation returning non-200/201 returns an error dict."""
mock_session, _ = self._build_mock(403, response_text="Forbidden")
with patch("aiohttp.ClientSession", return_value=mock_session), \
patch("gateway.channel_directory.lookup_channel_type", return_value="forum"):
result = asyncio.run(
_send_discord("tok", "forum_ch", "Hello")
)
assert "error" in result
assert "403" in result["error"]
class TestSendToPlatformDiscordForum:
"""_send_to_platform delegates forum detection to _send_discord."""
def test_send_to_platform_discord_delegates_to_send_discord(self):
"""Discord messages are routed through _send_discord, which handles forum detection."""
send_mock = AsyncMock(return_value={"success": True, "message_id": "1"})
with patch("tools.send_message_tool._send_discord", send_mock):
result = asyncio.run(
_send_to_platform(
Platform.DISCORD,
SimpleNamespace(enabled=True, token="tok", extra={}),
"forum_ch",
"Hello forum",
)
)
assert result["success"] is True
send_mock.assert_awaited_once_with(
"tok", "forum_ch", "Hello forum", media_files=[], thread_id=None,
)
def test_send_to_platform_discord_with_thread_id(self):
"""Thread ID is still passed through when sending to Discord."""
send_mock = AsyncMock(return_value={"success": True, "message_id": "1"})
with patch("tools.send_message_tool._send_discord", send_mock):
result = asyncio.run(
_send_to_platform(
Platform.DISCORD,
SimpleNamespace(enabled=True, token="tok", extra={}),
"ch1",
"Hello thread",
thread_id="17585",
)
)
assert result["success"] is True
_, call_kwargs = send_mock.await_args
assert call_kwargs["thread_id"] == "17585"
# ---------------------------------------------------------------------------
# Tests for _send_discord forum + media multipart upload
# ---------------------------------------------------------------------------
class TestSendDiscordForumMedia:
"""_send_discord uploads media as part of the starter message when the target is a forum."""
@staticmethod
def _build_thread_resp(thread_id="th_999", msg_id="msg_500"):
resp = MagicMock()
resp.status = 201
resp.json = AsyncMock(return_value={"id": thread_id, "message": {"id": msg_id}})
resp.text = AsyncMock(return_value="")
resp.__aenter__ = AsyncMock(return_value=resp)
resp.__aexit__ = AsyncMock(return_value=None)
return resp
def test_forum_with_media_uses_multipart(self, tmp_path, monkeypatch):
"""Forum + media → single multipart POST to /threads carrying the starter + files."""
from tools import send_message_tool as smt
img = tmp_path / "photo.png"
img.write_bytes(b"\x89PNGbytes")
monkeypatch.setattr(smt, "lookup_channel_type", lambda p, cid: "forum", raising=False)
monkeypatch.setattr(
"gateway.channel_directory.lookup_channel_type", lambda p, cid: "forum"
)
thread_resp = self._build_thread_resp()
session = MagicMock()
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=None)
session.post = MagicMock(return_value=thread_resp)
post_calls = []
orig_post = session.post
def track_post(url, **kwargs):
post_calls.append({"url": url, "kwargs": kwargs})
return thread_resp
session.post = MagicMock(side_effect=track_post)
with patch("aiohttp.ClientSession", return_value=session):
result = asyncio.run(
_send_discord("tok", "forum_ch", "Thread title\nbody", media_files=[(str(img), False)])
)
assert result["success"] is True
assert result["thread_id"] == "th_999"
assert result["message_id"] == "msg_500"
# Exactly one POST — the combined thread-creation + attachments call
assert len(post_calls) == 1
assert post_calls[0]["url"].endswith("/threads")
# Multipart form, not JSON
assert post_calls[0]["kwargs"].get("data") is not None
assert post_calls[0]["kwargs"].get("json") is None
def test_forum_without_media_still_json_only(self, tmp_path, monkeypatch):
"""Forum + no media → JSON POST (no multipart overhead)."""
monkeypatch.setattr(
"gateway.channel_directory.lookup_channel_type", lambda p, cid: "forum"
)
thread_resp = self._build_thread_resp("t1", "m1")
session = MagicMock()
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=None)
post_calls = []
def track_post(url, **kwargs):
post_calls.append({"url": url, "kwargs": kwargs})
return thread_resp
session.post = MagicMock(side_effect=track_post)
with patch("aiohttp.ClientSession", return_value=session):
result = asyncio.run(_send_discord("tok", "forum_ch", "Hello forum"))
assert result["success"] is True
assert len(post_calls) == 1
# JSON path, no multipart
assert post_calls[0]["kwargs"].get("json") is not None
assert post_calls[0]["kwargs"].get("data") is None
def test_forum_missing_media_file_collected_as_warning(self, tmp_path, monkeypatch):
"""Missing media files produce warnings but the thread is still created."""
monkeypatch.setattr(
"gateway.channel_directory.lookup_channel_type", lambda p, cid: "forum"
)
thread_resp = self._build_thread_resp()
session = MagicMock()
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=None)
session.post = MagicMock(return_value=thread_resp)
with patch("aiohttp.ClientSession", return_value=session):
result = asyncio.run(
_send_discord(
"tok", "forum_ch", "hi",
media_files=[("/nonexistent/does-not-exist.png", False)],
)
)
assert result["success"] is True
assert "warnings" in result
assert any("not found" in w for w in result["warnings"])
# ---------------------------------------------------------------------------
# Tests for the process-local forum-probe cache
# ---------------------------------------------------------------------------
class TestForumProbeCache:
"""_DISCORD_CHANNEL_TYPE_PROBE_CACHE memoizes forum detection results."""
def setup_method(self):
from tools import send_message_tool as smt
smt._DISCORD_CHANNEL_TYPE_PROBE_CACHE.clear()
def test_cache_round_trip(self):
from tools.send_message_tool import (
_probe_is_forum_cached,
_remember_channel_is_forum,
)
assert _probe_is_forum_cached("xyz") is None
_remember_channel_is_forum("xyz", True)
assert _probe_is_forum_cached("xyz") is True
_remember_channel_is_forum("xyz", False)
assert _probe_is_forum_cached("xyz") is False
def test_probe_result_is_memoized(self, monkeypatch):
"""An API-probed channel type is cached so subsequent sends skip the probe."""
monkeypatch.setattr(
"gateway.channel_directory.lookup_channel_type", lambda p, cid: None
)
# First probe response: type=15 (forum)
probe_resp = MagicMock()
probe_resp.status = 200
probe_resp.json = AsyncMock(return_value={"type": 15})
probe_resp.__aenter__ = AsyncMock(return_value=probe_resp)
probe_resp.__aexit__ = AsyncMock(return_value=None)
thread_resp = MagicMock()
thread_resp.status = 201
thread_resp.json = AsyncMock(return_value={"id": "t1", "message": {"id": "m1"}})
thread_resp.__aenter__ = AsyncMock(return_value=thread_resp)
thread_resp.__aexit__ = AsyncMock(return_value=None)
probe_session = MagicMock()
probe_session.__aenter__ = AsyncMock(return_value=probe_session)
probe_session.__aexit__ = AsyncMock(return_value=None)
probe_session.get = MagicMock(return_value=probe_resp)
thread_session = MagicMock()
thread_session.__aenter__ = AsyncMock(return_value=thread_session)
thread_session.__aexit__ = AsyncMock(return_value=None)
thread_session.post = MagicMock(return_value=thread_resp)
# Two _send_discord calls: first does probe + thread-create; second should skip probe
from tools import send_message_tool as smt
sessions_created = []
def session_factory(**kwargs):
# Alternate: each new ClientSession() call returns a probe_session, thread_session pair
idx = len(sessions_created)
sessions_created.append(idx)
# Returns the same mocks; the real code opens a probe session then a thread session.
# Hand out probe_session if this is the first time called within _send_discord,
# otherwise thread_session.
if idx % 2 == 0:
return probe_session
return thread_session
with patch("aiohttp.ClientSession", side_effect=session_factory):
result1 = asyncio.run(_send_discord("tok", "ch1", "first"))
assert result1["success"] is True
assert smt._probe_is_forum_cached("ch1") is True
# Second call: cache hits, no new probe session needed. We need to only
# return thread_session now since probe is skipped.
sessions_created.clear()
with patch("aiohttp.ClientSession", return_value=thread_session):
result2 = asyncio.run(_send_discord("tok", "ch1", "second"))
assert result2["success"] is True
# Only one session opened (thread creation) — no probe session this time
# (verified by not raising from our side_effect exhaustion)

View file

View file

@ -0,0 +1,233 @@
"""Tests for tui_gateway JSON-RPC protocol plumbing."""
import io
import json
import sys
import threading
from unittest.mock import MagicMock, patch
import pytest
_original_stdout = sys.stdout
@pytest.fixture(autouse=True)
def _restore_stdout():
yield
sys.stdout = _original_stdout
@pytest.fixture()
def server():
with patch.dict("sys.modules", {
"hermes_constants": MagicMock(get_hermes_home=MagicMock(return_value="/tmp/hermes_test")),
"hermes_cli.env_loader": MagicMock(),
"hermes_cli.banner": MagicMock(),
"hermes_state": MagicMock(),
}):
import importlib
mod = importlib.import_module("tui_gateway.server")
yield mod
mod._sessions.clear()
mod._pending.clear()
mod._answers.clear()
mod._methods.clear()
importlib.reload(mod)
@pytest.fixture()
def capture(server):
"""Redirect server's real stdout to a StringIO and return (server, buf)."""
buf = io.StringIO()
server._real_stdout = buf
return server, buf
# ── JSON-RPC envelope ────────────────────────────────────────────────
def test_unknown_method(server):
resp = server.handle_request({"id": "1", "method": "bogus"})
assert resp["error"]["code"] == -32601
def test_ok_envelope(server):
assert server._ok("r1", {"x": 1}) == {
"jsonrpc": "2.0", "id": "r1", "result": {"x": 1},
}
def test_err_envelope(server):
assert server._err("r2", 4001, "nope") == {
"jsonrpc": "2.0", "id": "r2", "error": {"code": 4001, "message": "nope"},
}
# ── write_json ───────────────────────────────────────────────────────
def test_write_json(capture):
server, buf = capture
assert server.write_json({"test": True})
assert json.loads(buf.getvalue()) == {"test": True}
def test_write_json_broken_pipe(server):
class _Broken:
def write(self, _): raise BrokenPipeError
def flush(self): raise BrokenPipeError
server._real_stdout = _Broken()
assert server.write_json({"x": 1}) is False
# ── _emit ────────────────────────────────────────────────────────────
def test_emit_with_payload(capture):
server, buf = capture
server._emit("test.event", "s1", {"key": "val"})
msg = json.loads(buf.getvalue())
assert msg["method"] == "event"
assert msg["params"]["type"] == "test.event"
assert msg["params"]["session_id"] == "s1"
assert msg["params"]["payload"]["key"] == "val"
def test_emit_without_payload(capture):
server, buf = capture
server._emit("ping", "s2")
assert "payload" not in json.loads(buf.getvalue())["params"]
# ── Blocking prompt round-trip ───────────────────────────────────────
def test_block_and_respond(capture):
server, _ = capture
result = [None]
threading.Thread(
target=lambda: result.__setitem__(0, server._block("test.prompt", "s1", {"q": "?"}, timeout=5)),
).start()
for _ in range(100):
if server._pending:
break
threading.Event().wait(0.01)
rid = next(iter(server._pending))
server._answers[rid] = "my_answer"
server._pending[rid].set()
threading.Event().wait(0.1)
assert result[0] == "my_answer"
def test_clear_pending(server):
ev = threading.Event()
server._pending["r1"] = ev
server._clear_pending()
assert ev.is_set()
assert server._answers["r1"] == ""
# ── Session lookup ───────────────────────────────────────────────────
def test_sess_missing(server):
_, err = server._sess({"session_id": "nope"}, "r1")
assert err["error"]["code"] == 4001
def test_sess_found(server):
server._sessions["abc"] = {"agent": MagicMock()}
s, err = server._sess({"session_id": "abc"}, "r1")
assert s is not None
assert err is None
# ── session.resume payload ────────────────────────────────────────────
def test_session_resume_returns_hydrated_messages(server, monkeypatch):
class _DB:
def get_session(self, _sid):
return {"id": "20260409_010101_abc123"}
def get_session_by_title(self, _title):
return None
def reopen_session(self, _sid):
return None
def get_messages_as_conversation(self, _sid):
return [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "yo"},
{"role": "tool", "content": "searched"},
{"role": "assistant", "content": " "},
{"role": "assistant", "content": None},
{"role": "narrator", "content": "skip"},
]
monkeypatch.setattr(server, "_get_db", lambda: _DB())
monkeypatch.setattr(server, "_make_agent", lambda sid, key, session_id=None: object())
monkeypatch.setattr(server, "_init_session", lambda sid, key, agent, history, cols=80: None)
monkeypatch.setattr(server, "_session_info", lambda _agent: {"model": "test/model"})
resp = server.handle_request(
{
"id": "r1",
"method": "session.resume",
"params": {"session_id": "20260409_010101_abc123", "cols": 100},
}
)
assert "error" not in resp
assert resp["result"]["message_count"] == 3
assert resp["result"]["messages"] == [
{"role": "user", "text": "hello"},
{"role": "assistant", "text": "yo"},
{"role": "tool", "name": "tool", "context": ""},
]
# ── Config I/O ───────────────────────────────────────────────────────
def test_config_load_missing(server, tmp_path):
server._hermes_home = tmp_path
assert server._load_cfg() == {}
def test_config_roundtrip(server, tmp_path):
server._hermes_home = tmp_path
server._save_cfg({"model": "test/model"})
assert server._load_cfg()["model"] == "test/model"
# ── _cli_exec_blocked ────────────────────────────────────────────────
@pytest.mark.parametrize("argv", [
[],
["setup"],
["gateway"],
["sessions", "browse"],
["config", "edit"],
])
def test_cli_exec_blocked(server, argv):
assert server._cli_exec_blocked(argv) is not None
@pytest.mark.parametrize("argv", [
["version"],
["sessions", "list"],
])
def test_cli_exec_allowed(server, argv):
assert server._cli_exec_blocked(argv) is None

View file

@ -0,0 +1,67 @@
"""Tests for tui_gateway.render — rendering bridge fallback behavior."""
from unittest.mock import MagicMock, patch
from tui_gateway.render import make_stream_renderer, render_diff, render_message
def _stub_rich(mock_mod):
return patch.dict("sys.modules", {"agent.rich_output": mock_mod})
def _no_rich():
return patch.dict("sys.modules", {"agent.rich_output": None})
# ── render_message ───────────────────────────────────────────────────
def test_render_message_none_without_module():
with _no_rich():
assert render_message("hello") is None
def test_render_message_formatted():
mod = MagicMock()
mod.format_response.return_value = "<b>hi</b>"
with _stub_rich(mod):
assert render_message("hi", 100) == "<b>hi</b>"
def test_render_message_type_error_fallback():
mod = MagicMock()
mod.format_response.side_effect = [TypeError, "fallback"]
with _stub_rich(mod):
assert render_message("hi") == "fallback"
def test_render_message_exception_returns_none():
mod = MagicMock()
mod.format_response.side_effect = RuntimeError
with _stub_rich(mod):
assert render_message("hi") is None
# ── render_diff / make_stream_renderer ───────────────────────────────
def test_render_diff_none_without_module():
with _no_rich():
assert render_diff("+line") is None
def test_stream_renderer_none_without_module():
with _no_rich():
assert make_stream_renderer() is None
def test_stream_renderer_returns_instance():
renderer = MagicMock()
mod = MagicMock()
mod.StreamingRenderer.return_value = renderer
with _stub_rich(mod):
assert make_stream_renderer(120) is renderer