mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-09 08:21:50 +00:00
Merge origin/main into pr-27248 (resolving run_agent.py = ours)
run_agent.py taken from HEAD (the extracted forwarder structure). The 25 run_agent.py fixes that landed on main during the PR's life need to be ported into the agent/* extracted modules in follow-up commits.
This commit is contained in:
commit
152d42d1a7
355 changed files with 32716 additions and 4195 deletions
|
|
@ -1,6 +1,9 @@
|
|||
"""Tests for acp_adapter.entry startup wiring."""
|
||||
|
||||
import sys
|
||||
|
||||
import acp
|
||||
import pytest
|
||||
|
||||
from acp_adapter import entry
|
||||
|
||||
|
|
@ -42,12 +45,152 @@ def test_main_setup_runs_model_configuration(monkeypatch):
|
|||
calls = {}
|
||||
|
||||
def fake_hermes_main():
|
||||
import sys
|
||||
|
||||
calls["argv"] = sys.argv[:]
|
||||
|
||||
monkeypatch.setattr("hermes_cli.main.main", fake_hermes_main)
|
||||
# Pretend stdin is not a TTY so the follow-up browser prompt is skipped.
|
||||
# That keeps this test focused on the model-setup wiring; the
|
||||
# browser-prompt path has its own test below.
|
||||
monkeypatch.setattr("sys.stdin.isatty", lambda: False)
|
||||
|
||||
entry.main(["--setup"])
|
||||
|
||||
assert calls["argv"][1:] == ["model"]
|
||||
|
||||
|
||||
def test_main_setup_offers_browser_install_when_tty(monkeypatch):
|
||||
"""When stdin is a TTY and the user answers yes, model setup is followed
|
||||
by a browser-tools bootstrap call."""
|
||||
monkeypatch.setattr("hermes_cli.main.main", lambda: None)
|
||||
monkeypatch.setattr("sys.stdin.isatty", lambda: True)
|
||||
monkeypatch.setattr("builtins.input", lambda *_args, **_kwargs: "y")
|
||||
|
||||
bootstrap_calls = []
|
||||
monkeypatch.setattr(
|
||||
entry,
|
||||
"_run_setup_browser",
|
||||
lambda assume_yes=False: bootstrap_calls.append(assume_yes) or 0,
|
||||
)
|
||||
|
||||
entry.main(["--setup"])
|
||||
|
||||
assert bootstrap_calls == [False]
|
||||
|
||||
|
||||
def test_main_setup_skips_browser_prompt_on_no(monkeypatch):
|
||||
monkeypatch.setattr("hermes_cli.main.main", lambda: None)
|
||||
monkeypatch.setattr("sys.stdin.isatty", lambda: True)
|
||||
monkeypatch.setattr("builtins.input", lambda *_args, **_kwargs: "")
|
||||
|
||||
called = []
|
||||
monkeypatch.setattr(
|
||||
entry,
|
||||
"_run_setup_browser",
|
||||
lambda assume_yes=False: called.append(assume_yes) or 0,
|
||||
)
|
||||
|
||||
entry.main(["--setup"])
|
||||
|
||||
assert called == []
|
||||
|
||||
|
||||
def test_main_setup_browser_invokes_bundled_script(monkeypatch):
|
||||
"""`hermes-acp --setup-browser` must shell out to the bundled bootstrap
|
||||
script — never reimplement the install logic inline."""
|
||||
monkeypatch.setattr("platform.system", lambda: "Linux")
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_run(cmd, check=False):
|
||||
captured["cmd"] = cmd
|
||||
|
||||
class _R:
|
||||
returncode = 0
|
||||
|
||||
return _R()
|
||||
|
||||
monkeypatch.setattr("subprocess.run", fake_run)
|
||||
|
||||
entry.main(["--setup-browser"])
|
||||
|
||||
assert captured["cmd"][0] == "bash"
|
||||
assert captured["cmd"][1].endswith("bootstrap_browser_tools.sh")
|
||||
# --yes is NOT passed when the flag is absent.
|
||||
assert "--yes" not in captured["cmd"]
|
||||
|
||||
|
||||
def test_main_setup_browser_forwards_yes_flag(monkeypatch):
|
||||
monkeypatch.setattr("platform.system", lambda: "Linux")
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_run(cmd, check=False):
|
||||
captured["cmd"] = cmd
|
||||
|
||||
class _R:
|
||||
returncode = 0
|
||||
|
||||
return _R()
|
||||
|
||||
monkeypatch.setattr("subprocess.run", fake_run)
|
||||
|
||||
entry.main(["--setup-browser", "--yes"])
|
||||
|
||||
assert "--yes" in captured["cmd"]
|
||||
|
||||
|
||||
def test_main_setup_browser_uses_powershell_on_windows(monkeypatch):
|
||||
monkeypatch.setattr("platform.system", lambda: "Windows")
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_run(cmd, check=False):
|
||||
captured["cmd"] = cmd
|
||||
|
||||
class _R:
|
||||
returncode = 0
|
||||
|
||||
return _R()
|
||||
|
||||
monkeypatch.setattr("subprocess.run", fake_run)
|
||||
|
||||
entry.main(["--setup-browser", "--yes"])
|
||||
|
||||
assert captured["cmd"][0] == "powershell.exe"
|
||||
assert any(part.endswith("bootstrap_browser_tools.ps1") for part in captured["cmd"])
|
||||
assert "-Yes" in captured["cmd"]
|
||||
|
||||
|
||||
def test_main_setup_browser_propagates_failure(monkeypatch):
|
||||
monkeypatch.setattr("platform.system", lambda: "Linux")
|
||||
|
||||
class _R:
|
||||
returncode = 7
|
||||
|
||||
monkeypatch.setattr("subprocess.run", lambda cmd, check=False: _R())
|
||||
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
entry.main(["--setup-browser"])
|
||||
assert excinfo.value.code == 7
|
||||
|
||||
|
||||
def test_bootstrap_scripts_ship_with_package():
|
||||
"""The package-data wiring (pyproject.toml) must include the bootstrap
|
||||
scripts — otherwise `--setup-browser` 404s at runtime."""
|
||||
from pathlib import Path
|
||||
|
||||
bootstrap_dir = Path(entry.__file__).resolve().parent / "bootstrap"
|
||||
sh = bootstrap_dir / "bootstrap_browser_tools.sh"
|
||||
ps1 = bootstrap_dir / "bootstrap_browser_tools.ps1"
|
||||
|
||||
assert sh.is_file(), f"missing bundled script: {sh}"
|
||||
assert ps1.is_file(), f"missing bundled script: {ps1}"
|
||||
|
||||
sh_text = sh.read_text(encoding="utf-8")
|
||||
ps1_text = ps1.read_text(encoding="utf-8")
|
||||
|
||||
# Sanity: scripts know how to find the Hermes-managed Node prefix.
|
||||
assert "HERMES_HOME" in sh_text
|
||||
assert "agent-browser" in sh_text
|
||||
assert "HermesHome" in ps1_text
|
||||
assert "agent-browser" in ps1_text
|
||||
|
|
|
|||
|
|
@ -1,15 +1,19 @@
|
|||
"""Tests for acp_adapter.events — callback factories for ACP notifications."""
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import warnings
|
||||
from concurrent.futures import Future
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import acp
|
||||
from acp.schema import ToolCallStart, ToolCallProgress, AgentThoughtChunk, AgentMessageChunk
|
||||
from acp.schema import AgentPlanUpdate, ToolCallStart, ToolCallProgress, AgentThoughtChunk, AgentMessageChunk
|
||||
|
||||
from acp_adapter.events import (
|
||||
_build_plan_update_from_todo_result,
|
||||
_send_update,
|
||||
make_message_cb,
|
||||
make_step_cb,
|
||||
make_thinking_cb,
|
||||
|
|
@ -293,6 +297,54 @@ class TestStepCallback:
|
|||
}
|
||||
mock_send.assert_called_once()
|
||||
|
||||
def test_todo_completion_emits_native_plan_update_after_tool_completion(self, mock_conn, event_loop_fixture):
|
||||
from collections import deque
|
||||
|
||||
tool_call_ids = {"todo": deque(["tc-todo"])}
|
||||
loop = event_loop_fixture
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
todo_result = (
|
||||
'{"todos":['
|
||||
'{"id":"inspect","content":"Inspect ACP","status":"completed"},'
|
||||
'{"id":"patch","content":"Patch renderer","status":"in_progress"},'
|
||||
'{"id":"old","content":"Drop stale task","status":"cancelled"}'
|
||||
'],"summary":{"total":3}}'
|
||||
)
|
||||
|
||||
with patch("acp_adapter.events._send_update") as mock_send:
|
||||
cb(1, [{"name": "todo", "result": todo_result}])
|
||||
|
||||
updates = [call.args[3] for call in mock_send.call_args_list]
|
||||
assert [getattr(update, "session_update", None) for update in updates] == [
|
||||
"tool_call_update",
|
||||
"plan",
|
||||
]
|
||||
plan = updates[1]
|
||||
assert isinstance(plan, AgentPlanUpdate)
|
||||
assert [entry.content for entry in plan.entries] == [
|
||||
"Inspect ACP",
|
||||
"Patch renderer",
|
||||
"[cancelled] Drop stale task",
|
||||
]
|
||||
assert [entry.status for entry in plan.entries] == ["completed", "in_progress", "completed"]
|
||||
assert [entry.priority for entry in plan.entries] == ["medium", "medium", "medium"]
|
||||
|
||||
def test_todo_plan_update_parses_json_with_trailing_hint(self):
|
||||
result = '{"todos":[{"id":"ship","content":"Ship ACP plan","status":"pending"}]}\n\n[Hint: persisted]'
|
||||
|
||||
update = _build_plan_update_from_todo_result(result)
|
||||
|
||||
assert isinstance(update, AgentPlanUpdate)
|
||||
assert [entry.content for entry in update.entries] == ["Ship ACP plan"]
|
||||
assert [entry.status for entry in update.entries] == ["pending"]
|
||||
|
||||
def test_todo_plan_update_with_empty_todos_clears_plan(self):
|
||||
update = _build_plan_update_from_todo_result('{"todos":[],"summary":{"total":0}}')
|
||||
|
||||
assert isinstance(update, AgentPlanUpdate)
|
||||
assert update.session_update == "plan"
|
||||
assert update.entries == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message callback
|
||||
|
|
@ -325,3 +377,46 @@ class TestMessageCallback:
|
|||
cb("")
|
||||
|
||||
mock_rcts.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scheduler-failure regression
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendUpdate:
|
||||
def test_scheduler_failure_closes_update_coroutine(self, event_loop_fixture):
|
||||
"""If run_coroutine_threadsafe raises, _send_update must close the coro."""
|
||||
created = {"coro": None}
|
||||
|
||||
async def _session_update(session_id, update):
|
||||
return None
|
||||
|
||||
conn = MagicMock()
|
||||
|
||||
def _capture_update(session_id, update):
|
||||
created["coro"] = _session_update(session_id, update)
|
||||
return created["coro"]
|
||||
|
||||
conn.session_update = _capture_update
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
with patch(
|
||||
"agent.async_utils.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=RuntimeError("scheduler down"),
|
||||
):
|
||||
_send_update(conn, "session-1", event_loop_fixture, {"type": "noop"})
|
||||
gc.collect()
|
||||
|
||||
assert created["coro"] is not None
|
||||
assert created["coro"].cr_frame is None
|
||||
# Only count warnings about THIS test's coroutine; other tests in the
|
||||
# same xdist worker (or stdlib mock internals) may emit unrelated
|
||||
# "coroutine was never awaited" warnings that bleed through.
|
||||
runtime_warnings = [
|
||||
w for w in caught
|
||||
if issubclass(w.category, RuntimeWarning)
|
||||
and "was never awaited" in str(w.message)
|
||||
and "_session_update" in str(w.message)
|
||||
]
|
||||
assert runtime_warnings == []
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def _invoke_callback(
|
|||
scheduled["loop"] = passed_loop
|
||||
return future
|
||||
|
||||
with patch("acp_adapter.permissions.asyncio.run_coroutine_threadsafe", side_effect=_schedule):
|
||||
with patch("agent.async_utils.asyncio.run_coroutine_threadsafe", side_effect=_schedule):
|
||||
cb = make_approval_callback(request_permission, loop, session_id="s1", timeout=timeout)
|
||||
if use_prompt_path:
|
||||
result = prompt_dangerous_approval(
|
||||
|
|
@ -135,7 +135,7 @@ class TestApprovalBridge:
|
|||
scheduled["loop"] = passed_loop
|
||||
return future
|
||||
|
||||
with patch("acp_adapter.permissions.asyncio.run_coroutine_threadsafe", side_effect=_schedule):
|
||||
with patch("agent.async_utils.asyncio.run_coroutine_threadsafe", side_effect=_schedule):
|
||||
cb = make_approval_callback(request_permission, loop, session_id="s1", timeout=0.01)
|
||||
result = cb("rm -rf /", "dangerous command")
|
||||
|
||||
|
|
@ -159,10 +159,53 @@ class TestApprovalBridge:
|
|||
scheduled["loop"] = passed_loop
|
||||
return future
|
||||
|
||||
with patch("acp_adapter.permissions.asyncio.run_coroutine_threadsafe", side_effect=_schedule):
|
||||
with patch("agent.async_utils.asyncio.run_coroutine_threadsafe", side_effect=_schedule):
|
||||
cb = make_approval_callback(request_permission, loop, session_id="s1", timeout=1.0)
|
||||
result = cb("echo hi", "demo")
|
||||
|
||||
scheduled["coro"].close()
|
||||
|
||||
assert result == "deny"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scheduler-failure regression
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import gc # noqa: E402
|
||||
import warnings # noqa: E402
|
||||
|
||||
|
||||
class TestSchedulerFailure:
|
||||
def test_scheduler_failure_closes_permission_coroutine(self):
|
||||
"""If run_coroutine_threadsafe raises, the coro is closed and we return 'deny'."""
|
||||
loop = MagicMock(spec=asyncio.AbstractEventLoop)
|
||||
created = {"coro": None}
|
||||
|
||||
async def _response_coro(**kwargs):
|
||||
return _make_response(AllowedOutcome(option_id="allow_once", outcome="selected"))
|
||||
|
||||
def _request_permission(**kwargs):
|
||||
created["coro"] = _response_coro(**kwargs)
|
||||
return created["coro"]
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
with patch(
|
||||
"agent.async_utils.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=RuntimeError("scheduler down"),
|
||||
):
|
||||
cb = make_approval_callback(_request_permission, loop, session_id="s1", timeout=0.01)
|
||||
result = cb("rm -rf /", "dangerous")
|
||||
gc.collect()
|
||||
|
||||
assert result == "deny"
|
||||
assert created["coro"] is not None
|
||||
assert created["coro"].cr_frame is None
|
||||
runtime_warnings = [
|
||||
w for w in caught
|
||||
if issubclass(w.category, RuntimeWarning)
|
||||
and "was never awaited" in str(w.message)
|
||||
and "_response_coro" in str(w.message)
|
||||
]
|
||||
assert runtime_warnings == []
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ from acp.agent.router import build_agent_router
|
|||
from acp.schema import (
|
||||
AgentCapabilities,
|
||||
AgentMessageChunk,
|
||||
AgentPlanUpdate,
|
||||
AgentThoughtChunk,
|
||||
AuthenticateResponse,
|
||||
AvailableCommandsUpdate,
|
||||
Implementation,
|
||||
|
|
@ -391,6 +393,57 @@ class TestSessionOps:
|
|||
assert "Search results" in tool_updates[1].content[0].content.text
|
||||
assert "cli.py:42" in tool_updates[1].content[0].content.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_replays_native_plan_for_persisted_todo_tool(self, agent):
|
||||
"""Persisted todo tool results should rebuild Zed's native plan panel."""
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
state.history = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_todo_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "todo",
|
||||
"arguments": '{"todos":[{"id":"ship","content":"Ship it","status":"in_progress"}]}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_todo_1",
|
||||
"content": '{"todos":[{"id":"ship","content":"Ship it","status":"in_progress"}]}',
|
||||
},
|
||||
]
|
||||
|
||||
mock_conn.session_update.reset_mock()
|
||||
resp = await agent.load_session(cwd="/tmp", session_id=new_resp.session_id)
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert isinstance(resp, LoadSessionResponse)
|
||||
relevant_updates = [
|
||||
update for update in (call.kwargs["update"] for call in mock_conn.session_update.await_args_list)
|
||||
if getattr(update, "session_update", None) in {"tool_call", "tool_call_update", "plan"}
|
||||
]
|
||||
assert [getattr(update, "session_update", None) for update in relevant_updates] == [
|
||||
"tool_call",
|
||||
"tool_call_update",
|
||||
"plan",
|
||||
]
|
||||
plan = relevant_updates[2]
|
||||
assert isinstance(plan, AgentPlanUpdate)
|
||||
assert [entry.content for entry in plan.entries] == ["Ship it"]
|
||||
assert [entry.status for entry in plan.entries] == ["in_progress"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_session_replays_persisted_history_to_client(self, agent):
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
|
|
@ -415,25 +468,296 @@ class TestSessionOps:
|
|||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_schedules_history_replay_after_response(self, agent):
|
||||
"""Zed only attaches replayed updates after session/load has completed."""
|
||||
async def test_load_session_replays_reasoning_thought_before_message(self, agent):
|
||||
"""Thinking-model thoughts must be replayed via ``agent_thought_chunk``.
|
||||
|
||||
Regression for #12285 — when a session is loaded, persisted assistant
|
||||
``reasoning_content`` / ``reasoning`` fields must surface as ACP
|
||||
``AgentThoughtChunk`` notifications in the same relative position they
|
||||
had live (thought streams before the assistant message text), so Zed's
|
||||
collapsed Thinking pane rebuilds instead of vanishing on reconnect.
|
||||
"""
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
state.history = [
|
||||
{"role": "user", "content": "Walk me through it."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "Let me think step by step about the request.",
|
||||
"content": "Here is the plan.",
|
||||
},
|
||||
{"role": "user", "content": "And the legacy case?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
# No reasoning_content — exercise the legacy "reasoning" fallback
|
||||
# path so sessions persisted before #16892 still replay thoughts.
|
||||
"reasoning": "Older sessions stored the trace under the internal key.",
|
||||
"content": "Same idea, older field name.",
|
||||
},
|
||||
]
|
||||
|
||||
mock_conn.session_update.reset_mock()
|
||||
resp = await agent.load_session(cwd="/tmp", session_id=new_resp.session_id)
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert isinstance(resp, LoadSessionResponse)
|
||||
|
||||
replay_kinds = [
|
||||
getattr(call.kwargs.get("update"), "session_update", None)
|
||||
for call in mock_conn.session_update.await_args_list
|
||||
if getattr(call.kwargs.get("update"), "session_update", None)
|
||||
in {"user_message_chunk", "agent_message_chunk", "agent_thought_chunk"}
|
||||
]
|
||||
assert replay_kinds == [
|
||||
"user_message_chunk",
|
||||
"agent_thought_chunk",
|
||||
"agent_message_chunk",
|
||||
"user_message_chunk",
|
||||
"agent_thought_chunk",
|
||||
"agent_message_chunk",
|
||||
]
|
||||
|
||||
thought_updates = [
|
||||
call.kwargs["update"]
|
||||
for call in mock_conn.session_update.await_args_list
|
||||
if isinstance(call.kwargs.get("update"), AgentThoughtChunk)
|
||||
]
|
||||
assert len(thought_updates) == 2
|
||||
assert thought_updates[0].content.text == "Let me think step by step about the request."
|
||||
assert thought_updates[1].content.text == "Older sessions stored the trace under the internal key."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_replays_reasoning_only_turn(self, agent):
|
||||
"""Assistant turns with reasoning but no content should still emit a thought.
|
||||
|
||||
Pure reasoning-only assistant entries (e.g. a thinking step before a
|
||||
tool-call turn) commonly carry ``reasoning_content`` with empty
|
||||
``content``. The replay must still surface the thought so the editor's
|
||||
Thinking pane rebuilds, even when there is no message text to follow.
|
||||
"""
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
state.history = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "I should call the search tool next.",
|
||||
"content": "",
|
||||
},
|
||||
]
|
||||
|
||||
mock_conn.session_update.reset_mock()
|
||||
await agent.load_session(cwd="/tmp", session_id=new_resp.session_id)
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
thought_updates = [
|
||||
call.kwargs["update"]
|
||||
for call in mock_conn.session_update.await_args_list
|
||||
if isinstance(call.kwargs.get("update"), AgentThoughtChunk)
|
||||
]
|
||||
message_updates = [
|
||||
call.kwargs["update"]
|
||||
for call in mock_conn.session_update.await_args_list
|
||||
if isinstance(call.kwargs.get("update"), AgentMessageChunk)
|
||||
]
|
||||
assert len(thought_updates) == 1
|
||||
assert thought_updates[0].content.text == "I should call the search tool next."
|
||||
assert message_updates == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_skips_empty_reasoning_fields(self, agent):
|
||||
"""Empty/whitespace reasoning fields must not produce notifications."""
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
state.history = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "",
|
||||
"reasoning": " \n\t",
|
||||
"content": "Just a regular answer.",
|
||||
},
|
||||
]
|
||||
|
||||
mock_conn.session_update.reset_mock()
|
||||
await agent.load_session(cwd="/tmp", session_id=new_resp.session_id)
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
thought_updates = [
|
||||
call.kwargs["update"]
|
||||
for call in mock_conn.session_update.await_args_list
|
||||
if isinstance(call.kwargs.get("update"), AgentThoughtChunk)
|
||||
]
|
||||
assert thought_updates == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_replays_thought_then_tool_call_without_message(self, agent):
|
||||
"""Canonical thinking-model shape: reasoning + tool_call + no body text.
|
||||
|
||||
Thinking models commonly emit a pre-tool thought followed by a
|
||||
tool_calls turn with empty ``content``. Replay must emit:
|
||||
``agent_thought_chunk`` then ``tool_call`` then ``tool_call_update``
|
||||
for the matching tool result — and crucially, NO ``agent_message_chunk``
|
||||
for the empty-text assistant body. Regression for the canonical
|
||||
thinking-then-tool flow on #12285.
|
||||
"""
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
state.history = [
|
||||
{"role": "user", "content": "Find the bug."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "I should grep for the function name first.",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_grep_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_files",
|
||||
"arguments": '{"pattern":"foo","path":"."}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_grep_1",
|
||||
"content": '{"total_count":1,"matches":[{"path":"x.py","line":1,"content":"foo"}]}',
|
||||
},
|
||||
]
|
||||
|
||||
mock_conn.session_update.reset_mock()
|
||||
await agent.load_session(cwd="/tmp", session_id=new_resp.session_id)
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
kinds = [
|
||||
getattr(call.kwargs.get("update"), "session_update", None)
|
||||
for call in mock_conn.session_update.await_args_list
|
||||
if getattr(call.kwargs.get("update"), "session_update", None)
|
||||
in {
|
||||
"user_message_chunk",
|
||||
"agent_thought_chunk",
|
||||
"agent_message_chunk",
|
||||
"tool_call",
|
||||
"tool_call_update",
|
||||
}
|
||||
]
|
||||
# No agent_message_chunk for the empty-content assistant turn.
|
||||
assert "agent_message_chunk" not in kinds
|
||||
# Thought must precede the tool_call_start within the assistant turn,
|
||||
# and the tool result follows.
|
||||
assert kinds == [
|
||||
"user_message_chunk",
|
||||
"agent_thought_chunk",
|
||||
"tool_call",
|
||||
"tool_call_update",
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_replays_history_before_returning_response(self, agent):
|
||||
"""Per ACP spec, replay must complete BEFORE load_session returns.
|
||||
|
||||
Spec-compliant ACP clients (Codex, Claude Code, OpenCode, Pi, Zed)
|
||||
attach their ``session/update`` listeners before awaiting the
|
||||
``loadSession`` RPC and rely on receiving the full transcript within
|
||||
the request's lifetime. Deferring replay via ``loop.call_soon`` (the
|
||||
prior behavior in May 2026) broke clients that read notification
|
||||
counts synchronously against the load response — see #12285 follow-up.
|
||||
"""
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
state.history = [{"role": "user", "content": "hello from history"}]
|
||||
events = []
|
||||
events: list[str] = []
|
||||
|
||||
async def replay_after_response(_state):
|
||||
async def replay_records(_state):
|
||||
events.append("replay")
|
||||
|
||||
with patch.object(agent, "_replay_session_history", side_effect=replay_after_response):
|
||||
with patch.object(agent, "_replay_session_history", side_effect=replay_records):
|
||||
resp = await agent.load_session(cwd="/tmp", session_id=new_resp.session_id)
|
||||
events.append("returned")
|
||||
|
||||
assert isinstance(resp, LoadSessionResponse)
|
||||
assert events == ["returned"]
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
assert events == ["returned", "replay"]
|
||||
# Replay must have happened BEFORE the response was constructed —
|
||||
# i.e. before the `events.append("returned")` after the await resolves.
|
||||
assert events == ["replay", "returned"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_session_replays_history_before_returning_response(self, agent):
|
||||
"""Same spec rationale as ``load_session`` — replay before responding."""
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
state.history = [{"role": "user", "content": "hello from history"}]
|
||||
events: list[str] = []
|
||||
|
||||
async def replay_records(_state):
|
||||
events.append("replay")
|
||||
|
||||
with patch.object(agent, "_replay_session_history", side_effect=replay_records):
|
||||
resp = await agent.resume_session(cwd="/tmp", session_id=new_resp.session_id)
|
||||
events.append("returned")
|
||||
|
||||
assert isinstance(resp, ResumeSessionResponse)
|
||||
assert events == ["replay", "returned"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_survives_replay_helper_exception(self, agent, caplog):
|
||||
"""A replay helper raising must not turn load_session into an error.
|
||||
|
||||
With awaited replay, an exception in ``_replay_session_history`` now
|
||||
propagates into the ``load_session`` handler. The defensive try/except
|
||||
guard at the call site must catch and log it so the JSON-RPC client
|
||||
still receives a ``LoadSessionResponse`` — partial transcripts are
|
||||
acceptable, total load failure is not.
|
||||
"""
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
state.history = [{"role": "user", "content": "hi"}]
|
||||
|
||||
async def boom(_state):
|
||||
raise RuntimeError("simulated replay helper crash")
|
||||
|
||||
with caplog.at_level("WARNING", logger="acp_adapter.server"):
|
||||
with patch.object(agent, "_replay_session_history", side_effect=boom):
|
||||
resp = await agent.load_session(cwd="/tmp", session_id=new_resp.session_id)
|
||||
|
||||
assert isinstance(resp, LoadSessionResponse)
|
||||
assert "history replay raised during session/load" in caplog.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_session_survives_replay_helper_exception(self, agent, caplog):
|
||||
"""Same guarantee as ``load_session`` for the resume path."""
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
state.history = [{"role": "user", "content": "hi"}]
|
||||
|
||||
async def boom(_state):
|
||||
raise RuntimeError("simulated replay helper crash")
|
||||
|
||||
with caplog.at_level("WARNING", logger="acp_adapter.server"):
|
||||
with patch.object(agent, "_replay_session_history", side_effect=boom):
|
||||
resp = await agent.resume_session(cwd="/tmp", session_id=new_resp.session_id)
|
||||
|
||||
assert isinstance(resp, ResumeSessionResponse)
|
||||
assert "history replay raised during session/resume" in caplog.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_session_creates_new_if_missing(self, agent):
|
||||
|
|
|
|||
|
|
@ -157,6 +157,13 @@ class TestBuildAnthropicClient:
|
|||
|
||||
|
||||
class TestReadClaudeCodeCredentials:
|
||||
@pytest.fixture(autouse=True)
|
||||
def no_keychain(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter._read_claude_code_credentials_from_keychain",
|
||||
lambda: None,
|
||||
)
|
||||
|
||||
def test_reads_valid_credentials(self, tmp_path, monkeypatch):
|
||||
cred_file = tmp_path / ".claude" / ".credentials.json"
|
||||
cred_file.parent.mkdir(parents=True)
|
||||
|
|
|
|||
170
tests/agent/test_anthropic_oauth_pkce.py
Normal file
170
tests/agent/test_anthropic_oauth_pkce.py
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
"""Regression tests for the Anthropic OAuth PKCE flow.
|
||||
|
||||
Guards against re-introducing the bug where the PKCE ``code_verifier`` was
|
||||
reused as the OAuth ``state`` parameter, leaking the verifier via the
|
||||
authorization URL (browser history, Referer headers, auth-server logs) and
|
||||
removing CSRF protection on the callback path.
|
||||
|
||||
History:
|
||||
- PR #1775 first fixed this on ``run_hermes_oauth_login()``.
|
||||
- PR #2647 (b17e5c10) added ``run_hermes_oauth_login_pure()`` and silently
|
||||
copy-pasted the pre-#1775 vulnerable pattern.
|
||||
- PR #3107 removed the old function, leaving only the regressed copy.
|
||||
- PR #10699 (issue #10693) fixed the regression on the surviving function.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
|
||||
def _patch_oauth_flow(
|
||||
monkeypatch,
|
||||
*,
|
||||
callback_code: str,
|
||||
token_response: Dict[str, Any] | None = None,
|
||||
capture_token_request: Dict[str, Any] | None = None,
|
||||
capture_auth_url: Dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Wire up monkeypatches that let ``run_hermes_oauth_login_pure()`` run
|
||||
end-to-end without touching a real browser, stdin, or HTTP endpoint.
|
||||
|
||||
``callback_code`` is the literal string the user would paste back into the
|
||||
terminal (``"<code>#<state>"`` format).
|
||||
``capture_token_request`` and ``capture_auth_url`` are out-dict captures
|
||||
so the test can introspect what was sent to the auth URL and the token
|
||||
endpoint, respectively.
|
||||
"""
|
||||
import urllib.request
|
||||
|
||||
if token_response is None:
|
||||
token_response = {
|
||||
"access_token": "sk-ant-test-access",
|
||||
"refresh_token": "sk-ant-test-refresh",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
def fake_open(url):
|
||||
if capture_auth_url is not None:
|
||||
capture_auth_url["url"] = url
|
||||
return True
|
||||
|
||||
monkeypatch.setattr("webbrowser.open", fake_open)
|
||||
monkeypatch.setattr("builtins.input", lambda *_a, **_kw: callback_code)
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, body: bytes) -> None:
|
||||
self._body = body
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_exc):
|
||||
return False
|
||||
|
||||
def read(self):
|
||||
return self._body
|
||||
|
||||
def fake_urlopen(req, *_a, **_kw):
|
||||
if capture_token_request is not None:
|
||||
capture_token_request["url"] = req.full_url
|
||||
capture_token_request["data"] = json.loads(req.data.decode())
|
||||
capture_token_request["headers"] = dict(req.headers)
|
||||
return _FakeResponse(json.dumps(token_response).encode())
|
||||
|
||||
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen)
|
||||
|
||||
|
||||
def test_authorization_url_state_is_not_pkce_verifier(monkeypatch, tmp_path):
|
||||
"""The ``state`` parameter in the authorization URL must NOT equal the
|
||||
PKCE ``code_verifier``.
|
||||
|
||||
Reusing the verifier as state leaks the verifier into browser history,
|
||||
Referer headers, and auth-server access logs — defeating RFC 7636.
|
||||
"""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
captured_url: Dict[str, str] = {}
|
||||
captured_token: Dict[str, Any] = {}
|
||||
_patch_oauth_flow(
|
||||
monkeypatch,
|
||||
# state echoed back unchanged so the CSRF guard passes
|
||||
callback_code="auth-code-from-anthropic#PLACEHOLDER",
|
||||
capture_auth_url=captured_url,
|
||||
capture_token_request=captured_token,
|
||||
)
|
||||
|
||||
# Stub the callback parse: we need the state echoed back to match. To do
|
||||
# that without hardcoding the state value, override input() AFTER seeing
|
||||
# the auth URL.
|
||||
import builtins
|
||||
|
||||
real_input_calls = {"count": 0}
|
||||
|
||||
def fake_input(*_a, **_kw):
|
||||
real_input_calls["count"] += 1
|
||||
# First (and only) call is the "Authorization code:" prompt.
|
||||
url = captured_url.get("url", "")
|
||||
qs = parse_qs(urlparse(url).query)
|
||||
state = qs.get("state", [""])[0]
|
||||
return f"auth-code-from-anthropic#{state}"
|
||||
|
||||
monkeypatch.setattr(builtins, "input", fake_input)
|
||||
|
||||
from agent.anthropic_adapter import run_hermes_oauth_login_pure
|
||||
|
||||
result = run_hermes_oauth_login_pure()
|
||||
assert result is not None, "OAuth flow should succeed with matching state"
|
||||
|
||||
url = captured_url["url"]
|
||||
qs = parse_qs(urlparse(url).query)
|
||||
|
||||
assert "state" in qs and qs["state"][0], "authorization URL must include state"
|
||||
assert "code_challenge" in qs, "authorization URL must include code_challenge"
|
||||
|
||||
state_in_url = qs["state"][0]
|
||||
verifier_sent = captured_token["data"]["code_verifier"]
|
||||
|
||||
# The whole point: state and verifier must be independent values.
|
||||
assert state_in_url != verifier_sent, (
|
||||
"PKCE code_verifier was reused as OAuth state — regression of #10693 / "
|
||||
"#1775. The verifier is supposed to be a secret known only to the "
|
||||
"client; placing it in the authorization URL leaks it via browser "
|
||||
"history, Referer headers, and auth-server logs."
|
||||
)
|
||||
|
||||
# And the verifier MUST NOT appear anywhere in the URL.
|
||||
assert verifier_sent not in url, (
|
||||
"PKCE verifier leaked into authorization URL — regression of #10693"
|
||||
)
|
||||
|
||||
|
||||
def test_callback_state_mismatch_aborts(monkeypatch, tmp_path, caplog):
|
||||
"""If the state returned in the callback does not match the one we sent
|
||||
in the authorization URL, the flow must abort before exchanging the code.
|
||||
|
||||
Without this check, an attacker who tricks the user into pasting a
|
||||
crafted ``<code>#<state>`` string can complete the token exchange — the
|
||||
CSRF protection that ``state`` is supposed to provide (RFC 6749 §10.12)
|
||||
would be absent.
|
||||
"""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
captured_token: Dict[str, Any] = {}
|
||||
_patch_oauth_flow(
|
||||
monkeypatch,
|
||||
callback_code="attacker-code#attacker-state-does-not-match",
|
||||
capture_token_request=captured_token,
|
||||
)
|
||||
|
||||
from agent.anthropic_adapter import run_hermes_oauth_login_pure
|
||||
|
||||
result = run_hermes_oauth_login_pure()
|
||||
|
||||
assert result is None, "mismatched state must abort the flow"
|
||||
assert "url" not in captured_token, (
|
||||
"token exchange must NOT happen when state mismatches"
|
||||
)
|
||||
157
tests/agent/test_async_utils.py
Normal file
157
tests/agent/test_async_utils.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
"""Tests for agent.async_utils.safe_schedule_threadsafe."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import warnings
|
||||
from concurrent.futures import Future
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.async_utils import safe_schedule_threadsafe
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _no_unawaited_warnings(caught, *, coro_name: str = "") -> bool:
|
||||
"""Return True if no "X was never awaited" warning slipped through.
|
||||
|
||||
When *coro_name* is provided, only warnings naming that coroutine are
|
||||
counted — xdist workers may emit unrelated unawaited-coroutine warnings
|
||||
(e.g. ``AsyncMockMixin._execute_mock_call``) from concurrent tests.
|
||||
"""
|
||||
bad = [
|
||||
w for w in caught
|
||||
if issubclass(w.category, RuntimeWarning)
|
||||
and "was never awaited" in str(w.message)
|
||||
and (not coro_name or coro_name in str(w.message))
|
||||
]
|
||||
return not bad
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSafeScheduleThreadsafe:
|
||||
def test_returns_future_on_success(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
import threading
|
||||
ready = threading.Event()
|
||||
stop = threading.Event()
|
||||
|
||||
def _runner():
|
||||
asyncio.set_event_loop(loop)
|
||||
ready.set()
|
||||
loop.run_until_complete(_wait_for_stop(stop))
|
||||
|
||||
async def _wait_for_stop(ev):
|
||||
while not ev.is_set():
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
t = threading.Thread(target=_runner, daemon=True)
|
||||
t.start()
|
||||
ready.wait(timeout=2)
|
||||
|
||||
async def _sample():
|
||||
return 42
|
||||
|
||||
fut = safe_schedule_threadsafe(_sample(), loop)
|
||||
assert isinstance(fut, Future)
|
||||
assert fut.result(timeout=2) == 42
|
||||
|
||||
stop.set()
|
||||
t.join(timeout=2)
|
||||
finally:
|
||||
if loop.is_running():
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
loop.close()
|
||||
|
||||
def test_closed_loop_returns_none_and_closes_coroutine(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.close()
|
||||
|
||||
async def _sample():
|
||||
return "ok"
|
||||
|
||||
coro = _sample()
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
result = safe_schedule_threadsafe(coro, loop)
|
||||
del coro
|
||||
gc.collect()
|
||||
|
||||
assert result is None
|
||||
assert _no_unawaited_warnings(caught, coro_name='_sample')
|
||||
|
||||
def test_none_loop_returns_none_and_closes_coroutine(self):
|
||||
async def _sample():
|
||||
return "ok"
|
||||
|
||||
coro = _sample()
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
result = safe_schedule_threadsafe(coro, None)
|
||||
del coro
|
||||
gc.collect()
|
||||
|
||||
assert result is None
|
||||
assert _no_unawaited_warnings(caught, coro_name='_sample')
|
||||
|
||||
def test_scheduling_exception_closes_coroutine(self):
|
||||
"""If run_coroutine_threadsafe raises, close the coroutine and return None."""
|
||||
# A loop that *looks* open but raises on submission
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
async def _sample():
|
||||
return "ok"
|
||||
|
||||
coro = _sample()
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
with patch(
|
||||
"agent.async_utils.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=RuntimeError("scheduler down"),
|
||||
):
|
||||
result = safe_schedule_threadsafe(coro, loop)
|
||||
del coro
|
||||
gc.collect()
|
||||
|
||||
assert result is None
|
||||
assert _no_unawaited_warnings(caught, coro_name='_sample')
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def test_logs_at_specified_level(self, caplog):
|
||||
import logging
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.close()
|
||||
|
||||
async def _sample():
|
||||
return None
|
||||
|
||||
custom = logging.getLogger("test_async_utils")
|
||||
with caplog.at_level(logging.WARNING, logger="test_async_utils"):
|
||||
result = safe_schedule_threadsafe(
|
||||
_sample(), loop,
|
||||
logger=custom,
|
||||
log_message="custom-msg",
|
||||
log_level=logging.WARNING,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert any("custom-msg" in rec.message for rec in caplog.records)
|
||||
|
||||
def test_non_coroutine_arg_does_not_crash(self):
|
||||
"""Defensive: even if the caller hands us something weird, don't blow up."""
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.close()
|
||||
|
||||
# Pass a non-coroutine sentinel
|
||||
result = safe_schedule_threadsafe("not-a-coroutine", loop) # type: ignore[arg-type]
|
||||
assert result is None
|
||||
|
|
@ -26,6 +26,7 @@ from agent.auxiliary_client import (
|
|||
_normalize_aux_provider,
|
||||
_try_payment_fallback,
|
||||
_resolve_auto,
|
||||
_resolve_xai_oauth_for_aux,
|
||||
_CodexCompletionsAdapter,
|
||||
)
|
||||
|
||||
|
|
@ -221,6 +222,77 @@ class TestReadCodexAccessToken:
|
|||
assert result == "plain-token-no-jwt"
|
||||
|
||||
|
||||
class TestResolveXaiOAuthForAux:
|
||||
def test_uses_pool_backed_credentials_without_singleton(self, tmp_path, monkeypatch):
|
||||
"""Auxiliary xAI OAuth must see pool-only credentials.
|
||||
|
||||
``hermes auth status`` already reports these as logged in; compression
|
||||
should not fall through to "no auxiliary provider configured" just
|
||||
because the singleton auth-store entry is absent.
|
||||
"""
|
||||
from agent.credential_pool import AUTH_TYPE_OAUTH, PooledCredential, load_pool
|
||||
from hermes_cli.auth import DEFAULT_XAI_OAUTH_BASE_URL
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1,
|
||||
"providers": {},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.delenv("HERMES_XAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("XAI_BASE_URL", raising=False)
|
||||
|
||||
pool = load_pool("xai-oauth")
|
||||
pool.add_entry(PooledCredential(
|
||||
provider="xai-oauth",
|
||||
id="xai123",
|
||||
label="pool-only",
|
||||
auth_type=AUTH_TYPE_OAUTH,
|
||||
priority=0,
|
||||
source="manual:xai_pkce",
|
||||
access_token="pool-access-token",
|
||||
refresh_token="pool-refresh-token",
|
||||
base_url=DEFAULT_XAI_OAUTH_BASE_URL,
|
||||
))
|
||||
|
||||
assert _resolve_xai_oauth_for_aux() == (
|
||||
"pool-access-token",
|
||||
DEFAULT_XAI_OAUTH_BASE_URL,
|
||||
)
|
||||
|
||||
def test_pool_backed_credentials_honor_base_url_env_override(self, tmp_path, monkeypatch):
|
||||
from agent.credential_pool import AUTH_TYPE_OAUTH, PooledCredential, load_pool
|
||||
from hermes_cli.auth import DEFAULT_XAI_OAUTH_BASE_URL
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1,
|
||||
"providers": {},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setenv("HERMES_XAI_BASE_URL", "https://example.x.ai/v1/")
|
||||
|
||||
pool = load_pool("xai-oauth")
|
||||
pool.add_entry(PooledCredential(
|
||||
provider="xai-oauth",
|
||||
id="xai456",
|
||||
label="pool-only",
|
||||
auth_type=AUTH_TYPE_OAUTH,
|
||||
priority=0,
|
||||
source="manual:xai_pkce",
|
||||
access_token="pool-access-token",
|
||||
refresh_token="pool-refresh-token",
|
||||
base_url=DEFAULT_XAI_OAUTH_BASE_URL,
|
||||
))
|
||||
|
||||
assert _resolve_xai_oauth_for_aux() == (
|
||||
"pool-access-token",
|
||||
"https://example.x.ai/v1",
|
||||
)
|
||||
|
||||
|
||||
class TestAnthropicOAuthFlag:
|
||||
"""Test that OAuth tokens get is_oauth=True in auxiliary Anthropic client."""
|
||||
|
||||
|
|
@ -2415,10 +2487,51 @@ def _clean_env(monkeypatch):
|
|||
"""Strip provider env vars so each test starts clean."""
|
||||
for key in (
|
||||
"OPENROUTER_API_KEY", "OPENAI_BASE_URL", "OPENAI_API_KEY",
|
||||
"NVIDIA_API_KEY", "NVIDIA_BASE_URL",
|
||||
):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
||||
class TestNvidiaBillingHeaders:
|
||||
"""NVIDIA NIM billing-origin headers are scoped to NVIDIA cloud."""
|
||||
|
||||
def test_resolve_provider_client_cloud_adds_billing_origin_header(self, monkeypatch):
|
||||
monkeypatch.setenv("NVIDIA_API_KEY", "nvidia-key")
|
||||
monkeypatch.delenv("NVIDIA_BASE_URL", raising=False)
|
||||
mock_openai = MagicMock()
|
||||
mock_openai.return_value = MagicMock(name="nvidia-client")
|
||||
|
||||
with patch("agent.auxiliary_client.OpenAI", mock_openai):
|
||||
client, model = resolve_provider_client(
|
||||
provider="nvidia",
|
||||
model="nvidia/test-model",
|
||||
)
|
||||
|
||||
assert client is not None
|
||||
assert model == "nvidia/test-model"
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
headers = call_kwargs["default_headers"]
|
||||
assert headers["X-BILLING-INVOKE-ORIGIN"] == "HermesAgent"
|
||||
|
||||
def test_resolve_provider_client_local_nim_skips_billing_origin_header(self, monkeypatch):
|
||||
monkeypatch.setenv("NVIDIA_API_KEY", "nvidia-key")
|
||||
monkeypatch.setenv("NVIDIA_BASE_URL", "http://localhost:8000/v1")
|
||||
mock_openai = MagicMock()
|
||||
mock_openai.return_value = MagicMock(name="nvidia-local-client")
|
||||
|
||||
with patch("agent.auxiliary_client.OpenAI", mock_openai):
|
||||
client, model = resolve_provider_client(
|
||||
provider="nvidia",
|
||||
model="nvidia/test-model",
|
||||
)
|
||||
|
||||
assert client is not None
|
||||
assert model == "nvidia/test-model"
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
headers = call_kwargs.get("default_headers", {})
|
||||
assert "X-BILLING-INVOKE-ORIGIN" not in headers
|
||||
|
||||
|
||||
class TestOpenRouterExplicitApiKey:
|
||||
"""Test that explicit_api_key is correctly propagated to _try_openrouter()."""
|
||||
|
||||
|
|
|
|||
266
tests/agent/test_compressor_historical_media.py
Normal file
266
tests/agent/test_compressor_historical_media.py
Normal file
|
|
@ -0,0 +1,266 @@
|
|||
"""Tests for post-compression historical-media stripping.
|
||||
|
||||
Port of Kilo-Org/kilocode#9434 (adapted for OpenAI-style message lists).
|
||||
Without this pass, tail messages keep their original multi-MB base-64 image
|
||||
payloads after context compression, and every subsequent request re-ships
|
||||
them — sometimes breaching provider body-size limits and wedging the
|
||||
session.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.context_compressor import (
|
||||
ContextCompressor,
|
||||
_content_has_images,
|
||||
_is_image_part,
|
||||
_strip_historical_media,
|
||||
_strip_images_from_content,
|
||||
)
|
||||
|
||||
|
||||
IMG_URL = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/png;base64," + ("A" * 1024)},
|
||||
}
|
||||
INPUT_IMG = {
|
||||
"type": "input_image",
|
||||
"image_url": "data:image/png;base64," + ("B" * 1024),
|
||||
}
|
||||
ANTHROPIC_IMG = {
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/png", "data": "C" * 1024},
|
||||
}
|
||||
TEXT = {"type": "text", "text": "hi"}
|
||||
INPUT_TEXT = {"type": "input_text", "text": "hi"}
|
||||
|
||||
|
||||
class TestIsImagePart:
|
||||
def test_openai_chat_shape(self):
|
||||
assert _is_image_part(IMG_URL) is True
|
||||
|
||||
def test_openai_responses_shape(self):
|
||||
assert _is_image_part(INPUT_IMG) is True
|
||||
|
||||
def test_anthropic_native_shape(self):
|
||||
assert _is_image_part(ANTHROPIC_IMG) is True
|
||||
|
||||
def test_text_part_is_not_image(self):
|
||||
assert _is_image_part(TEXT) is False
|
||||
assert _is_image_part(INPUT_TEXT) is False
|
||||
|
||||
def test_non_dict_rejected(self):
|
||||
assert _is_image_part("image") is False
|
||||
assert _is_image_part(None) is False
|
||||
assert _is_image_part(42) is False
|
||||
|
||||
|
||||
class TestContentHasImages:
|
||||
def test_string_content(self):
|
||||
assert _content_has_images("a string") is False
|
||||
|
||||
def test_empty_list(self):
|
||||
assert _content_has_images([]) is False
|
||||
|
||||
def test_text_only_list(self):
|
||||
assert _content_has_images([TEXT, TEXT]) is False
|
||||
|
||||
def test_list_with_image(self):
|
||||
assert _content_has_images([TEXT, IMG_URL]) is True
|
||||
|
||||
def test_none(self):
|
||||
assert _content_has_images(None) is False
|
||||
|
||||
|
||||
class TestStripImagesFromContent:
|
||||
def test_string_passthrough(self):
|
||||
assert _strip_images_from_content("hello") == "hello"
|
||||
|
||||
def test_none_passthrough(self):
|
||||
assert _strip_images_from_content(None) is None
|
||||
|
||||
def test_text_only_passthrough(self):
|
||||
parts = [TEXT, {"type": "text", "text": "world"}]
|
||||
assert _strip_images_from_content(parts) == parts
|
||||
|
||||
def test_replaces_image_with_placeholder(self):
|
||||
parts = [TEXT, IMG_URL]
|
||||
out = _strip_images_from_content(parts)
|
||||
assert len(out) == 2
|
||||
assert out[0] == TEXT
|
||||
assert out[1] == {
|
||||
"type": "text",
|
||||
"text": "[Attached image — stripped after compression]",
|
||||
}
|
||||
|
||||
def test_does_not_mutate_input(self):
|
||||
parts = [IMG_URL, TEXT]
|
||||
_ = _strip_images_from_content(parts)
|
||||
assert parts[0] is IMG_URL # original list untouched
|
||||
assert parts[1] is TEXT
|
||||
|
||||
def test_handles_all_three_shapes(self):
|
||||
parts = [IMG_URL, INPUT_IMG, ANTHROPIC_IMG, TEXT]
|
||||
out = _strip_images_from_content(parts)
|
||||
assert sum(1 for p in out if p.get("type") == "text") == 4
|
||||
assert not any(_is_image_part(p) for p in out)
|
||||
|
||||
|
||||
class TestStripHistoricalMedia:
|
||||
def test_empty_passthrough(self):
|
||||
assert _strip_historical_media([]) == []
|
||||
|
||||
def test_no_images_anywhere(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "hey"},
|
||||
{"role": "user", "content": "bye"},
|
||||
]
|
||||
assert _strip_historical_media(msgs) is msgs # identity — no copy
|
||||
|
||||
def test_single_image_user_only_first_message(self):
|
||||
# Only image-bearing user is the first message — nothing before it.
|
||||
msgs = [
|
||||
{"role": "user", "content": [TEXT, IMG_URL]},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
]
|
||||
out = _strip_historical_media(msgs)
|
||||
assert out is msgs # no-op
|
||||
# Image still there.
|
||||
assert _content_has_images(out[0]["content"])
|
||||
|
||||
def test_strips_older_user_image_keeps_newest(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": [TEXT, IMG_URL]}, # old — strip
|
||||
{"role": "assistant", "content": "looked at it"},
|
||||
{"role": "user", "content": [TEXT, INPUT_IMG]}, # newest — keep
|
||||
]
|
||||
out = _strip_historical_media(msgs)
|
||||
assert out is not msgs # new list
|
||||
# First message's image was replaced
|
||||
assert not _content_has_images(out[0]["content"])
|
||||
# Newest user still has its image
|
||||
assert _content_has_images(out[2]["content"])
|
||||
|
||||
def test_strips_assistant_and_tool_images_before_anchor(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": [TEXT, IMG_URL]}, # old user
|
||||
{"role": "assistant", "content": [TEXT, IMG_URL]}, # old assistant
|
||||
{"role": "tool", "content": [TEXT, IMG_URL], "tool_call_id": "t1"},
|
||||
{"role": "user", "content": [TEXT, IMG_URL]}, # newest user — keep
|
||||
]
|
||||
out = _strip_historical_media(msgs)
|
||||
for i in range(3):
|
||||
assert not _content_has_images(out[i]["content"]), f"msg {i} still has image"
|
||||
assert _content_has_images(out[3]["content"])
|
||||
|
||||
def test_text_only_newest_user_still_strips_older_images(self):
|
||||
# The anchor is "newest user WITH images". If the newest user is
|
||||
# text-only, we fall back to the previous image-bearing user turn.
|
||||
msgs = [
|
||||
{"role": "user", "content": [TEXT, IMG_URL]},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
{"role": "user", "content": [TEXT, IMG_URL]}, # anchor
|
||||
{"role": "assistant", "content": "done"},
|
||||
{"role": "user", "content": "follow-up text only"},
|
||||
]
|
||||
out = _strip_historical_media(msgs)
|
||||
# First image-bearing user (index 0) was stripped — it was before the
|
||||
# newest image-bearing user (index 2).
|
||||
assert not _content_has_images(out[0]["content"])
|
||||
# Anchor (index 2) keeps its image.
|
||||
assert _content_has_images(out[2]["content"])
|
||||
|
||||
def test_no_image_bearing_user_is_noop(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "first"},
|
||||
{"role": "assistant", "content": [TEXT, IMG_URL]}, # assistant image only
|
||||
{"role": "user", "content": "second"},
|
||||
]
|
||||
out = _strip_historical_media(msgs)
|
||||
# No image-bearing user anchor → no stripping.
|
||||
assert out is msgs
|
||||
assert _content_has_images(out[1]["content"])
|
||||
|
||||
def test_does_not_mutate_input_messages(self):
|
||||
msg0 = {"role": "user", "content": [TEXT, IMG_URL]}
|
||||
msg1 = {"role": "user", "content": [TEXT, IMG_URL]}
|
||||
msgs = [msg0, msg1]
|
||||
_ = _strip_historical_media(msgs)
|
||||
# Originals untouched
|
||||
assert _content_has_images(msg0["content"])
|
||||
assert _content_has_images(msg1["content"])
|
||||
|
||||
def test_idempotent(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": [TEXT, IMG_URL]},
|
||||
{"role": "assistant", "content": "k"},
|
||||
{"role": "user", "content": [TEXT, IMG_URL]},
|
||||
]
|
||||
first = _strip_historical_media(msgs)
|
||||
second = _strip_historical_media(first)
|
||||
# Second pass is a no-op — no images left before the anchor.
|
||||
assert second is first
|
||||
|
||||
def test_non_dict_messages_pass_through(self):
|
||||
msgs = [
|
||||
"not-a-dict", # shouldn't crash
|
||||
{"role": "user", "content": [TEXT, IMG_URL]},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
{"role": "user", "content": [TEXT, IMG_URL]},
|
||||
]
|
||||
out = _strip_historical_media(msgs)
|
||||
assert out[0] == "not-a-dict"
|
||||
# Image-bearing user at index 1 is before the anchor (index 3) → stripped.
|
||||
assert not _content_has_images(out[1]["content"])
|
||||
|
||||
|
||||
class TestCompressIntegration:
|
||||
"""Verify the stripping runs inside ContextCompressor.compress()."""
|
||||
|
||||
@pytest.fixture
|
||||
def compressor(self):
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100_000):
|
||||
c = ContextCompressor(
|
||||
model="test/model",
|
||||
threshold_percent=0.50,
|
||||
protect_first_n=1,
|
||||
protect_last_n=2,
|
||||
quiet_mode=True,
|
||||
)
|
||||
return c
|
||||
|
||||
def test_compress_strips_historical_images(self, compressor):
|
||||
# Enough messages to trigger the summarize path. protect_first_n=1 +
|
||||
# protect_last_n=2 + a middle window of at least 3 with a summary.
|
||||
msgs = [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": [TEXT, IMG_URL]}, # old image-bearing user
|
||||
{"role": "assistant", "content": "looked at it"},
|
||||
{"role": "user", "content": "follow-up"},
|
||||
{"role": "assistant", "content": "ack"},
|
||||
{"role": "user", "content": "more"},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
{"role": "user", "content": [TEXT, IMG_URL]}, # newest image-bearing user (tail)
|
||||
{"role": "assistant", "content": "done"},
|
||||
]
|
||||
# Bypass the real LLM summary — return a stub so compress() proceeds.
|
||||
with patch.object(compressor, "_generate_summary", return_value="SUMMARY TEXT"):
|
||||
out = compressor.compress(msgs, current_tokens=60_000)
|
||||
|
||||
# Newest user turn with image should still have it (it's in the tail).
|
||||
user_imgs = [m for m in out if m.get("role") == "user" and _content_has_images(m.get("content"))]
|
||||
assert len(user_imgs) == 1, (
|
||||
"Expected exactly one user message with images after compression "
|
||||
f"(the newest one); got {len(user_imgs)}"
|
||||
)
|
||||
# No assistant or tool messages should carry images either.
|
||||
for m in out:
|
||||
if m is user_imgs[0]:
|
||||
continue
|
||||
assert not _content_has_images(m.get("content")), (
|
||||
f"Stale image in {m.get('role')!r} message after compression"
|
||||
)
|
||||
77
tests/agent/test_copilot_acp_deprecation.py
Normal file
77
tests/agent/test_copilot_acp_deprecation.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
"""Tests for gh-copilot CLI deprecation detection and GitHub Models Azure URL mapping."""
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.copilot_acp_client import _is_gh_copilot_deprecation_message
|
||||
|
||||
|
||||
class TestDeprecationPatternDetection:
|
||||
"""Verify that stderr from the deprecated `gh copilot` extension is caught
|
||||
without false-positiving on the new `@github/copilot` CLI."""
|
||||
|
||||
_REAL_DEPRECATION_STDERR = (
|
||||
"The gh-copilot extension has been deprecated in favor of the newer "
|
||||
"GitHub Copilot CLI.\nFor more information, visit:\n"
|
||||
"- Copilot CLI: https://github.com/github/copilot-cli\n"
|
||||
"- Deprecation announcement: https://github.blog/changelog/"
|
||||
"2025-09-25-upcoming-deprecation-of-gh-copilot-cli-extension\n"
|
||||
"No commands will be executed."
|
||||
)
|
||||
|
||||
def test_real_deprecation_message_matches(self):
|
||||
assert _is_gh_copilot_deprecation_message(self._REAL_DEPRECATION_STDERR)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"stderr_text",
|
||||
[
|
||||
# The deprecation banner uses both halves of the fingerprint.
|
||||
"The gh-copilot extension has been deprecated.",
|
||||
"gh-copilot: no commands will be executed.",
|
||||
# Mixed casing — match is case-insensitive.
|
||||
"The GH-Copilot Extension HAS BEEN DEPRECATED.",
|
||||
],
|
||||
)
|
||||
def test_genuine_deprecation_variants_match(self, stderr_text: str):
|
||||
assert _is_gh_copilot_deprecation_message(stderr_text)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"stderr_text",
|
||||
[
|
||||
# Generic errors — no fingerprint at all.
|
||||
"Error: connection refused",
|
||||
"",
|
||||
# The NEW @github/copilot CLI's repo is github.com/github/copilot-cli.
|
||||
# Its stderr can legitimately mention "copilot-cli" or "deprecation"
|
||||
# in unrelated contexts; neither alone should trip the detector.
|
||||
"copilot-cli: failed to authenticate with the API",
|
||||
"warning: the --foo flag is scheduled for deprecation in v3",
|
||||
"See https://github.com/github/copilot-cli/issues for support",
|
||||
# Half the fingerprint without the other half.
|
||||
"gh-copilot: command not found",
|
||||
"extension has been deprecated (some other extension)",
|
||||
],
|
||||
)
|
||||
def test_does_not_false_positive(self, stderr_text: str):
|
||||
assert not _is_gh_copilot_deprecation_message(stderr_text)
|
||||
|
||||
|
||||
class TestGitHubModelsAzureUrl:
|
||||
"""Verify that the Azure GitHub Models URL is recognised."""
|
||||
|
||||
def test_url_to_provider_contains_azure_models(self):
|
||||
from agent.model_metadata import _URL_TO_PROVIDER
|
||||
|
||||
# Maps to the canonical "copilot" provider (same convention as the
|
||||
# other GitHub-family entries) — not the "github-models" alias.
|
||||
assert _URL_TO_PROVIDER.get("models.inference.ai.azure.com") == "copilot"
|
||||
|
||||
def test_is_github_models_base_url_recognises_azure(self):
|
||||
from hermes_cli.models import _is_github_models_base_url
|
||||
|
||||
assert _is_github_models_base_url("https://models.inference.ai.azure.com")
|
||||
assert _is_github_models_base_url("https://models.inference.ai.azure.com/v1/chat")
|
||||
|
||||
def test_is_github_models_base_url_still_recognises_github_ai(self):
|
||||
from hermes_cli.models import _is_github_models_base_url
|
||||
|
||||
assert _is_github_models_base_url("https://models.github.ai/inference")
|
||||
|
|
@ -6,6 +6,11 @@ the JSON Schema ecosystem accepts:
|
|||
1. Properties without ``type`` — Moonshot requires ``type`` on every node.
|
||||
2. ``type`` at the parent of ``anyOf`` — Moonshot requires it only inside
|
||||
``anyOf`` children.
|
||||
3. ``$ref`` with sibling keywords — Moonshot expands the ref first and then
|
||||
rejects ``description``/``type`` siblings on the same node.
|
||||
(Ported from anomalyco/opencode#24730.)
|
||||
4. Tuple-style ``items`` arrays — Moonshot requires a single item schema,
|
||||
not positional ones. (Ported from anomalyco/opencode#24730.)
|
||||
|
||||
These tests cover the repairs applied by ``agent/moonshot_schema.py``.
|
||||
"""
|
||||
|
|
@ -180,6 +185,164 @@ class TestAnyOfParentType:
|
|||
assert db_type["enum"] == ["mysql", "postgresql"] # "" stripped by enum cleanup
|
||||
|
||||
|
||||
class TestRefSiblingStripping:
|
||||
"""Rule 4: ``$ref`` nodes may not carry sibling keywords on Moonshot.
|
||||
|
||||
Ported from anomalyco/opencode#24730. The real-world failure was MCP tools
|
||||
whose generated schemas put a ``description`` on a ``$ref`` property so the
|
||||
model would see the field's human-readable hint. The reference stays — the
|
||||
referenced definition still owns the description (on the target node itself)
|
||||
and still serves the model's context.
|
||||
"""
|
||||
|
||||
def test_description_sibling_stripped_from_ref(self):
|
||||
params = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"variantOptions": {
|
||||
"$ref": "#/$defs/VariantOptions",
|
||||
"description": "Required. The variant options for generation.",
|
||||
},
|
||||
},
|
||||
"$defs": {
|
||||
"VariantOptions": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"description": "Configuration options.",
|
||||
},
|
||||
},
|
||||
}
|
||||
out = sanitize_moonshot_tool_parameters(params)
|
||||
# Sibling stripped.
|
||||
assert out["properties"]["variantOptions"] == {"$ref": "#/$defs/VariantOptions"}
|
||||
# The target definition's own description is preserved — we only strip
|
||||
# siblings ON the $ref node, not on the thing it points at.
|
||||
assert out["$defs"]["VariantOptions"]["description"] == "Configuration options."
|
||||
|
||||
def test_multiple_siblings_all_stripped(self):
|
||||
params = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"p": {
|
||||
"$ref": "#/$defs/T",
|
||||
"type": "object",
|
||||
"description": "x",
|
||||
"default": {},
|
||||
"title": "P",
|
||||
},
|
||||
},
|
||||
"$defs": {"T": {"type": "object"}},
|
||||
}
|
||||
out = sanitize_moonshot_tool_parameters(params)
|
||||
assert out["properties"]["p"] == {"$ref": "#/$defs/T"}
|
||||
|
||||
def test_ref_without_siblings_unchanged(self):
|
||||
params = {
|
||||
"type": "object",
|
||||
"properties": {"p": {"$ref": "#/$defs/T"}},
|
||||
"$defs": {"T": {"type": "object"}},
|
||||
}
|
||||
out = sanitize_moonshot_tool_parameters(params)
|
||||
assert out["properties"]["p"] == {"$ref": "#/$defs/T"}
|
||||
|
||||
def test_ref_inside_anyof_children(self):
|
||||
params = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"v": {
|
||||
"anyOf": [
|
||||
{"$ref": "#/$defs/A", "description": "variant A"},
|
||||
{"type": "null"},
|
||||
],
|
||||
},
|
||||
},
|
||||
"$defs": {"A": {"type": "object"}},
|
||||
}
|
||||
out = sanitize_moonshot_tool_parameters(params)
|
||||
# Main's existing Rule 2 collapses anyOf-with-null down to the
|
||||
# single non-null branch (Moonshot rejects null branches in anyOf
|
||||
# outright). That branch was originally `{"$ref": ..., "description": ...}`;
|
||||
# Rule 4 then strips the sibling, leaving exactly `{"$ref": "..."}`.
|
||||
# The test name still applies — Rule 4 ran on the $ref branch — it
|
||||
# just happens after the anyOf collapse on this input.
|
||||
assert out["properties"]["v"] == {"$ref": "#/$defs/A"}
|
||||
|
||||
|
||||
class TestTupleItems:
|
||||
"""Rule 5: tuple-style ``items`` arrays collapse to a single schema.
|
||||
|
||||
Ported from anomalyco/opencode#24730. Moonshot's schema engine requires
|
||||
``items`` to be ONE schema object applied to every array element; tuple-
|
||||
style positional item schemas are rejected. We collapse to the first
|
||||
element's schema (which is the "closest" interpretation of positional →
|
||||
single) and drop the rest.
|
||||
"""
|
||||
|
||||
def test_tuple_items_collapsed_to_first(self):
|
||||
params = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"renderedSize": {
|
||||
"type": "array",
|
||||
"items": [{"type": "number"}, {"type": "number"}],
|
||||
"minItems": 2,
|
||||
"maxItems": 2,
|
||||
},
|
||||
},
|
||||
}
|
||||
out = sanitize_moonshot_tool_parameters(params)
|
||||
assert out["properties"]["renderedSize"]["items"] == {"type": "number"}
|
||||
# Sibling constraints are preserved — only the tuple shape is repaired.
|
||||
assert out["properties"]["renderedSize"]["minItems"] == 2
|
||||
|
||||
def test_empty_tuple_items_becomes_empty_schema(self):
|
||||
# Empty tuple collapses to ``{}``; the generic repair then fills a
|
||||
# synthetic ``type`` because Moonshot requires ``type`` on every
|
||||
# schema node. Either ``{}`` or ``{"type": "string"}`` is a valid
|
||||
# final shape for Moonshot — both accept any string element — but we
|
||||
# always go through ``_fill_missing_type`` so the result is fully
|
||||
# well-formed without needing the consumer to patch it later.
|
||||
params = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"things": {"type": "array", "items": []},
|
||||
},
|
||||
}
|
||||
out = sanitize_moonshot_tool_parameters(params)
|
||||
items = out["properties"]["things"]["items"]
|
||||
# Must be a dict and must carry a ``type`` (the whole point of Rule 1).
|
||||
assert isinstance(items, dict)
|
||||
assert items.get("type")
|
||||
|
||||
def test_tuple_items_first_element_is_repaired(self):
|
||||
# The first element itself has a missing type — it should be filled.
|
||||
params = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pair": {
|
||||
"type": "array",
|
||||
"items": [{"description": "first"}, {"description": "second"}],
|
||||
},
|
||||
},
|
||||
}
|
||||
out = sanitize_moonshot_tool_parameters(params)
|
||||
# Repaired to a single schema with a synthetic type.
|
||||
assert out["properties"]["pair"]["items"] == {
|
||||
"description": "first",
|
||||
"type": "string",
|
||||
}
|
||||
|
||||
def test_single_schema_items_unchanged(self):
|
||||
params = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
out = sanitize_moonshot_tool_parameters(params)
|
||||
assert out["properties"]["tags"]["items"] == {"type": "string"}
|
||||
|
||||
|
||||
class TestTopLevelGuarantees:
|
||||
"""The returned top-level schema is always a well-formed object."""
|
||||
|
||||
|
|
|
|||
|
|
@ -466,6 +466,14 @@ Generate some audio.
|
|||
msg = build_skill_invocation_message("/nonexistent")
|
||||
assert msg is None
|
||||
|
||||
def test_returns_none_when_skill_load_fails(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
_make_skill(tmp_path, "broken-skill")
|
||||
scan_skill_commands()
|
||||
with patch("agent.skill_commands._load_skill_payload", return_value=None):
|
||||
msg = build_skill_invocation_message("/broken-skill", "do stuff")
|
||||
assert msg is None
|
||||
|
||||
def test_uses_shared_skill_loader_for_secure_setup(self, tmp_path, monkeypatch):
|
||||
monkeypatch.delenv("TENOR_API_KEY", raising=False)
|
||||
calls = []
|
||||
|
|
|
|||
|
|
@ -100,6 +100,44 @@ class TestCodexBuildKwargs:
|
|||
)
|
||||
assert "prompt_cache_key" not in kw
|
||||
|
||||
def test_xai_responses_sends_cache_key_via_extra_body(self, transport):
|
||||
"""xAI's Responses API documents ``prompt_cache_key`` as the
|
||||
body-level cache-routing key (the ``x-grok-conv-id`` header is
|
||||
Chat-Completions-only). Passing it via ``extra_body`` is robust
|
||||
against openai SDK builds whose ``Responses.stream()`` kwarg
|
||||
signature ever drops the field — the body field still serializes
|
||||
and reaches xAI either way. The ``x-grok-conv-id`` header is kept
|
||||
as a belt-and-braces fallback so cache routing survives even
|
||||
when the body field would be stripped by an intermediate proxy.
|
||||
Ref: https://docs.x.ai/developers/advanced-api-usage/prompt-caching/maximizing-cache-hits
|
||||
"""
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="grok-4.3", messages=messages, tools=[],
|
||||
session_id="conv-xai-1",
|
||||
is_xai_responses=True,
|
||||
)
|
||||
assert "prompt_cache_key" not in kw
|
||||
assert kw.get("extra_body", {}).get("prompt_cache_key") == "conv-xai-1"
|
||||
assert kw.get("extra_headers", {}).get("x-grok-conv-id") == "conv-xai-1"
|
||||
|
||||
def test_xai_responses_extra_body_preserves_caller_fields(self, transport):
|
||||
"""When the caller already supplies ``extra_body`` (e.g. via
|
||||
request_overrides), the xAI cache-key injection must merge into
|
||||
the existing dict instead of overwriting it. Caller-supplied
|
||||
``prompt_cache_key`` wins (setdefault semantics) so user overrides
|
||||
aren't silently clobbered by the transport."""
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="grok-4.3", messages=messages, tools=[],
|
||||
session_id="conv-xai-1",
|
||||
is_xai_responses=True,
|
||||
request_overrides={"extra_body": {"prompt_cache_key": "caller-override", "other_field": 42}},
|
||||
)
|
||||
eb = kw.get("extra_body", {})
|
||||
assert eb.get("prompt_cache_key") == "caller-override"
|
||||
assert eb.get("other_field") == 42
|
||||
|
||||
def test_max_tokens(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
|
|
@ -156,9 +194,16 @@ class TestCodexBuildKwargs:
|
|||
is_xai_responses=True,
|
||||
reasoning_config={"effort": "high"},
|
||||
)
|
||||
# xAI Responses must receive both encrypted reasoning content and the effort
|
||||
# xAI Responses receives reasoning.effort on the allowlisted models.
|
||||
assert kw.get("reasoning") == {"effort": "high"}
|
||||
assert "reasoning.encrypted_content" in kw.get("include", [])
|
||||
# As of May 2026 we deliberately do NOT request
|
||||
# reasoning.encrypted_content back from xAI — the OAuth/SuperGrok
|
||||
# surface rejects replayed encrypted reasoning items on turn 2+
|
||||
# (the multi-turn "Expected to have received response.created
|
||||
# before error" failure). Grok still reasons natively each turn;
|
||||
# we just don't try to thread the prior turn's encrypted blob back
|
||||
# in. See tests/run_agent/test_codex_xai_oauth_recovery.py.
|
||||
assert "reasoning.encrypted_content" not in kw.get("include", [])
|
||||
|
||||
def test_xai_reasoning_disabled_no_reasoning_key(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
|
|
@ -184,8 +229,9 @@ class TestCodexBuildKwargs:
|
|||
# api.x.ai 400s with "Model X does not support parameter reasoningEffort"
|
||||
# on grok-4 / grok-4-fast / grok-3 / grok-code-fast / grok-4.20-0309-*.
|
||||
# Those models reason natively but don't expose the dial. The transport
|
||||
# must omit the `reasoning` key for them while keeping the encrypted
|
||||
# reasoning content include so we can capture native reasoning tokens.
|
||||
# must omit the `reasoning` key for them. As of May 2026 we also no
|
||||
# longer request ``reasoning.encrypted_content`` back from xAI on ANY
|
||||
# model — see test_xai_reasoning_effort_passed for the rationale.
|
||||
|
||||
def test_xai_grok_4_omits_reasoning_effort(self, transport):
|
||||
"""grok-4 / grok-4-0709 reject reasoning.effort with HTTP 400."""
|
||||
|
|
@ -199,8 +245,9 @@ class TestCodexBuildKwargs:
|
|||
assert "reasoning" not in kw, (
|
||||
f"{model} must not receive a reasoning key (xAI rejects it)"
|
||||
)
|
||||
# Still capture native reasoning tokens
|
||||
assert "reasoning.encrypted_content" in kw.get("include", [])
|
||||
# We no longer ask xAI for encrypted_content back (see comment
|
||||
# above) — verify the include list is empty.
|
||||
assert "reasoning.encrypted_content" not in kw.get("include", [])
|
||||
|
||||
def test_xai_grok_4_fast_omits_reasoning_effort(self, transport):
|
||||
"""grok-4-fast and grok-4-1-fast variants reject reasoning.effort."""
|
||||
|
|
|
|||
104
tests/cli/test_cli_background_status_indicator.py
Normal file
104
tests/cli/test_cli_background_status_indicator.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Tests for the /background indicator in the CLI status bar.
|
||||
|
||||
The classic prompt_toolkit status bar shows `▶ N` when N tasks launched via
|
||||
`/background` are still running. Source of truth is `self._background_tasks`
|
||||
(a Dict[str, threading.Thread]); entries are removed in the task thread's
|
||||
finally block, so len() reflects truly-running tasks.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
from cli import HermesCLI
|
||||
|
||||
|
||||
def _stub_thread() -> threading.Thread:
|
||||
"""Return a Thread instance that's never started — pure dict-value stand-in."""
|
||||
return threading.Thread(target=lambda: None)
|
||||
|
||||
|
||||
def _make_cli():
|
||||
"""Bare-metal HermesCLI for snapshot/build tests (no __init__ side effects)."""
|
||||
cli_obj = HermesCLI.__new__(HermesCLI)
|
||||
cli_obj.model = "anthropic/claude-opus-4.6"
|
||||
cli_obj.agent = None
|
||||
cli_obj._background_tasks = {}
|
||||
# The snapshot reads session_start to compute duration; supply a stub.
|
||||
cli_obj.session_start = datetime.now()
|
||||
return cli_obj
|
||||
|
||||
|
||||
def test_snapshot_reports_zero_when_no_background_tasks():
|
||||
cli_obj = _make_cli()
|
||||
snap = cli_obj._get_status_bar_snapshot()
|
||||
assert snap["active_background_tasks"] == 0
|
||||
|
||||
|
||||
def test_snapshot_counts_live_background_tasks():
|
||||
cli_obj = _make_cli()
|
||||
cli_obj._background_tasks = {"bg_a": _stub_thread(), "bg_b": _stub_thread()}
|
||||
snap = cli_obj._get_status_bar_snapshot()
|
||||
assert snap["active_background_tasks"] == 2
|
||||
|
||||
|
||||
def test_snapshot_safe_when_background_tasks_attr_missing():
|
||||
"""Older HermesCLI instances (tests with __new__, etc.) may lack the attr."""
|
||||
cli_obj = HermesCLI.__new__(HermesCLI)
|
||||
cli_obj.model = "x"
|
||||
cli_obj.agent = None
|
||||
cli_obj.session_start = datetime.now()
|
||||
# No _background_tasks at all — must not raise.
|
||||
snap = cli_obj._get_status_bar_snapshot()
|
||||
assert snap["active_background_tasks"] == 0
|
||||
|
||||
|
||||
def test_plain_text_status_omits_indicator_when_idle():
|
||||
cli_obj = _make_cli()
|
||||
text = cli_obj._build_status_bar_text(width=80)
|
||||
assert "▶" not in text
|
||||
|
||||
|
||||
def test_plain_text_status_shows_indicator_when_active():
|
||||
cli_obj = _make_cli()
|
||||
cli_obj._background_tasks = {"bg_a": _stub_thread()}
|
||||
text = cli_obj._build_status_bar_text(width=80)
|
||||
assert "▶ 1" in text
|
||||
|
||||
|
||||
def test_plain_text_status_shows_higher_count():
|
||||
cli_obj = _make_cli()
|
||||
cli_obj._background_tasks = {
|
||||
"a": _stub_thread(),
|
||||
"b": _stub_thread(),
|
||||
"c": _stub_thread(),
|
||||
}
|
||||
text = cli_obj._build_status_bar_text(width=80)
|
||||
assert "▶ 3" in text
|
||||
|
||||
|
||||
def test_narrow_width_omits_bg_indicator():
|
||||
"""The narrow tier (<52) is already cramped — bg is secondary, drop it."""
|
||||
cli_obj = _make_cli()
|
||||
cli_obj._background_tasks = {"bg_a": _stub_thread()}
|
||||
text = cli_obj._build_status_bar_text(width=40)
|
||||
assert "▶" not in text
|
||||
|
||||
|
||||
def test_fragments_include_bg_segment_when_active():
|
||||
cli_obj = _make_cli()
|
||||
cli_obj._background_tasks = {"a": _stub_thread(), "b": _stub_thread()}
|
||||
cli_obj._status_bar_visible = True
|
||||
# _get_status_bar_fragments asks _get_tui_terminal_width(); stub it wide.
|
||||
cli_obj._get_tui_terminal_width = lambda: 120 # type: ignore[method-assign]
|
||||
frags = cli_obj._get_status_bar_fragments()
|
||||
rendered = "".join(text for _style, text in frags)
|
||||
assert "▶ 2" in rendered
|
||||
|
||||
|
||||
def test_fragments_omit_bg_segment_when_idle():
|
||||
cli_obj = _make_cli()
|
||||
cli_obj._status_bar_visible = True
|
||||
cli_obj._get_tui_terminal_width = lambda: 120 # type: ignore[method-assign]
|
||||
frags = cli_obj._get_status_bar_fragments()
|
||||
rendered = "".join(text for _style, text in frags)
|
||||
assert "▶" not in rendered
|
||||
119
tests/cli/test_exit_delete_session.py
Normal file
119
tests/cli/test_exit_delete_session.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
"""Tests for `/exit --delete` and `/quit --delete` session deletion.
|
||||
|
||||
Ports the behavior from google-gemini/gemini-cli#19332: running `/exit` or
|
||||
`/quit` with the `--delete` flag arms a one-shot `_delete_session_on_exit`
|
||||
flag that the CLI shutdown path uses to remove the current session from
|
||||
SQLite + on-disk transcripts before exit.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def _make_cli():
|
||||
"""Bare HermesCLI suitable for process_command() tests.
|
||||
|
||||
Uses ``__new__`` to skip the heavy __init__; only sets the attributes
|
||||
the /exit branch touches.
|
||||
"""
|
||||
from cli import HermesCLI
|
||||
cli = HermesCLI.__new__(HermesCLI)
|
||||
cli.config = {}
|
||||
cli.console = MagicMock()
|
||||
cli.agent = None
|
||||
cli.conversation_history = []
|
||||
cli.session_id = "test-session"
|
||||
cli._delete_session_on_exit = False
|
||||
return cli
|
||||
|
||||
|
||||
class TestExitDeleteFlag:
|
||||
def test_plain_exit_does_not_arm_delete(self):
|
||||
cli = _make_cli()
|
||||
result = cli.process_command("/exit")
|
||||
assert result is False
|
||||
assert cli._delete_session_on_exit is False
|
||||
|
||||
def test_plain_quit_does_not_arm_delete(self):
|
||||
cli = _make_cli()
|
||||
result = cli.process_command("/quit")
|
||||
assert result is False
|
||||
assert cli._delete_session_on_exit is False
|
||||
|
||||
def test_exit_delete_arms_flag(self):
|
||||
cli = _make_cli()
|
||||
result = cli.process_command("/exit --delete")
|
||||
assert result is False
|
||||
assert cli._delete_session_on_exit is True
|
||||
|
||||
def test_quit_delete_arms_flag(self):
|
||||
cli = _make_cli()
|
||||
result = cli.process_command("/quit --delete")
|
||||
assert result is False
|
||||
assert cli._delete_session_on_exit is True
|
||||
|
||||
def test_exit_delete_short_form(self):
|
||||
"""`-d` is a convenience alias for `--delete`."""
|
||||
cli = _make_cli()
|
||||
result = cli.process_command("/exit -d")
|
||||
assert result is False
|
||||
assert cli._delete_session_on_exit is True
|
||||
|
||||
def test_quit_alias_q_is_not_quit(self):
|
||||
"""`/q` is the alias for `/queue`, not `/quit`. This test documents
|
||||
that /q --delete does NOT arm session deletion — it would dispatch
|
||||
to /queue instead."""
|
||||
cli = _make_cli()
|
||||
cli._pending_input = __import__("queue").Queue()
|
||||
# /q with no args shows a usage error and keeps the CLI running.
|
||||
result = cli.process_command("/q")
|
||||
assert result is not False # queue command doesn't exit
|
||||
assert cli._delete_session_on_exit is False
|
||||
|
||||
def test_delete_flag_is_case_insensitive(self):
|
||||
cli = _make_cli()
|
||||
result = cli.process_command("/exit --DELETE")
|
||||
assert result is False
|
||||
assert cli._delete_session_on_exit is True
|
||||
|
||||
def test_delete_flag_trims_whitespace(self):
|
||||
cli = _make_cli()
|
||||
result = cli.process_command("/exit --delete ")
|
||||
assert result is False
|
||||
assert cli._delete_session_on_exit is True
|
||||
|
||||
def test_unknown_exit_argument_does_not_exit(self):
|
||||
"""Unrecognised args should NOT exit the CLI — they surface an
|
||||
error message and stay in the session. This prevents accidental
|
||||
session destruction from typos like `/exit -delete`."""
|
||||
cli = _make_cli()
|
||||
result = cli.process_command("/exit --delte")
|
||||
# process_command returns True = keep running
|
||||
assert result is True
|
||||
assert cli._delete_session_on_exit is False
|
||||
|
||||
def test_unknown_exit_argument_prints_help(self):
|
||||
cli = _make_cli()
|
||||
# _cprint goes through module-level print, so capture via console.
|
||||
# We can't patch _cprint directly without import juggling; the
|
||||
# previous assertion already proves the unknown-arg branch is
|
||||
# reached (result True + flag False).
|
||||
result = cli.process_command("/exit garbage")
|
||||
assert result is True
|
||||
assert cli._delete_session_on_exit is False
|
||||
|
||||
|
||||
class TestCommandRegistry:
|
||||
def test_quit_command_advertises_delete_flag(self):
|
||||
"""The CommandDef args_hint should surface `--delete` in /help and
|
||||
CLI autocomplete."""
|
||||
from hermes_cli.commands import resolve_command
|
||||
cmd = resolve_command("quit")
|
||||
assert cmd is not None
|
||||
assert cmd.args_hint == "[--delete]"
|
||||
|
||||
def test_exit_alias_resolves_to_quit_with_hint(self):
|
||||
from hermes_cli.commands import resolve_command
|
||||
cmd = resolve_command("exit")
|
||||
assert cmd is not None
|
||||
assert cmd.name == "quit"
|
||||
assert cmd.args_hint == "[--delete]"
|
||||
|
|
@ -321,6 +321,93 @@ class TestPauseResumeJob:
|
|||
assert resumed["paused_reason"] is None
|
||||
|
||||
|
||||
class TestResolveJobRef:
|
||||
"""Name-based job lookup for CLI/tool callers (PR #2627, @buntingszn)."""
|
||||
|
||||
def test_resolve_by_exact_id(self, tmp_cron_dir):
|
||||
from cron.jobs import resolve_job_ref
|
||||
|
||||
job = create_job(prompt="A", schedule="1h", name="alpha")
|
||||
assert resolve_job_ref(job["id"])["id"] == job["id"]
|
||||
|
||||
def test_resolve_by_name(self, tmp_cron_dir):
|
||||
from cron.jobs import resolve_job_ref
|
||||
|
||||
job = create_job(prompt="A", schedule="1h", name="alpha")
|
||||
assert resolve_job_ref("alpha")["id"] == job["id"]
|
||||
|
||||
def test_resolve_by_name_case_insensitive(self, tmp_cron_dir):
|
||||
from cron.jobs import resolve_job_ref
|
||||
|
||||
job = create_job(prompt="A", schedule="1h", name="MyJob")
|
||||
assert resolve_job_ref("myjob")["id"] == job["id"]
|
||||
assert resolve_job_ref("MYJOB")["id"] == job["id"]
|
||||
|
||||
def test_resolve_returns_none_when_not_found(self, tmp_cron_dir):
|
||||
from cron.jobs import resolve_job_ref
|
||||
|
||||
create_job(prompt="A", schedule="1h", name="alpha")
|
||||
assert resolve_job_ref("does-not-exist") is None
|
||||
assert resolve_job_ref("") is None
|
||||
|
||||
def test_resolve_id_wins_over_name(self, tmp_cron_dir):
|
||||
"""If a job's name happens to equal another job's ID, ID match wins."""
|
||||
from cron.jobs import resolve_job_ref
|
||||
|
||||
j1 = create_job(prompt="A", schedule="1h")
|
||||
# Create a second job whose name is j1's ID
|
||||
j2 = create_job(prompt="B", schedule="1h", name=j1["id"])
|
||||
# Looking up j1["id"] must return j1, not the colliding-name job j2
|
||||
assert resolve_job_ref(j1["id"])["id"] == j1["id"]
|
||||
assert resolve_job_ref(j1["id"])["id"] != j2["id"]
|
||||
|
||||
def test_resolve_ambiguous_name_raises(self, tmp_cron_dir):
|
||||
"""Two jobs sharing a name → refuse to pick, surface both IDs."""
|
||||
from cron.jobs import AmbiguousJobReference, resolve_job_ref
|
||||
|
||||
j1 = create_job(prompt="A", schedule="1h", name="dup")
|
||||
j2 = create_job(prompt="B", schedule="1h", name="dup")
|
||||
with pytest.raises(AmbiguousJobReference) as exc_info:
|
||||
resolve_job_ref("dup")
|
||||
ids = {m["id"] for m in exc_info.value.matches}
|
||||
assert ids == {j1["id"], j2["id"]}
|
||||
# Error message mentions both IDs so the user can pick one
|
||||
assert j1["id"] in str(exc_info.value)
|
||||
assert j2["id"] in str(exc_info.value)
|
||||
|
||||
def test_trigger_by_name(self, tmp_cron_dir):
|
||||
from cron.jobs import trigger_job
|
||||
|
||||
job = create_job(prompt="A", schedule="1h", name="alpha")
|
||||
result = trigger_job("alpha")
|
||||
assert result is not None
|
||||
assert result["id"] == job["id"]
|
||||
|
||||
def test_pause_by_name(self, tmp_cron_dir):
|
||||
job = create_job(prompt="A", schedule="1h", name="alpha")
|
||||
result = pause_job("alpha", reason="manual")
|
||||
assert result is not None
|
||||
assert result["id"] == job["id"]
|
||||
assert result["state"] == "paused"
|
||||
|
||||
def test_remove_by_name(self, tmp_cron_dir):
|
||||
job = create_job(prompt="A", schedule="1h", name="alpha")
|
||||
assert remove_job("alpha") is True
|
||||
assert get_job(job["id"]) is None
|
||||
|
||||
def test_mutations_refuse_ambiguous_name(self, tmp_cron_dir):
|
||||
"""pause/resume/trigger/remove must refuse to act on an ambiguous name."""
|
||||
from cron.jobs import AmbiguousJobReference, trigger_job
|
||||
|
||||
create_job(prompt="A", schedule="1h", name="dup")
|
||||
create_job(prompt="B", schedule="1h", name="dup")
|
||||
for fn in (pause_job, resume_job, trigger_job):
|
||||
with pytest.raises(AmbiguousJobReference):
|
||||
fn("dup")
|
||||
with pytest.raises(AmbiguousJobReference):
|
||||
remove_job("dup")
|
||||
|
||||
|
||||
class TestMarkJobRun:
|
||||
def test_increments_completed(self, tmp_cron_dir):
|
||||
job = create_job(prompt="Test", schedule="every 1h")
|
||||
|
|
|
|||
|
|
@ -66,6 +66,9 @@ def _ensure_discord_mock():
|
|||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.Forbidden = type("Forbidden", (Exception,), {})
|
||||
discord_mod.MessageType = SimpleNamespace(default=0, reply=19)
|
||||
discord_mod.Object = lambda *, id: SimpleNamespace(id=id)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
|
|
|
|||
152
tests/gateway/test_active_session_text_merge.py
Normal file
152
tests/gateway/test_active_session_text_merge.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
"""Regression test for #4469.
|
||||
|
||||
When the agent is actively running (session present in
|
||||
``adapter._active_sessions``) and the user fires off multiple TEXT
|
||||
follow-ups in rapid succession, the previous behaviour was a single-slot
|
||||
replacement at ``gateway/platforms/base.py``:
|
||||
|
||||
self._pending_messages[session_key] = event
|
||||
|
||||
So three rapid messages ``A``, ``B``, ``C`` arriving while the agent was
|
||||
still working on the initial turn produced a pending slot containing only
|
||||
``C``; ``A`` and ``B`` were silently dropped.
|
||||
|
||||
The fix routes the follow-up through ``merge_pending_message_event(...,
|
||||
merge_text=True)`` so TEXT events accumulate into the existing pending
|
||||
event's text instead of clobbering it. Photo / media bursts continue to
|
||||
merge through the same helper (they always did).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# Minimal telegram stub so importing gateway.platforms.base does not pull
|
||||
# in the real python-telegram-bot dependency.
|
||||
_tg = sys.modules.get("telegram") or types.ModuleType("telegram")
|
||||
_tg.constants = sys.modules.get("telegram.constants") or types.ModuleType("telegram.constants")
|
||||
_ct = MagicMock()
|
||||
_ct.PRIVATE = "private"
|
||||
_ct.GROUP = "group"
|
||||
_ct.SUPERGROUP = "supergroup"
|
||||
_tg.constants.ChatType = _ct
|
||||
sys.modules.setdefault("telegram", _tg)
|
||||
sys.modules.setdefault("telegram.constants", _tg.constants)
|
||||
sys.modules.setdefault("telegram.ext", types.ModuleType("telegram.ext"))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
)
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_event(text: str, chat_id: str = "12345") -> MessageEvent:
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
chat_type="dm",
|
||||
user_id="u1",
|
||||
)
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
message_id=f"msg-{text[:8]}",
|
||||
)
|
||||
|
||||
|
||||
def _make_adapter() -> BasePlatformAdapter:
|
||||
"""Build a BasePlatformAdapter without running its heavy __init__.
|
||||
|
||||
We only need the bits ``handle_message`` touches on the active-session
|
||||
path: ``_active_sessions``, ``_pending_messages``,
|
||||
``_message_handler``, ``_busy_session_handler``, ``config``, ``platform``.
|
||||
"""
|
||||
|
||||
class _DummyAdapter(BasePlatformAdapter): # type: ignore[misc]
|
||||
async def connect(self):
|
||||
pass
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return None
|
||||
|
||||
async def send(self, *args, **kwargs):
|
||||
return MagicMock(success=True, message_id="x", retryable=False)
|
||||
|
||||
adapter = object.__new__(_DummyAdapter)
|
||||
adapter.config = PlatformConfig(enabled=True, token="***")
|
||||
adapter.platform = Platform.TELEGRAM
|
||||
adapter._message_handler = AsyncMock(return_value=None)
|
||||
adapter._busy_session_handler = None
|
||||
adapter._active_sessions = {}
|
||||
adapter._pending_messages = {}
|
||||
adapter._session_tasks = {}
|
||||
adapter._background_tasks = set()
|
||||
adapter._post_delivery_callbacks = {}
|
||||
adapter._expected_cancelled_tasks = set()
|
||||
adapter._fatal_error_code = None
|
||||
adapter._fatal_error_message = None
|
||||
adapter._fatal_error_retryable = True
|
||||
adapter._fatal_error_handler = None
|
||||
adapter._running = True
|
||||
adapter._auto_tts_default = False
|
||||
adapter._auto_tts_enabled_chats = set()
|
||||
adapter._auto_tts_disabled_chats = set()
|
||||
adapter._typing_paused = set()
|
||||
return adapter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_text_followups_accumulate_instead_of_replacing():
|
||||
"""Three rapid TEXT follow-ups during an active session must all
|
||||
survive in ``adapter._pending_messages[session_key].text``."""
|
||||
adapter = _make_adapter()
|
||||
first = _make_event("part one")
|
||||
session_key = build_session_key(first.source)
|
||||
|
||||
# Mark the session as active so subsequent messages take the
|
||||
# "already running" branch in handle_message.
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
second = _make_event("part two")
|
||||
third = _make_event("part three")
|
||||
|
||||
await adapter.handle_message(second)
|
||||
await adapter.handle_message(third)
|
||||
|
||||
# Both rapid follow-ups must be preserved, not just the last one.
|
||||
pending = adapter._pending_messages[session_key]
|
||||
assert pending.text == "part two\npart three", (
|
||||
f"expected accumulated text, got {pending.text!r}"
|
||||
)
|
||||
# Interrupt event must be signalled exactly like before.
|
||||
assert adapter._active_sessions[session_key].is_set()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_followup_is_stored_as_is():
|
||||
"""One TEXT follow-up still lands as the event object itself
|
||||
(no spurious wrapping / mutation) — guards against the merge path
|
||||
breaking the simple case."""
|
||||
adapter = _make_adapter()
|
||||
first = _make_event("only one")
|
||||
session_key = build_session_key(first.source)
|
||||
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
await adapter.handle_message(first)
|
||||
|
||||
pending = adapter._pending_messages[session_key]
|
||||
assert pending is first
|
||||
assert pending.text == "only one"
|
||||
assert adapter._active_sessions[session_key].is_set()
|
||||
|
|
@ -105,6 +105,29 @@ class TestResponseStore:
|
|||
store = ResponseStore(max_size=10)
|
||||
assert store.delete("resp_missing") is False
|
||||
|
||||
def test_delete_clears_conversation_mapping(self):
|
||||
"""Deleting a response also removes conversation mappings that reference it."""
|
||||
store = ResponseStore(max_size=10)
|
||||
store.put("resp_1", {"output": "hello"})
|
||||
store.set_conversation("chat-a", "resp_1")
|
||||
assert store.get_conversation("chat-a") == "resp_1"
|
||||
store.delete("resp_1")
|
||||
assert store.get_conversation("chat-a") is None
|
||||
|
||||
def test_eviction_clears_conversation_mapping(self):
|
||||
"""LRU eviction also removes conversation mappings for evicted responses."""
|
||||
store = ResponseStore(max_size=2)
|
||||
store.put("resp_1", {"output": "one"})
|
||||
store.set_conversation("chat-a", "resp_1")
|
||||
store.put("resp_2", {"output": "two"})
|
||||
store.set_conversation("chat-b", "resp_2")
|
||||
# Adding a 3rd should evict resp_1 and its conversation mapping
|
||||
store.put("resp_3", {"output": "three"})
|
||||
assert store.get("resp_1") is None
|
||||
assert store.get_conversation("chat-a") is None
|
||||
# resp_2 mapping should still be intact
|
||||
assert store.get_conversation("chat-b") == "resp_2"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _IdempotencyCache
|
||||
|
|
@ -422,7 +445,12 @@ class TestHealthEndpoint:
|
|||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.get("/health")
|
||||
assert resp.status == 200
|
||||
assert resp.headers.get("Content-Security-Policy") == "default-src 'none'; frame-ancestors 'none'"
|
||||
assert resp.headers.get("Permissions-Policy") == "camera=(), microphone=(), geolocation=()"
|
||||
assert resp.headers.get("Strict-Transport-Security") == "max-age=31536000; includeSubDomains"
|
||||
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
|
||||
assert resp.headers.get("X-Frame-Options") == "DENY"
|
||||
assert resp.headers.get("X-XSS-Protection") == "0"
|
||||
assert resp.headers.get("Referrer-Policy") == "no-referrer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -681,6 +709,37 @@ class TestChatCompletionsEndpoint:
|
|||
assert "[DONE]" in body
|
||||
assert "Hello!" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_string_false_returns_json_completion(self, adapter):
|
||||
"""Quoted false must not route chat completions into SSE mode."""
|
||||
mock_result = {
|
||||
"final_response": "Hello! How can I help you today?",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (
|
||||
mock_result,
|
||||
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "hermes-agent",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"stream": "false",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status == 200
|
||||
assert "text/event-stream" not in resp.headers.get("Content-Type", "")
|
||||
data = await resp.json()
|
||||
assert data["object"] == "chat.completion"
|
||||
assert data["choices"][0]["message"]["content"] == mock_result["final_response"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_task_done_callback_enqueues_eos_for_chat_completions(self, adapter):
|
||||
"""Regression guard for #24451: completion callback must signal SSE EOS."""
|
||||
|
|
@ -1632,6 +1691,31 @@ class TestResponsesEndpoint:
|
|||
# The response has an ID but it shouldn't be retrievable
|
||||
assert adapter._response_store.get(data["id"]) is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_string_false_does_not_store(self, adapter):
|
||||
"""Quoted false must preserve ephemeral store=false semantics."""
|
||||
mock_result = {"final_response": "OK", "messages": [], "api_calls": 1}
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (
|
||||
mock_result,
|
||||
{"input_tokens": 0, "output_tokens": 0, "total_tokens": 0},
|
||||
)
|
||||
resp = await cli.post(
|
||||
"/v1/responses",
|
||||
json={
|
||||
"model": "hermes-agent",
|
||||
"input": "Hello",
|
||||
"store": "false",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert adapter._response_store.get(data["id"]) is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_instructions_inherited_from_previous(self, adapter):
|
||||
"""If no instructions provided, carry forward from previous response."""
|
||||
|
|
@ -1726,6 +1810,37 @@ class TestResponsesStreaming:
|
|||
assert "Hello" in body
|
||||
assert " world" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_string_false_returns_json_response(self, adapter):
|
||||
"""Quoted false must not route Responses API requests into SSE mode."""
|
||||
mock_result = {
|
||||
"final_response": "Paris is the capital of France.",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (
|
||||
mock_result,
|
||||
{"input_tokens": 0, "output_tokens": 0, "total_tokens": 0},
|
||||
)
|
||||
resp = await cli.post(
|
||||
"/v1/responses",
|
||||
json={
|
||||
"model": "hermes-agent",
|
||||
"input": "What is the capital of France?",
|
||||
"stream": "false",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status == 200
|
||||
assert "text/event-stream" not in resp.headers.get("Content-Type", "")
|
||||
data = await resp.json()
|
||||
assert data["object"] == "response"
|
||||
assert data["output"][0]["content"][0]["text"] == mock_result["final_response"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_task_done_callback_enqueues_eos_for_responses(self, adapter):
|
||||
"""Regression guard for #24451 on /v1/responses streaming path."""
|
||||
|
|
@ -2870,6 +2985,45 @@ class TestConversationParameter:
|
|||
# Conversation mapping should NOT be set since store=false
|
||||
assert adapter._response_store.get_conversation("ephemeral-chat") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_reuse_after_eviction_no_404(self, adapter):
|
||||
"""After eviction clears a conversation mapping, reusing that name starts fresh (no 404)."""
|
||||
adapter._response_store = ResponseStore(max_size=1)
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (
|
||||
{"final_response": "First", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
# Create conversation -> resp stored
|
||||
resp1 = await cli.post("/v1/responses", json={
|
||||
"input": "hello",
|
||||
"conversation": "my-chat",
|
||||
})
|
||||
assert resp1.status == 200
|
||||
|
||||
# Evict by adding another response
|
||||
mock_run.return_value = (
|
||||
{"final_response": "Other", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
await cli.post("/v1/responses", json={"input": "other"})
|
||||
|
||||
# Conversation mapping should have been cleaned by eviction
|
||||
assert adapter._response_store.get_conversation("my-chat") is None
|
||||
|
||||
# Reuse conversation name — should start fresh, not 404
|
||||
mock_run.return_value = (
|
||||
{"final_response": "Restarted", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
resp3 = await cli.post("/v1/responses", json={
|
||||
"input": "hello again",
|
||||
"conversation": "my-chat",
|
||||
})
|
||||
assert resp3.status == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# X-Hermes-Session-Id header (session continuity)
|
||||
|
|
|
|||
|
|
@ -335,6 +335,28 @@ class TestRunEvents:
|
|||
"approval_not_pending",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approval_string_false_does_not_resolve_all(self, adapter):
|
||||
"""Quoted false must not fan out approval resolution across the queue."""
|
||||
app = _create_runs_app(adapter)
|
||||
run_id = "run_bool_parse"
|
||||
adapter._run_statuses[run_id] = {"run_id": run_id, "status": "running"}
|
||||
adapter._run_approval_sessions[run_id] = "session-123"
|
||||
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
|
||||
approval_resp = await cli.post(
|
||||
f"/v1/runs/{run_id}/approval",
|
||||
json={"choice": "once", "all": "false"},
|
||||
)
|
||||
|
||||
assert approval_resp.status == 200
|
||||
mock_resolve.assert_called_once_with(
|
||||
"session-123",
|
||||
"once",
|
||||
resolve_all=False,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_events_not_found_returns_404(self, adapter):
|
||||
app = _create_runs_app(adapter)
|
||||
|
|
|
|||
|
|
@ -101,6 +101,11 @@ class TestBlueBubblesHelpers:
|
|||
adapter = _make_adapter(monkeypatch)
|
||||
assert adapter.format_message("**Hello** `world`") == "Hello world"
|
||||
|
||||
def test_format_message_preserves_underscores_in_identifiers(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
text = "Use /api_v2 with FEATURE_FLAG_NAME and config_file.json"
|
||||
assert adapter.format_message(text) == text
|
||||
|
||||
def test_strip_markdown_headers(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
assert adapter.format_message("## Heading\ntext") == "Heading\ntext"
|
||||
|
|
|
|||
|
|
@ -384,3 +384,148 @@ class TestIncomingDocumentHandling:
|
|||
assert event.message_type == MessageType.PHOTO
|
||||
assert event.media_urls == ["/tmp/cached_image.png"]
|
||||
assert event.media_types == ["image/png"]
|
||||
|
||||
|
||||
class TestAllowAnyAttachment:
|
||||
"""Cover the discord.allow_any_attachment config flag.
|
||||
|
||||
With the flag off (default), unknown file types are dropped. With it on,
|
||||
they get cached and surfaced to the agent as DOCUMENT events with
|
||||
application/octet-stream MIME so gateway/run.py emits a path-pointing
|
||||
context note.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_type_skipped_by_default(self, adapter):
|
||||
"""Default (flag off): unknown extension is dropped.
|
||||
|
||||
With no text + no cached media, the adapter may legitimately decline
|
||||
to dispatch the event at all, so we don't assert on call_args here —
|
||||
we just verify the file wasn't cached.
|
||||
"""
|
||||
with _mock_aiohttp_download(b"should not be cached"):
|
||||
msg = make_message([
|
||||
make_attachment(filename="weird.xyz", content_type="application/x-custom")
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
if adapter.handle_message.call_args is not None:
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.media_urls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_type_cached_when_flag_on(self, adapter):
|
||||
"""Flag on: unknown extension is cached as application/octet-stream."""
|
||||
adapter.config.extra["allow_any_attachment"] = True
|
||||
|
||||
with _mock_aiohttp_download(b"\x00\x01\x02 binary payload"):
|
||||
msg = make_message([
|
||||
make_attachment(filename="weird.xyz", content_type="application/x-custom")
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert len(event.media_urls) == 1
|
||||
assert os.path.exists(event.media_urls[0])
|
||||
# Falls back to the source content_type when we have one.
|
||||
assert event.media_types == ["application/x-custom"]
|
||||
assert event.message_type == MessageType.DOCUMENT
|
||||
# We deliberately do NOT inline arbitrary bytes — run.py emits the
|
||||
# path-pointing note based on DOCUMENT + octet-stream MIME.
|
||||
assert "[Content of" not in (event.text or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_type_no_content_type_becomes_octet_stream(self, adapter):
|
||||
"""Flag on + no content_type from discord: MIME falls back to octet-stream."""
|
||||
adapter.config.extra["allow_any_attachment"] = True
|
||||
|
||||
with _mock_aiohttp_download(b"raw bytes"):
|
||||
msg = make_message([
|
||||
make_attachment(filename="mystery.bin", content_type=None)
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.message_type == MessageType.DOCUMENT
|
||||
assert event.media_types == ["application/octet-stream"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_attachment_bytes_caps_uploads(self, adapter):
|
||||
"""discord.max_attachment_bytes overrides the historical 32 MiB cap."""
|
||||
adapter.config.extra["allow_any_attachment"] = True
|
||||
adapter.config.extra["max_attachment_bytes"] = 1024 # 1 KiB
|
||||
|
||||
msg = make_message([
|
||||
make_attachment(
|
||||
filename="too_big.xyz",
|
||||
content_type="application/x-custom",
|
||||
size=2048,
|
||||
)
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.media_urls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_attachment_bytes_zero_means_unlimited(self, adapter):
|
||||
"""max_attachment_bytes=0 disables the size cap entirely."""
|
||||
adapter.config.extra["allow_any_attachment"] = True
|
||||
adapter.config.extra["max_attachment_bytes"] = 0
|
||||
|
||||
# 64 MiB — would normally exceed the historical 32 MiB hardcoded cap.
|
||||
with _mock_aiohttp_download(b"x" * 16):
|
||||
msg = make_message([
|
||||
make_attachment(
|
||||
filename="huge.xyz",
|
||||
content_type="application/x-custom",
|
||||
size=64 * 1024 * 1024,
|
||||
)
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert len(event.media_urls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allowlisted_doc_unchanged_when_flag_on(self, adapter):
|
||||
"""Flag on must not change handling of types already in SUPPORTED_DOCUMENT_TYPES.
|
||||
|
||||
A .txt should still get its content inlined (the historical behavior),
|
||||
and the MIME should still be the canonical text/plain — not whatever
|
||||
discord guessed.
|
||||
"""
|
||||
adapter.config.extra["allow_any_attachment"] = True
|
||||
file_content = b"still a text file"
|
||||
|
||||
with _mock_aiohttp_download(file_content):
|
||||
msg = make_message(
|
||||
attachments=[make_attachment(filename="notes.txt", content_type="text/plain")],
|
||||
content="check this",
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "[Content of notes.txt]:" in event.text
|
||||
assert "still a text file" in event.text
|
||||
assert event.media_types == ["text/plain"]
|
||||
|
||||
def test_helper_reads_env_fallback(self, adapter, monkeypatch):
|
||||
"""Helper falls back to DISCORD_ALLOW_ANY_ATTACHMENT env var."""
|
||||
assert adapter._discord_allow_any_attachment() is False
|
||||
monkeypatch.setenv("DISCORD_ALLOW_ANY_ATTACHMENT", "true")
|
||||
assert adapter._discord_allow_any_attachment() is True
|
||||
monkeypatch.setenv("DISCORD_ALLOW_ANY_ATTACHMENT", "no")
|
||||
assert adapter._discord_allow_any_attachment() is False
|
||||
|
||||
def test_helper_config_overrides_env(self, adapter, monkeypatch):
|
||||
"""config.yaml setting wins over env var."""
|
||||
monkeypatch.setenv("DISCORD_ALLOW_ANY_ATTACHMENT", "true")
|
||||
adapter.config.extra["allow_any_attachment"] = False
|
||||
assert adapter._discord_allow_any_attachment() is False
|
||||
|
||||
def test_max_bytes_helper_invalid_value_falls_back(self, adapter):
|
||||
"""Garbage in max_attachment_bytes config falls back to 32 MiB."""
|
||||
adapter.config.extra["max_attachment_bytes"] = "not-a-number"
|
||||
assert adapter._discord_max_attachment_bytes() == 32 * 1024 * 1024
|
||||
|
||||
|
|
|
|||
122
tests/gateway/test_memory_monitor.py
Normal file
122
tests/gateway/test_memory_monitor.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""Tests for gateway.memory_monitor — periodic process memory logging.
|
||||
|
||||
Ported from cline/cline#10343. The module logs a structured
|
||||
``[MEMORY] rss=...MB ...`` line periodically so long-running gateway
|
||||
leaks show up as a time series in agent.log / gateway.log.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway import memory_monitor as mm
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _ensure_monitor_stopped():
|
||||
"""Every test starts from a clean state and leaves one behind."""
|
||||
mm.stop_memory_monitoring(timeout=1.0)
|
||||
yield
|
||||
mm.stop_memory_monitoring(timeout=1.0)
|
||||
|
||||
|
||||
def test_log_memory_usage_emits_memory_line(caplog):
|
||||
caplog.set_level(logging.INFO, logger="gateway.memory_monitor")
|
||||
mm.log_memory_usage()
|
||||
memory_lines = [r for r in caplog.records if "[MEMORY]" in r.getMessage()]
|
||||
assert memory_lines, "expected at least one [MEMORY] log record"
|
||||
|
||||
|
||||
def test_log_memory_usage_has_grep_friendly_format(caplog):
|
||||
caplog.set_level(logging.INFO, logger="gateway.memory_monitor")
|
||||
mm.log_memory_usage()
|
||||
msg = caplog.records[-1].getMessage()
|
||||
# Grep-friendly contract: line starts with [MEMORY] and carries RSS
|
||||
# (or 'unavailable'), GC counts, thread count, uptime.
|
||||
assert msg.startswith("[MEMORY]"), msg
|
||||
assert "rss=" in msg
|
||||
assert "gc=" in msg
|
||||
assert "threads=" in msg
|
||||
assert "uptime=" in msg
|
||||
|
||||
|
||||
def test_log_memory_usage_with_prefix(caplog):
|
||||
caplog.set_level(logging.INFO, logger="gateway.memory_monitor")
|
||||
mm.log_memory_usage(prefix="baseline")
|
||||
msg = caplog.records[-1].getMessage()
|
||||
assert "[MEMORY] baseline " in msg
|
||||
|
||||
|
||||
def test_start_logs_baseline_and_returns_true(caplog):
|
||||
caplog.set_level(logging.INFO, logger="gateway.memory_monitor")
|
||||
# Large interval so the background timer never fires during the test —
|
||||
# we're only checking the synchronous baseline behavior here.
|
||||
started = mm.start_memory_monitoring(interval_seconds=3600.0)
|
||||
assert started is True
|
||||
assert mm.is_running() is True
|
||||
|
||||
messages = [r.getMessage() for r in caplog.records]
|
||||
assert any("[MEMORY] baseline " in m for m in messages), messages
|
||||
assert any("Periodic memory monitoring started" in m for m in messages), messages
|
||||
|
||||
|
||||
def test_double_start_is_noop():
|
||||
assert mm.start_memory_monitoring(interval_seconds=3600.0) is True
|
||||
assert mm.start_memory_monitoring(interval_seconds=3600.0) is False
|
||||
assert mm.is_running() is True
|
||||
|
||||
|
||||
def test_stop_logs_shutdown_snapshot(caplog):
|
||||
mm.start_memory_monitoring(interval_seconds=3600.0)
|
||||
caplog.clear()
|
||||
caplog.set_level(logging.INFO, logger="gateway.memory_monitor")
|
||||
mm.stop_memory_monitoring(timeout=1.0)
|
||||
assert mm.is_running() is False
|
||||
|
||||
messages = [r.getMessage() for r in caplog.records]
|
||||
assert any("[MEMORY] shutdown " in m for m in messages), messages
|
||||
assert any("Periodic memory monitoring stopped" in m for m in messages), messages
|
||||
|
||||
|
||||
def test_stop_without_start_is_noop():
|
||||
# Must not raise, must not log shutdown snapshot.
|
||||
mm.stop_memory_monitoring(timeout=0.5)
|
||||
assert mm.is_running() is False
|
||||
|
||||
|
||||
def test_periodic_timer_fires(caplog):
|
||||
caplog.set_level(logging.INFO, logger="gateway.memory_monitor")
|
||||
# Short interval so we can observe multiple ticks inside the test budget.
|
||||
mm.start_memory_monitoring(interval_seconds=0.1)
|
||||
time.sleep(0.45)
|
||||
mm.stop_memory_monitoring(timeout=1.0)
|
||||
|
||||
periodic = [
|
||||
r for r in caplog.records
|
||||
if r.getMessage().startswith("[MEMORY] rss=") or r.getMessage().startswith("[MEMORY] rss=unavailable")
|
||||
]
|
||||
# baseline + at least 2 periodic + shutdown — but shutdown has the
|
||||
# "shutdown " prefix so it won't match the strict "[MEMORY] rss=" start.
|
||||
# We expect >= 3 bare "[MEMORY] rss=..." lines.
|
||||
assert len(periodic) >= 3, [r.getMessage() for r in caplog.records]
|
||||
|
||||
|
||||
def test_thread_is_daemon():
|
||||
mm.start_memory_monitoring(interval_seconds=3600.0)
|
||||
assert mm._monitor_thread is not None
|
||||
assert mm._monitor_thread.daemon is True, (
|
||||
"memory monitor thread must be daemon so it can never block process exit"
|
||||
)
|
||||
|
||||
|
||||
def test_unavailable_rss_warns_and_does_not_start(caplog, monkeypatch):
|
||||
# Force both backends to claim unavailable; start should bail.
|
||||
monkeypatch.setattr(mm, "_get_rss_mb", lambda: None)
|
||||
caplog.set_level(logging.WARNING, logger="gateway.memory_monitor")
|
||||
started = mm.start_memory_monitoring(interval_seconds=3600.0)
|
||||
assert started is False
|
||||
assert mm.is_running() is False
|
||||
assert any("Memory monitoring unavailable" in r.getMessage() for r in caplog.records)
|
||||
|
|
@ -294,15 +294,63 @@ class TestPlatformReconnectWatcher:
|
|||
assert runner._failed_platforms[Platform.TELEGRAM]["attempts"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_gives_up_after_max_attempts(self):
|
||||
"""After max attempts, platform should be removed from retry queue."""
|
||||
async def test_reconnect_pauses_after_circuit_breaker_threshold(self):
|
||||
"""After enough consecutive retryable failures, the watcher should
|
||||
*pause* the platform (keep it in the queue but stop hammering it),
|
||||
not drop it. The user resumes via /platform resume.
|
||||
"""
|
||||
runner = _make_runner()
|
||||
|
||||
platform_config = PlatformConfig(enabled=True, token="test")
|
||||
# 9 prior attempts — the next failure will be the 10th and should
|
||||
# trip the circuit breaker.
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": platform_config,
|
||||
"attempts": 9,
|
||||
"next_retry": time.monotonic() - 1,
|
||||
}
|
||||
|
||||
fail_adapter = StubAdapter(
|
||||
succeed=False, fatal_error="DNS failure", fatal_retryable=True
|
||||
)
|
||||
real_sleep = asyncio.sleep
|
||||
|
||||
with patch.object(runner, "_create_adapter", return_value=fail_adapter):
|
||||
async def run_one_iteration():
|
||||
runner._running = True
|
||||
call_count = 0
|
||||
|
||||
async def fake_sleep(n):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 1:
|
||||
runner._running = False
|
||||
await real_sleep(0)
|
||||
|
||||
with patch("asyncio.sleep", side_effect=fake_sleep):
|
||||
await runner._platform_reconnect_watcher()
|
||||
|
||||
await run_one_iteration()
|
||||
|
||||
# Platform stays in queue — paused, not dropped
|
||||
assert Platform.TELEGRAM in runner._failed_platforms
|
||||
info = runner._failed_platforms[Platform.TELEGRAM]
|
||||
assert info["paused"] is True
|
||||
assert info["attempts"] == 10
|
||||
assert "pause_reason" in info
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_skips_paused_platforms(self):
|
||||
"""A paused platform should not be retried by the watcher tick."""
|
||||
runner = _make_runner()
|
||||
|
||||
platform_config = PlatformConfig(enabled=True, token="test")
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": platform_config,
|
||||
"attempts": 20, # At max
|
||||
"next_retry": time.monotonic() - 1,
|
||||
"attempts": 10,
|
||||
"next_retry": time.monotonic() - 1, # would normally retry now
|
||||
"paused": True,
|
||||
"pause_reason": "paused via /platform pause",
|
||||
}
|
||||
|
||||
real_sleep = asyncio.sleep
|
||||
|
|
@ -324,8 +372,10 @@ class TestPlatformReconnectWatcher:
|
|||
|
||||
await run_one_iteration()
|
||||
|
||||
assert Platform.TELEGRAM not in runner._failed_platforms
|
||||
mock_create.assert_not_called() # Should give up without trying
|
||||
# Paused platform stays queued and was never touched
|
||||
assert Platform.TELEGRAM in runner._failed_platforms
|
||||
assert runner._failed_platforms[Platform.TELEGRAM]["paused"] is True
|
||||
mock_create.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_skips_when_not_time_yet(self):
|
||||
|
|
@ -459,11 +509,12 @@ class TestRuntimeDisconnectQueuing:
|
|||
assert Platform.TELEGRAM not in runner._failed_platforms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retryable_error_exits_for_service_restart_when_all_down(self):
|
||||
"""Gateway should exit with failure when all platforms fail with retryable errors.
|
||||
|
||||
This lets systemd Restart=on-failure restart the process, which is more
|
||||
reliable than in-process background reconnection after exhausted retries.
|
||||
async def test_retryable_error_keeps_gateway_alive_when_all_down(self):
|
||||
"""When all adapters fail at runtime with retryable errors, the
|
||||
gateway should stay alive and let the reconnect watcher recover them
|
||||
in the background. (Previously this exited-with-failure to trigger
|
||||
a systemd restart — that converted transient outages into infinite
|
||||
restart loops and killed in-process state.)
|
||||
"""
|
||||
runner = _make_runner()
|
||||
runner.stop = AsyncMock()
|
||||
|
|
@ -474,9 +525,9 @@ class TestRuntimeDisconnectQueuing:
|
|||
|
||||
await runner._handle_adapter_fatal_error(adapter)
|
||||
|
||||
# stop() SHOULD be called — gateway exits for systemd restart
|
||||
runner.stop.assert_called_once()
|
||||
assert runner._exit_with_failure is True
|
||||
# stop() should NOT be called — gateway stays alive for the watcher
|
||||
runner.stop.assert_not_called()
|
||||
assert runner._exit_with_failure is False
|
||||
assert Platform.TELEGRAM in runner._failed_platforms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -512,3 +563,154 @@ class TestRuntimeDisconnectQueuing:
|
|||
await runner._handle_adapter_fatal_error(adapter)
|
||||
|
||||
runner.stop.assert_called_once()
|
||||
|
||||
|
||||
# --- Pause / resume circuit breaker ---
|
||||
|
||||
|
||||
class TestPauseResume:
|
||||
"""Test the per-platform pause/resume helpers and slash command."""
|
||||
|
||||
def test_pause_marks_platform_paused(self):
|
||||
runner = _make_runner()
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": PlatformConfig(enabled=True, token="t"),
|
||||
"attempts": 3,
|
||||
"next_retry": time.monotonic() + 30,
|
||||
}
|
||||
runner._pause_failed_platform(Platform.TELEGRAM, reason="manual")
|
||||
info = runner._failed_platforms[Platform.TELEGRAM]
|
||||
assert info["paused"] is True
|
||||
assert info["pause_reason"] == "manual"
|
||||
assert info["next_retry"] == float("inf")
|
||||
|
||||
def test_pause_is_idempotent(self):
|
||||
runner = _make_runner()
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": PlatformConfig(enabled=True, token="t"),
|
||||
"attempts": 3,
|
||||
"next_retry": time.monotonic() + 30,
|
||||
"paused": True,
|
||||
"pause_reason": "first reason",
|
||||
}
|
||||
runner._pause_failed_platform(Platform.TELEGRAM, reason="second reason")
|
||||
# Reason should not be overwritten on a second pause call.
|
||||
assert (
|
||||
runner._failed_platforms[Platform.TELEGRAM]["pause_reason"]
|
||||
== "first reason"
|
||||
)
|
||||
|
||||
def test_pause_no_op_when_platform_not_queued(self):
|
||||
runner = _make_runner()
|
||||
# No exception even when the platform isn't in _failed_platforms.
|
||||
runner._pause_failed_platform(Platform.TELEGRAM, reason="x")
|
||||
assert Platform.TELEGRAM not in runner._failed_platforms
|
||||
|
||||
def test_resume_clears_paused_and_resets_attempts(self):
|
||||
runner = _make_runner()
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": PlatformConfig(enabled=True, token="t"),
|
||||
"attempts": 10,
|
||||
"next_retry": float("inf"),
|
||||
"paused": True,
|
||||
"pause_reason": "auto-paused",
|
||||
}
|
||||
assert runner._resume_paused_platform(Platform.TELEGRAM) is True
|
||||
info = runner._failed_platforms[Platform.TELEGRAM]
|
||||
assert info["paused"] is False
|
||||
assert info["attempts"] == 0
|
||||
assert info["next_retry"] != float("inf")
|
||||
assert "pause_reason" not in info
|
||||
|
||||
def test_resume_returns_false_when_not_paused(self):
|
||||
runner = _make_runner()
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": PlatformConfig(enabled=True, token="t"),
|
||||
"attempts": 1,
|
||||
"next_retry": time.monotonic() + 30,
|
||||
}
|
||||
assert runner._resume_paused_platform(Platform.TELEGRAM) is False
|
||||
|
||||
def test_resume_returns_false_when_not_queued(self):
|
||||
runner = _make_runner()
|
||||
assert runner._resume_paused_platform(Platform.TELEGRAM) is False
|
||||
|
||||
|
||||
class TestPlatformSlashCommand:
|
||||
"""Test the /platform list|pause|resume slash command handler."""
|
||||
|
||||
def _make_event(self, content: str):
|
||||
ev = MagicMock()
|
||||
ev.content = content
|
||||
return ev
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_shows_connected_and_paused(self):
|
||||
runner = _make_runner()
|
||||
runner.adapters[Platform.DISCORD] = StubAdapter(platform=Platform.DISCORD)
|
||||
runner._failed_platforms[Platform.WHATSAPP] = {
|
||||
"config": PlatformConfig(enabled=True, token="t"),
|
||||
"attempts": 10,
|
||||
"next_retry": float("inf"),
|
||||
"paused": True,
|
||||
"pause_reason": "not paired",
|
||||
}
|
||||
out = await runner._handle_platform_command(self._make_event("/platform list"))
|
||||
assert "discord" in out
|
||||
assert "whatsapp" in out
|
||||
assert "PAUSED" in out
|
||||
assert "not paired" in out
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_command_pauses_queued_platform(self):
|
||||
runner = _make_runner()
|
||||
runner._failed_platforms[Platform.WHATSAPP] = {
|
||||
"config": PlatformConfig(enabled=True, token="t"),
|
||||
"attempts": 2,
|
||||
"next_retry": time.monotonic() + 30,
|
||||
}
|
||||
out = await runner._handle_platform_command(
|
||||
self._make_event("/platform pause whatsapp")
|
||||
)
|
||||
assert "paused" in out.lower()
|
||||
assert runner._failed_platforms[Platform.WHATSAPP]["paused"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_rejects_unqueued_platform(self):
|
||||
runner = _make_runner()
|
||||
out = await runner._handle_platform_command(
|
||||
self._make_event("/platform pause whatsapp")
|
||||
)
|
||||
assert "not in the retry queue" in out
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_command_resumes_paused_platform(self):
|
||||
runner = _make_runner()
|
||||
runner._failed_platforms[Platform.WHATSAPP] = {
|
||||
"config": PlatformConfig(enabled=True, token="t"),
|
||||
"attempts": 10,
|
||||
"next_retry": float("inf"),
|
||||
"paused": True,
|
||||
"pause_reason": "x",
|
||||
}
|
||||
out = await runner._handle_platform_command(
|
||||
self._make_event("/platform resume whatsapp")
|
||||
)
|
||||
assert "resumed" in out.lower()
|
||||
assert runner._failed_platforms[Platform.WHATSAPP]["paused"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_platform_name(self):
|
||||
runner = _make_runner()
|
||||
out = await runner._handle_platform_command(
|
||||
self._make_event("/platform pause notarealplatform")
|
||||
)
|
||||
assert "Unknown platform" in out
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_platform_shows_usage_with_list(self):
|
||||
# An empty /platform call defaults to "list".
|
||||
runner = _make_runner()
|
||||
out = await runner._handle_platform_command(self._make_event("/platform"))
|
||||
assert "Gateway platforms" in out
|
||||
|
||||
|
|
|
|||
|
|
@ -68,7 +68,11 @@ async def test_runner_requests_clean_exit_for_nonretryable_startup_conflict(monk
|
|||
@pytest.mark.asyncio
|
||||
async def test_runner_queues_retryable_runtime_fatal_for_reconnection(monkeypatch, tmp_path):
|
||||
"""Retryable runtime fatal errors queue the platform for reconnection
|
||||
instead of shutting down the gateway."""
|
||||
AND keep the gateway alive — the background reconnect watcher recovers
|
||||
the platform when the underlying issue clears. (Previously this
|
||||
exited-with-failure to trigger a systemd restart; that converted
|
||||
transient failures into infinite restart loops.)
|
||||
"""
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.WHATSAPP: PlatformConfig(enabled=True, token="token")
|
||||
|
|
@ -89,8 +93,8 @@ async def test_runner_queues_retryable_runtime_fatal_for_reconnection(monkeypatc
|
|||
|
||||
await runner._handle_adapter_fatal_error(adapter)
|
||||
|
||||
# Should shut down with failure — systemd Restart=on-failure will restart
|
||||
runner.stop.assert_awaited_once()
|
||||
assert runner._exit_with_failure is True
|
||||
# Gateway stays alive — watcher will retry in background
|
||||
runner.stop.assert_not_awaited()
|
||||
assert runner._exit_with_failure is False
|
||||
assert Platform.WHATSAPP in runner._failed_platforms
|
||||
assert runner._failed_platforms[Platform.WHATSAPP]["attempts"] == 0
|
||||
|
|
|
|||
|
|
@ -64,7 +64,14 @@ class _SuccessfulAdapter(BasePlatformAdapter):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch, tmp_path):
|
||||
async def test_runner_stays_alive_for_retryable_startup_errors(monkeypatch, tmp_path):
|
||||
"""Retryable startup errors should leave the gateway running in
|
||||
degraded mode so the reconnect watcher can recover the platform when
|
||||
the underlying problem clears. Previously this returned False from
|
||||
``start()`` and exited the process, which converted a single broken
|
||||
platform (e.g. unpaired WhatsApp, DNS blip on Telegram) into a
|
||||
systemd restart loop and killed cron jobs in the meantime.
|
||||
"""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
|
|
@ -78,11 +85,13 @@ async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch,
|
|||
|
||||
ok = await runner.start()
|
||||
|
||||
assert ok is False
|
||||
# Gateway stays alive in degraded mode; reconnect watcher takes over.
|
||||
assert ok is True
|
||||
assert runner.should_exit_cleanly is False
|
||||
state = read_runtime_status()
|
||||
assert state["gateway_state"] == "startup_failed"
|
||||
assert "temporary DNS resolution failure" in state["exit_reason"]
|
||||
assert state["gateway_state"] in {"degraded", "running"}
|
||||
# Telegram was queued for retry, not given up on.
|
||||
assert Platform.TELEGRAM in runner._failed_platforms
|
||||
assert state["platforms"]["telegram"]["state"] == "retrying"
|
||||
assert state["platforms"]["telegram"]["error_code"] == "telegram_connect_error"
|
||||
|
||||
|
|
|
|||
|
|
@ -205,3 +205,78 @@ class TestResetPolicyNotify:
|
|||
assert restored.notify == original.notify
|
||||
assert restored.notify_exclude_platforms == original.notify_exclude_platforms
|
||||
assert restored.mode == original.mode
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SessionEntry to_dict / from_dict roundtrip for auto-reset fields
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSessionEntryAutoResetRoundtrip:
|
||||
def test_was_auto_reset_persists_across_roundtrip(self, tmp_path):
|
||||
"""was_auto_reset=True survives to_dict() → from_dict() (gateway restart)."""
|
||||
store = _make_store(
|
||||
SessionResetPolicy(mode="idle", idle_minutes=1),
|
||||
tmp_path,
|
||||
)
|
||||
source = _make_source()
|
||||
|
||||
entry = store.get_or_create_session(source)
|
||||
entry.updated_at = datetime.now() - timedelta(minutes=5)
|
||||
store._save()
|
||||
|
||||
entry2 = store.get_or_create_session(source)
|
||||
assert entry2.was_auto_reset is True
|
||||
assert entry2.auto_reset_reason == "idle"
|
||||
assert entry2.session_id != entry.session_id
|
||||
|
||||
# Simulate gateway restart: reload from disk
|
||||
store._loaded = False
|
||||
store._entries.clear()
|
||||
store._ensure_loaded()
|
||||
|
||||
reloaded = store._entries.get(entry2.session_key)
|
||||
assert reloaded is not None
|
||||
assert reloaded.was_auto_reset is True
|
||||
assert reloaded.auto_reset_reason == "idle"
|
||||
|
||||
def test_reset_had_activity_persists_across_roundtrip(self, tmp_path):
|
||||
"""reset_had_activity survives to_dict() → from_dict() (gateway restart)."""
|
||||
store = _make_store(
|
||||
SessionResetPolicy(mode="idle", idle_minutes=1),
|
||||
tmp_path,
|
||||
)
|
||||
source = _make_source()
|
||||
|
||||
entry = store.get_or_create_session(source)
|
||||
entry.total_tokens = 1000
|
||||
entry.updated_at = datetime.now() - timedelta(minutes=5)
|
||||
store._save()
|
||||
|
||||
entry2 = store.get_or_create_session(source)
|
||||
assert entry2.reset_had_activity is True
|
||||
|
||||
store._loaded = False
|
||||
store._entries.clear()
|
||||
store._ensure_loaded()
|
||||
|
||||
reloaded = store._entries.get(entry2.session_key)
|
||||
assert reloaded is not None
|
||||
assert reloaded.reset_had_activity is True
|
||||
|
||||
def test_auto_reset_reason_none_roundtrip(self, tmp_path):
|
||||
"""auto_reset_reason=None (no reset) survives roundtrip cleanly."""
|
||||
store = _make_store(tmp_path=tmp_path)
|
||||
source = _make_source()
|
||||
|
||||
entry = store.get_or_create_session(source)
|
||||
assert entry.was_auto_reset is False
|
||||
|
||||
store._loaded = False
|
||||
store._entries.clear()
|
||||
store._ensure_loaded()
|
||||
|
||||
reloaded = store._entries.get(entry.session_key)
|
||||
assert reloaded is not None
|
||||
assert reloaded.was_auto_reset is False
|
||||
assert reloaded.auto_reset_reason is None
|
||||
assert reloaded.reset_had_activity is False
|
||||
|
|
|
|||
|
|
@ -1794,3 +1794,162 @@ class TestSignalContentlessEnvelope:
|
|||
|
||||
assert "event" in captured, "Normal message should NOT be skipped"
|
||||
assert captured["event"].text == "hello world"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Envelope handling — group routing (legacy groupInfo vs modern groupV2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalGroupV2Routing:
|
||||
"""Regression coverage for groupV2 envelope handling.
|
||||
|
||||
signal-cli's JSON-RPC ``subscribeReceive`` envelope shape has drifted across
|
||||
versions: some forward the underlying libsignal V2 envelope as
|
||||
``dataMessage.groupV2.id`` while older / normalized paths still use
|
||||
``dataMessage.groupInfo.groupId``. The adapter must read groupV2 first and
|
||||
fall back to groupInfo so V2-only groups aren't misrouted as DMs.
|
||||
|
||||
Ported from qwibitai/nanoclaw#1962 (V2 adapter improvements).
|
||||
"""
|
||||
|
||||
def _base_envelope(self, data_message: dict) -> dict:
|
||||
return {
|
||||
"envelope": {
|
||||
"sourceNumber": "+15559998888",
|
||||
"sourceUuid": "uuid-sender",
|
||||
"sourceName": "Alice",
|
||||
"timestamp": 1700000000000,
|
||||
"dataMessage": data_message,
|
||||
}
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_v2_id_routes_as_group(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch, group_allowed="*")
|
||||
captured = []
|
||||
|
||||
async def _capture(event):
|
||||
captured.append(event)
|
||||
|
||||
adapter.handle_message = _capture
|
||||
|
||||
env = self._base_envelope({
|
||||
"message": "hello v2",
|
||||
"groupV2": {"id": "v2group=="},
|
||||
})
|
||||
|
||||
await adapter._handle_envelope(env)
|
||||
|
||||
assert len(captured) == 1
|
||||
assert captured[0].source.chat_id == "group:v2group=="
|
||||
assert captured[0].source.chat_type == "group"
|
||||
assert captured[0].text == "hello v2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_group_info_still_works(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch, group_allowed="*")
|
||||
captured = []
|
||||
|
||||
async def _capture(event):
|
||||
captured.append(event)
|
||||
|
||||
adapter.handle_message = _capture
|
||||
|
||||
env = self._base_envelope({
|
||||
"message": "hello v1",
|
||||
"groupInfo": {"groupId": "legacy=="},
|
||||
})
|
||||
|
||||
await adapter._handle_envelope(env)
|
||||
|
||||
assert len(captured) == 1
|
||||
assert captured[0].source.chat_id == "group:legacy=="
|
||||
assert captured[0].source.chat_type == "group"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_v2_preferred_over_group_info(self, monkeypatch):
|
||||
"""When both fields are present, groupV2 wins — it's the authoritative V2 id."""
|
||||
adapter = _make_signal_adapter(monkeypatch, group_allowed="*")
|
||||
captured = []
|
||||
|
||||
async def _capture(event):
|
||||
captured.append(event)
|
||||
|
||||
adapter.handle_message = _capture
|
||||
|
||||
env = self._base_envelope({
|
||||
"message": "hello",
|
||||
"groupV2": {"id": "v2=="},
|
||||
"groupInfo": {"groupId": "v1=="},
|
||||
})
|
||||
|
||||
await adapter._handle_envelope(env)
|
||||
|
||||
assert len(captured) == 1
|
||||
assert captured[0].source.chat_id == "group:v2=="
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_group_fields_routes_as_dm(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
captured = []
|
||||
|
||||
async def _capture(event):
|
||||
captured.append(event)
|
||||
|
||||
adapter.handle_message = _capture
|
||||
|
||||
env = self._base_envelope({"message": "direct message"})
|
||||
|
||||
await adapter._handle_envelope(env)
|
||||
|
||||
assert len(captured) == 1
|
||||
assert captured[0].source.chat_type == "dm"
|
||||
assert captured[0].source.chat_id == "+15559998888"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_v2_respects_allowlist(self, monkeypatch):
|
||||
"""V2 group ids flow through the same SIGNAL_GROUP_ALLOWED_USERS filter."""
|
||||
adapter = _make_signal_adapter(monkeypatch, group_allowed="allowed-v2==")
|
||||
captured = []
|
||||
|
||||
async def _capture(event):
|
||||
captured.append(event)
|
||||
|
||||
adapter.handle_message = _capture
|
||||
|
||||
# Blocked group (not in allowlist)
|
||||
await adapter._handle_envelope(self._base_envelope({
|
||||
"message": "blocked",
|
||||
"groupV2": {"id": "blocked-v2=="},
|
||||
}))
|
||||
assert len(captured) == 0
|
||||
|
||||
# Allowed group
|
||||
await adapter._handle_envelope(self._base_envelope({
|
||||
"message": "allowed",
|
||||
"groupV2": {"id": "allowed-v2=="},
|
||||
}))
|
||||
assert len(captured) == 1
|
||||
assert captured[0].source.chat_id == "group:allowed-v2=="
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_group_fields_fall_through_to_dm(self, monkeypatch):
|
||||
"""Non-dict groupV2 / groupInfo shouldn't crash — treat as DM."""
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
captured = []
|
||||
|
||||
async def _capture(event):
|
||||
captured.append(event)
|
||||
|
||||
adapter.handle_message = _capture
|
||||
|
||||
env = self._base_envelope({
|
||||
"message": "malformed",
|
||||
"groupV2": "not-a-dict",
|
||||
"groupInfo": 42,
|
||||
})
|
||||
|
||||
await adapter._handle_envelope(env)
|
||||
|
||||
assert len(captured) == 1
|
||||
assert captured[0].source.chat_type == "dm"
|
||||
|
|
|
|||
347
tests/gateway/test_simplex_plugin.py
Normal file
347
tests/gateway/test_simplex_plugin.py
Normal file
|
|
@ -0,0 +1,347 @@
|
|||
"""Tests for the SimpleX Chat platform-plugin adapter.
|
||||
|
||||
Loaded via the ``_plugin_adapter_loader`` helper so this lives under
|
||||
``plugin_adapter_simplex`` in ``sys.modules`` and cannot collide with
|
||||
sibling platform-plugin tests on the same xdist worker.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.gateway._plugin_adapter_loader import load_plugin_adapter
|
||||
|
||||
_simplex = load_plugin_adapter("simplex")
|
||||
|
||||
SimplexAdapter = _simplex.SimplexAdapter
|
||||
check_requirements = _simplex.check_requirements
|
||||
validate_config = _simplex.validate_config
|
||||
is_connected = _simplex.is_connected
|
||||
register = _simplex.register
|
||||
_env_enablement = _simplex._env_enablement
|
||||
_standalone_send = _simplex._standalone_send
|
||||
_guess_extension = _simplex._guess_extension
|
||||
_is_image_ext = _simplex._is_image_ext
|
||||
_is_audio_ext = _simplex._is_audio_ext
|
||||
_CORR_PREFIX = _simplex._CORR_PREFIX
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Platform enum (plugin-discovered, not bundled)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_platform_enum_resolves_via_plugin_scan():
|
||||
"""The plugin filesystem scan should expose Platform("simplex")."""
|
||||
from gateway.config import Platform
|
||||
p = Platform("simplex")
|
||||
assert p.value == "simplex"
|
||||
# Identity stability — repeated lookups return the same pseudo-member
|
||||
assert Platform("simplex") is p
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. check_requirements / validate_config / is_connected
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_check_requirements_needs_url(monkeypatch):
|
||||
monkeypatch.delenv("SIMPLEX_WS_URL", raising=False)
|
||||
assert check_requirements() is False
|
||||
|
||||
|
||||
def test_check_requirements_true_when_configured(monkeypatch):
|
||||
monkeypatch.setenv("SIMPLEX_WS_URL", "ws://127.0.0.1:5225")
|
||||
# websockets is a dev dep in this repo via the test plugins; the
|
||||
# check_requirements() gate also asserts the package imports.
|
||||
websockets_present = True
|
||||
try:
|
||||
import websockets # noqa: F401
|
||||
except ImportError:
|
||||
websockets_present = False
|
||||
assert check_requirements() is websockets_present
|
||||
|
||||
|
||||
def test_validate_config_uses_env_or_extra():
|
||||
from gateway.config import PlatformConfig
|
||||
# Empty extra + no env → invalid
|
||||
cfg = PlatformConfig(enabled=True)
|
||||
assert validate_config(cfg) is False
|
||||
# extra-only path → valid
|
||||
cfg2 = PlatformConfig(enabled=True, extra={"ws_url": "ws://localhost:5225"})
|
||||
assert validate_config(cfg2) is True
|
||||
|
||||
|
||||
def test_is_connected_mirrors_validate(monkeypatch):
|
||||
from gateway.config import PlatformConfig
|
||||
monkeypatch.delenv("SIMPLEX_WS_URL", raising=False)
|
||||
cfg = PlatformConfig(enabled=True, extra={"ws_url": "ws://x"})
|
||||
assert is_connected(cfg) is True
|
||||
assert is_connected(PlatformConfig(enabled=True)) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. _env_enablement seeds PlatformConfig.extra
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_env_enablement_none_when_unset(monkeypatch):
|
||||
monkeypatch.delenv("SIMPLEX_WS_URL", raising=False)
|
||||
assert _env_enablement() is None
|
||||
|
||||
|
||||
def test_env_enablement_seeds_ws_url(monkeypatch):
|
||||
monkeypatch.setenv("SIMPLEX_WS_URL", "ws://127.0.0.1:5225")
|
||||
monkeypatch.delenv("SIMPLEX_HOME_CHANNEL", raising=False)
|
||||
seed = _env_enablement()
|
||||
assert seed == {"ws_url": "ws://127.0.0.1:5225"}
|
||||
|
||||
|
||||
def test_env_enablement_seeds_home_channel(monkeypatch):
|
||||
monkeypatch.setenv("SIMPLEX_WS_URL", "ws://127.0.0.1:5225")
|
||||
monkeypatch.setenv("SIMPLEX_HOME_CHANNEL", "42")
|
||||
monkeypatch.setenv("SIMPLEX_HOME_CHANNEL_NAME", "Personal")
|
||||
seed = _env_enablement()
|
||||
assert seed["home_channel"] == {"chat_id": "42", "name": "Personal"}
|
||||
|
||||
|
||||
def test_env_enablement_home_channel_defaults_name_to_id(monkeypatch):
|
||||
monkeypatch.setenv("SIMPLEX_WS_URL", "ws://127.0.0.1:5225")
|
||||
monkeypatch.setenv("SIMPLEX_HOME_CHANNEL", "42")
|
||||
monkeypatch.delenv("SIMPLEX_HOME_CHANNEL_NAME", raising=False)
|
||||
seed = _env_enablement()
|
||||
assert seed["home_channel"] == {"chat_id": "42", "name": "42"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Adapter init
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_adapter_init_custom_url():
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(enabled=True, extra={"ws_url": "ws://localhost:5225"})
|
||||
adapter = SimplexAdapter(cfg)
|
||||
assert adapter.ws_url == "ws://localhost:5225"
|
||||
assert adapter._running is False
|
||||
assert adapter._ws is None
|
||||
|
||||
|
||||
def test_adapter_init_default_url():
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(enabled=True)
|
||||
adapter = SimplexAdapter(cfg)
|
||||
assert adapter.ws_url == "ws://127.0.0.1:5225"
|
||||
|
||||
|
||||
def test_adapter_platform_identity():
|
||||
"""Adapter should expose Platform("simplex") identity."""
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
cfg = PlatformConfig(enabled=True)
|
||||
adapter = SimplexAdapter(cfg)
|
||||
assert adapter.platform is Platform("simplex")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Helper functions (magic-byte detection)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_guess_extension_png():
|
||||
assert _guess_extension(b"\x89PNG\r\n\x1a\n") == ".png"
|
||||
|
||||
|
||||
def test_guess_extension_jpg():
|
||||
assert _guess_extension(b"\xff\xd8\xff\xe0") == ".jpg"
|
||||
|
||||
|
||||
def test_guess_extension_ogg():
|
||||
assert _guess_extension(b"OggS\x00\x02") == ".ogg"
|
||||
|
||||
|
||||
def test_guess_extension_unknown():
|
||||
assert _guess_extension(b"\x00\x01\x02\x03") == ".bin"
|
||||
|
||||
|
||||
def test_is_image_ext():
|
||||
assert _is_image_ext(".png") is True
|
||||
assert _is_image_ext(".webp") is True
|
||||
assert _is_image_ext(".ogg") is False
|
||||
|
||||
|
||||
def test_is_audio_ext():
|
||||
assert _is_audio_ext(".ogg") is True
|
||||
assert _is_audio_ext(".mp3") is True
|
||||
assert _is_audio_ext(".pdf") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. Correlation IDs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_corr_id_starts_with_prefix_and_tracks_pending():
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(enabled=True, extra={"ws_url": "ws://localhost:5225"})
|
||||
adapter = SimplexAdapter(cfg)
|
||||
corr_id = adapter._make_corr_id()
|
||||
assert corr_id.startswith(_CORR_PREFIX)
|
||||
assert corr_id in adapter._pending_corr_ids
|
||||
|
||||
|
||||
def test_corr_id_pending_set_self_trims():
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(enabled=True, extra={"ws_url": "ws://localhost:5225"})
|
||||
adapter = SimplexAdapter(cfg)
|
||||
adapter._max_pending_corr = 4
|
||||
for _ in range(10):
|
||||
adapter._make_corr_id()
|
||||
# After many additions, the pending set should be bounded by the trim
|
||||
# logic — at most one trim window above the cap.
|
||||
assert len(adapter._pending_corr_ids) <= adapter._max_pending_corr + 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. Outbound send (mocked WS)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_dm():
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(enabled=True, extra={"ws_url": "ws://localhost:5225"})
|
||||
adapter = SimplexAdapter(cfg)
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
adapter._ws = mock_ws
|
||||
|
||||
result = await adapter.send("contact-42", "Hello, SimpleX!")
|
||||
mock_ws.send.assert_called_once()
|
||||
payload = json.loads(mock_ws.send.call_args[0][0])
|
||||
assert payload["cmd"] == "@[contact-42] Hello, SimpleX!"
|
||||
assert payload["corrId"].startswith(_CORR_PREFIX)
|
||||
assert result.success is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_group():
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(enabled=True, extra={"ws_url": "ws://localhost:5225"})
|
||||
adapter = SimplexAdapter(cfg)
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
adapter._ws = mock_ws
|
||||
|
||||
result = await adapter.send("group:grp-99", "Hello, group!")
|
||||
payload = json.loads(mock_ws.send.call_args[0][0])
|
||||
assert payload["cmd"] == "#[grp-99] Hello, group!"
|
||||
assert result.success is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_when_ws_not_connected_does_not_crash():
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(enabled=True, extra={"ws_url": "ws://localhost:5225"})
|
||||
adapter = SimplexAdapter(cfg)
|
||||
# No _ws assigned — _send_ws should drop quietly
|
||||
result = await adapter.send("contact-42", "hi")
|
||||
assert result.success is True # send() always returns success — fire-and-forget
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Inbound: filter own-echo by corrId prefix
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_event_filters_own_corr_id():
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(enabled=True, extra={"ws_url": "ws://localhost:5225"})
|
||||
adapter = SimplexAdapter(cfg)
|
||||
# Pretend we sent a command with this corrId
|
||||
own = adapter._make_corr_id()
|
||||
handler_mock = AsyncMock()
|
||||
adapter._handle_new_chat_item = handler_mock # type: ignore
|
||||
|
||||
await adapter._handle_event({"corrId": own, "type": "newChatItem"})
|
||||
handler_mock.assert_not_called()
|
||||
assert own not in adapter._pending_corr_ids # discarded
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 9. Standalone (out-of-process) send for cron
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_standalone_send_missing_websockets(monkeypatch):
|
||||
"""When websockets is unimportable, return a clean error dict.
|
||||
|
||||
Implementation detail: the standalone path does ``import websockets``
|
||||
inside the function body. We simulate the package being absent by
|
||||
pulling it out of ``sys.modules`` and pointing the finder at None.
|
||||
"""
|
||||
import sys
|
||||
saved_websockets = sys.modules.pop("websockets", None)
|
||||
saved_meta = list(sys.meta_path)
|
||||
|
||||
class _Blocker:
|
||||
@staticmethod
|
||||
def find_spec(name, path=None, target=None):
|
||||
if name == "websockets" or name.startswith("websockets."):
|
||||
raise ImportError("websockets blocked for test")
|
||||
return None
|
||||
|
||||
sys.meta_path.insert(0, _Blocker())
|
||||
try:
|
||||
pconfig = MagicMock()
|
||||
pconfig.extra = {"ws_url": "ws://localhost:5225"}
|
||||
result = await _standalone_send(pconfig, "contact-42", "hi")
|
||||
assert isinstance(result, dict)
|
||||
assert "error" in result
|
||||
assert "websockets" in result["error"]
|
||||
finally:
|
||||
sys.meta_path[:] = saved_meta
|
||||
if saved_websockets is not None:
|
||||
sys.modules["websockets"] = saved_websockets
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_standalone_send_missing_url(monkeypatch):
|
||||
monkeypatch.delenv("SIMPLEX_WS_URL", raising=False)
|
||||
pconfig = MagicMock()
|
||||
pconfig.extra = {}
|
||||
# We expect the URL fallback (extra+env both empty) to be empty string,
|
||||
# producing an error. We also need websockets to be importable for the
|
||||
# url-check branch to be reached, so skip when it's not.
|
||||
try:
|
||||
import websockets.client # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("websockets not installed")
|
||||
|
||||
result = await _standalone_send(pconfig, "contact-42", "hi")
|
||||
assert isinstance(result, dict)
|
||||
# Either error about URL or a connection attempt failure — both are valid
|
||||
# signals that the standalone path requires configuration.
|
||||
assert "error" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 10. register() — plugin-side metadata
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_register_calls_register_platform():
|
||||
ctx = MagicMock()
|
||||
register(ctx)
|
||||
ctx.register_platform.assert_called_once()
|
||||
kwargs = ctx.register_platform.call_args.kwargs
|
||||
assert kwargs["name"] == "simplex"
|
||||
assert kwargs["label"] == "SimpleX Chat"
|
||||
assert kwargs["required_env"] == ["SIMPLEX_WS_URL"]
|
||||
assert kwargs["allowed_users_env"] == "SIMPLEX_ALLOWED_USERS"
|
||||
assert kwargs["allow_all_env"] == "SIMPLEX_ALLOW_ALL_USERS"
|
||||
assert kwargs["cron_deliver_env_var"] == "SIMPLEX_HOME_CHANNEL"
|
||||
assert callable(kwargs["check_fn"])
|
||||
assert callable(kwargs["validate_config"])
|
||||
assert callable(kwargs["is_connected"])
|
||||
assert callable(kwargs["env_enablement_fn"])
|
||||
assert callable(kwargs["standalone_sender_fn"])
|
||||
assert callable(kwargs["adapter_factory"])
|
||||
assert callable(kwargs["setup_fn"])
|
||||
# SimpleX uses opaque IDs only — no PII to redact.
|
||||
assert kwargs["pii_safe"] is True
|
||||
|
|
@ -283,6 +283,17 @@ class TestTeamsAdapterInit:
|
|||
adapter = TeamsAdapter(_make_config(client_id="id", client_secret="secret", tenant_id="tenant"))
|
||||
assert adapter._port == 5000
|
||||
|
||||
def test_invalid_port_from_extra_falls_back_to_default(self):
|
||||
adapter = TeamsAdapter(
|
||||
_make_config(client_id="id", client_secret="secret", tenant_id="tenant", port="abc")
|
||||
)
|
||||
assert adapter._port == 3978
|
||||
|
||||
def test_invalid_port_from_env_falls_back_to_default(self, monkeypatch):
|
||||
monkeypatch.setenv("TEAMS_PORT", "abc")
|
||||
adapter = TeamsAdapter(_make_config(client_id="id", client_secret="secret", tenant_id="tenant"))
|
||||
assert adapter._port == 3978
|
||||
|
||||
def test_platform_value(self):
|
||||
adapter = TeamsAdapter(_make_config(client_id="id", client_secret="secret", tenant_id="tenant"))
|
||||
assert adapter.platform.value == "teams"
|
||||
|
|
|
|||
|
|
@ -236,14 +236,13 @@ async def test_send_typing_does_not_fall_back_to_root_for_dm_topic():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_typing_skips_api_call_for_dm_topic_reply_fallback():
|
||||
"""Hermes-created DM topic lanes have no working Bot API typing route.
|
||||
async def test_send_typing_attempts_api_call_for_dm_topic_reply_fallback():
|
||||
"""Hermes-created DM topic lanes should still attempt scoped typing.
|
||||
|
||||
``send_chat_action`` only accepts ``message_thread_id``, which Telegram's
|
||||
Bot API 10.0 rejects for these lanes — the call would silently fail and
|
||||
log a "thread not found" warning every typing tick (every 2s). Skipping
|
||||
the call entirely keeps logs clean while preserving the user-visible
|
||||
behavior (no typing indicator either way for these lanes).
|
||||
Some private DM topic lanes route message sends through reply-anchor
|
||||
fallback, but live Telegram testing shows sendChatAction accepts the lane's
|
||||
message_thread_id. If Telegram rejects a stale or invalid thread later,
|
||||
send_typing already swallows that failure as non-fatal.
|
||||
"""
|
||||
adapter = _make_adapter()
|
||||
call_log = []
|
||||
|
|
@ -262,7 +261,9 @@ async def test_send_typing_skips_api_call_for_dm_topic_reply_fallback():
|
|||
},
|
||||
)
|
||||
|
||||
assert call_log == []
|
||||
assert call_log == [
|
||||
{"chat_id": 12345, "action": "typing", "message_thread_id": 20197},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -237,6 +237,8 @@ class TestUpdateCommandGatewayFlag:
|
|||
cmd_string = call_args[-1] if isinstance(call_args, list) else str(call_args)
|
||||
assert "--gateway" in cmd_string
|
||||
assert "PYTHONUNBUFFERED" in cmd_string
|
||||
assert "rc=$?" in cmd_string
|
||||
assert "status=$?" not in cmd_string
|
||||
assert "stream progress" in result
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -611,3 +611,93 @@ class TestHttpSessionLifecycle:
|
|||
|
||||
mock_task.cancel.assert_not_called()
|
||||
assert adapter._poll_task is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pre-flight: refuse to start the bridge when creds.json is missing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNoCredsPreflight:
|
||||
"""Verify ``connect()`` fast-fails as non-retryable when WhatsApp is
|
||||
enabled but the user never finished pairing (no ``creds.json``).
|
||||
|
||||
Without this guard, every gateway boot:
|
||||
• spawned the bridge subprocess (npm install if needed)
|
||||
• waited 30s for status:connected (never happens without creds)
|
||||
• queued WhatsApp for indefinite retries that would just repeat
|
||||
With the guard, ``connect()`` returns False immediately with a
|
||||
non-retryable fatal error so the reconnect watcher drops the platform
|
||||
and the gateway gets a single clear log line telling the user to run
|
||||
``hermes whatsapp``.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_returns_false_when_no_creds(self, tmp_path):
|
||||
from gateway.platforms.whatsapp import WhatsAppAdapter
|
||||
|
||||
adapter = WhatsAppAdapter.__new__(WhatsAppAdapter)
|
||||
adapter.platform = Platform.WHATSAPP
|
||||
adapter.config = MagicMock()
|
||||
adapter._bridge_port = 19876
|
||||
# Point bridge_script at a real existing file so the earlier
|
||||
# bridge-missing check doesn't trip — we want to exercise the
|
||||
# creds.json check specifically.
|
||||
bridge = tmp_path / "bridge.js"
|
||||
bridge.write_text("// stub")
|
||||
adapter._bridge_script = str(bridge)
|
||||
adapter._session_path = tmp_path / "session" # no creds.json inside
|
||||
adapter._session_path.mkdir()
|
||||
adapter._bridge_log_fh = None
|
||||
adapter._fatal_error_code = None
|
||||
adapter._fatal_error_message = None
|
||||
adapter._fatal_error_retryable = True
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.whatsapp.check_whatsapp_requirements",
|
||||
return_value=True,
|
||||
):
|
||||
result = await adapter.connect()
|
||||
|
||||
assert result is False
|
||||
# Non-retryable so the reconnect watcher drops it cleanly
|
||||
assert adapter._fatal_error_code == "whatsapp_not_paired"
|
||||
assert adapter._fatal_error_retryable is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_proceeds_when_creds_present(self, tmp_path):
|
||||
"""When creds.json exists, the preflight check is bypassed and
|
||||
connect() proceeds to the bridge bootstrap path. We don't fully
|
||||
simulate the bridge here — we just verify no fast-fail occurs.
|
||||
"""
|
||||
from gateway.platforms.whatsapp import WhatsAppAdapter
|
||||
|
||||
adapter = WhatsAppAdapter.__new__(WhatsAppAdapter)
|
||||
adapter.platform = Platform.WHATSAPP
|
||||
adapter.config = MagicMock()
|
||||
adapter._bridge_port = 19877
|
||||
bridge = tmp_path / "bridge.js"
|
||||
bridge.write_text("// stub")
|
||||
adapter._bridge_script = str(bridge)
|
||||
session_dir = tmp_path / "session"
|
||||
session_dir.mkdir()
|
||||
(session_dir / "creds.json").write_text("{}")
|
||||
adapter._session_path = session_dir
|
||||
adapter._bridge_log_fh = None
|
||||
adapter._fatal_error_code = None
|
||||
adapter._fatal_error_message = None
|
||||
adapter._fatal_error_retryable = True
|
||||
# Stub _acquire_platform_lock to return False so connect() exits
|
||||
# cleanly *after* the preflight, without spawning subprocesses.
|
||||
adapter._acquire_platform_lock = MagicMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.whatsapp.check_whatsapp_requirements",
|
||||
return_value=True,
|
||||
):
|
||||
result = await adapter.connect()
|
||||
|
||||
# Preflight passed — exits because we faked lock acquisition,
|
||||
# but the fatal-error code is NOT the "not paired" one.
|
||||
assert result is False
|
||||
assert adapter._fatal_error_code != "whatsapp_not_paired"
|
||||
|
|
|
|||
95
tests/hermes_cli/test_auth_loopback_ssh_hint.py
Normal file
95
tests/hermes_cli/test_auth_loopback_ssh_hint.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
"""Unit tests for _print_loopback_ssh_hint() in hermes_cli/auth.py.
|
||||
|
||||
The helper exists to warn users that loopback OAuth flows (xAI Grok OAuth,
|
||||
Spotify) don't work over SSH unless they set up an `ssh -L` port forward
|
||||
between their laptop's browser and the remote host's loopback listener.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import contextlib
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli import auth as auth_mod
|
||||
|
||||
|
||||
def _cap(fn):
|
||||
buf = io.StringIO()
|
||||
with contextlib.redirect_stdout(buf):
|
||||
fn()
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def test_loopback_ssh_hint_silent_when_not_remote(monkeypatch):
|
||||
monkeypatch.setattr(auth_mod, "_is_remote_session", lambda: False)
|
||||
out = _cap(lambda: auth_mod._print_loopback_ssh_hint(
|
||||
"http://127.0.0.1:56121/callback", docs_url=auth_mod.XAI_OAUTH_DOCS_URL
|
||||
))
|
||||
assert out == ""
|
||||
|
||||
|
||||
def test_loopback_ssh_hint_prints_tunnel_command_on_ssh(monkeypatch):
|
||||
monkeypatch.setattr(auth_mod, "_is_remote_session", lambda: True)
|
||||
out = _cap(lambda: auth_mod._print_loopback_ssh_hint(
|
||||
"http://127.0.0.1:56121/callback", docs_url=auth_mod.XAI_OAUTH_DOCS_URL
|
||||
))
|
||||
# Must include the exact ssh -L command with the port from the redirect URI
|
||||
assert "ssh -N -L 56121:127.0.0.1:56121" in out
|
||||
# Must include the provider-specific docs URL
|
||||
assert auth_mod.XAI_OAUTH_DOCS_URL in out
|
||||
# Must always include the cross-provider SSH guide
|
||||
assert auth_mod.OAUTH_OVER_SSH_DOCS_URL in out
|
||||
|
||||
|
||||
def test_loopback_ssh_hint_uses_actual_bound_port(monkeypatch):
|
||||
"""When the preferred port is busy, _xai_start_callback_server falls back to
|
||||
an OS-assigned port. The hint must echo whichever port actually got bound,
|
||||
not the hardcoded constant."""
|
||||
monkeypatch.setattr(auth_mod, "_is_remote_session", lambda: True)
|
||||
out = _cap(lambda: auth_mod._print_loopback_ssh_hint(
|
||||
"http://127.0.0.1:51234/callback", docs_url=auth_mod.XAI_OAUTH_DOCS_URL
|
||||
))
|
||||
assert "ssh -N -L 51234:127.0.0.1:51234" in out
|
||||
assert "56121" not in out
|
||||
|
||||
|
||||
def test_loopback_ssh_hint_silent_for_non_loopback_uri(monkeypatch):
|
||||
"""Defense in depth: if a future caller passes a non-loopback redirect URI
|
||||
by mistake, we don't tell the user to forward an external port."""
|
||||
monkeypatch.setattr(auth_mod, "_is_remote_session", lambda: True)
|
||||
out = _cap(lambda: auth_mod._print_loopback_ssh_hint(
|
||||
"https://example.com/callback", docs_url=auth_mod.XAI_OAUTH_DOCS_URL
|
||||
))
|
||||
assert out == ""
|
||||
|
||||
|
||||
def test_loopback_ssh_hint_silent_for_malformed_uri(monkeypatch):
|
||||
monkeypatch.setattr(auth_mod, "_is_remote_session", lambda: True)
|
||||
out = _cap(lambda: auth_mod._print_loopback_ssh_hint(
|
||||
"not-a-uri", docs_url=auth_mod.XAI_OAUTH_DOCS_URL
|
||||
))
|
||||
assert out == ""
|
||||
|
||||
|
||||
def test_loopback_ssh_hint_works_without_provider_docs_url(monkeypatch):
|
||||
monkeypatch.setattr(auth_mod, "_is_remote_session", lambda: True)
|
||||
out = _cap(lambda: auth_mod._print_loopback_ssh_hint(
|
||||
"http://127.0.0.1:43827/spotify/callback"
|
||||
))
|
||||
assert "ssh -N -L 43827:127.0.0.1:43827" in out
|
||||
# Generic SSH guide is always present even without a provider-specific URL
|
||||
assert auth_mod.OAUTH_OVER_SSH_DOCS_URL in out
|
||||
# Should not falsely show "Provider docs:" when no docs_url was passed
|
||||
assert "Provider docs:" not in out
|
||||
|
||||
|
||||
def test_loopback_ssh_hint_accepts_localhost_hostname(monkeypatch):
|
||||
"""The constant is 127.0.0.1, but parsing tolerates `localhost` too in case
|
||||
a future caller normalizes the URI differently."""
|
||||
monkeypatch.setattr(auth_mod, "_is_remote_session", lambda: True)
|
||||
out = _cap(lambda: auth_mod._print_loopback_ssh_hint(
|
||||
"http://localhost:56121/callback"
|
||||
))
|
||||
assert "ssh -N -L 56121:127.0.0.1:56121" in out
|
||||
1605
tests/hermes_cli/test_auth_xai_oauth_provider.py
Normal file
1605
tests/hermes_cli/test_auth_xai_oauth_provider.py
Normal file
File diff suppressed because it is too large
Load diff
35
tests/hermes_cli/test_banner_pip_update.py
Normal file
35
tests/hermes_cli/test_banner_pip_update.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
|
||||
def testcheck_via_pypi_detects_update():
|
||||
"""check_via_pypi returns 1 when PyPI has newer version."""
|
||||
from hermes_cli.banner import check_via_pypi
|
||||
with patch("hermes_cli.banner.VERSION", "0.12.0"):
|
||||
with patch("hermes_cli.banner._fetch_pypi_latest", return_value="0.13.0"):
|
||||
result = check_via_pypi()
|
||||
assert result == 1
|
||||
|
||||
|
||||
def testcheck_via_pypi_up_to_date():
|
||||
"""check_via_pypi returns 0 when versions match."""
|
||||
from hermes_cli.banner import check_via_pypi
|
||||
with patch("hermes_cli.banner.VERSION", "0.13.0"):
|
||||
with patch("hermes_cli.banner._fetch_pypi_latest", return_value="0.13.0"):
|
||||
result = check_via_pypi()
|
||||
assert result == 0
|
||||
|
||||
|
||||
def testcheck_via_pypi_network_failure():
|
||||
"""check_via_pypi returns None on network error."""
|
||||
from hermes_cli.banner import check_via_pypi
|
||||
with patch("hermes_cli.banner._fetch_pypi_latest", return_value=None):
|
||||
result = check_via_pypi()
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_version_tuple_comparison():
|
||||
"""Version comparison works with multi-segment versions."""
|
||||
from hermes_cli.banner import _version_tuple
|
||||
assert _version_tuple("0.13.0") > _version_tuple("0.12.0")
|
||||
assert _version_tuple("0.13.0") == _version_tuple("0.13.0")
|
||||
assert _version_tuple("1.0.0") > _version_tuple("0.99.99")
|
||||
|
|
@ -130,17 +130,22 @@ class TestCmdUpdateBranchFallback:
|
|||
# 1. repo root — slash-command / TUI bridge deps
|
||||
# 2. ui-tui/ — Ink TUI deps
|
||||
# 3. web/ — install + "npm run build" for the web frontend
|
||||
full_flags = [
|
||||
#
|
||||
# Repo-root and ui-tui installs intentionally omit `--silent` and run
|
||||
# without `capture_output` so optional postinstall scripts (e.g.
|
||||
# `@askjo/camofox-browser`'s browser-binary fetch) print progress —
|
||||
# otherwise long downloads look like a hang (#18840). The web/ install
|
||||
# keeps `--silent` because its build step is short and noisy.
|
||||
update_flags = [
|
||||
"/usr/bin/npm",
|
||||
"ci",
|
||||
"--silent",
|
||||
"--no-fund",
|
||||
"--no-audit",
|
||||
"--progress=false",
|
||||
]
|
||||
assert npm_calls[:2] == [
|
||||
(full_flags, PROJECT_ROOT),
|
||||
(full_flags, PROJECT_ROOT / "ui-tui"),
|
||||
(update_flags, PROJECT_ROOT),
|
||||
(update_flags, PROJECT_ROOT / "ui-tui"),
|
||||
]
|
||||
if len(npm_calls) > 2:
|
||||
assert npm_calls[2:] == [
|
||||
|
|
@ -148,6 +153,24 @@ class TestCmdUpdateBranchFallback:
|
|||
(["/usr/bin/npm", "run", "build"], PROJECT_ROOT / "web"),
|
||||
]
|
||||
|
||||
# Regression for #18840: repo root + ui-tui installs must stream
|
||||
# output (capture_output=False) so postinstall progress is visible
|
||||
# to the user.
|
||||
repo_and_tui_calls = [
|
||||
call
|
||||
for call in mock_run.call_args_list
|
||||
if call.args
|
||||
and call.args[0][0] == "/usr/bin/npm"
|
||||
and call.args[0][1] == "ci"
|
||||
and call.kwargs.get("cwd") in (PROJECT_ROOT, PROJECT_ROOT / "ui-tui")
|
||||
]
|
||||
assert len(repo_and_tui_calls) == 2
|
||||
for call in repo_and_tui_calls:
|
||||
assert call.kwargs.get("capture_output") is False, (
|
||||
"repo-root / ui-tui npm install must stream output "
|
||||
"(no capture_output) so postinstall progress is visible"
|
||||
)
|
||||
|
||||
def test_update_non_interactive_runs_safe_config_migrations(self, mock_args, capsys):
|
||||
"""Dashboard/web updates apply non-interactive migrations before restart."""
|
||||
with patch("shutil.which", return_value=None), patch(
|
||||
|
|
|
|||
|
|
@ -8,9 +8,13 @@ import pytest
|
|||
|
||||
from hermes_cli.codex_runtime_plugin_migration import (
|
||||
MIGRATION_MARKER,
|
||||
MIGRATION_END_MARKER,
|
||||
MigrationReport,
|
||||
_build_hermes_tools_mcp_entry,
|
||||
_format_toml_value,
|
||||
_looks_like_test_tempdir,
|
||||
_strip_existing_managed_block,
|
||||
_strip_unmanaged_plugin_tables,
|
||||
_translate_one_server,
|
||||
migrate,
|
||||
render_codex_toml_section,
|
||||
|
|
@ -567,10 +571,31 @@ class TestMigrate:
|
|||
assert "[model]" in new_text
|
||||
assert 'profile = "default"' in new_text
|
||||
assert "[providers.openai]" in new_text
|
||||
# And new MCP block appended
|
||||
# And new MCP block inserted without breaking user tables
|
||||
assert "[mcp_servers.a]" in new_text
|
||||
assert MIGRATION_MARKER in new_text
|
||||
|
||||
def test_managed_root_keys_stay_top_level_when_config_ends_in_table(self, tmp_path):
|
||||
"""TOML has no explicit 'leave current table' syntax. If Hermes appends
|
||||
root keys like default_permissions after a user table such as [features],
|
||||
Codex parses them as features.default_permissions and rejects the config.
|
||||
The managed block must therefore be inserted before the first table."""
|
||||
import tomllib
|
||||
|
||||
target = tmp_path / "config.toml"
|
||||
target.write_text(
|
||||
'model = "gpt-5.5"\n'
|
||||
"\n"
|
||||
"[features]\n"
|
||||
"terminal_resize_reflow = true\n"
|
||||
)
|
||||
migrate({}, codex_home=tmp_path, discover_plugins=False, expose_hermes_tools=False)
|
||||
new_text = target.read_text()
|
||||
parsed = tomllib.loads(new_text)
|
||||
assert parsed["default_permissions"] == ":workspace"
|
||||
assert "default_permissions" not in parsed["features"]
|
||||
assert new_text.index(MIGRATION_MARKER) < new_text.index("[features]")
|
||||
|
||||
def test_preserves_user_mcp_server_outside_managed_block(self, tmp_path):
|
||||
"""Quirk #6: when a user adds their own MCP server entry directly
|
||||
to ~/.codex/config.toml outside Hermes' managed block, re-running
|
||||
|
|
@ -635,3 +660,206 @@ class TestMigrate:
|
|||
assert "Migrated 2 MCP server(s)" in summary
|
||||
assert "- a" in summary
|
||||
assert "- b" in summary
|
||||
|
||||
|
||||
# ---- Bug B: duplicate [plugins.X] tables ----
|
||||
|
||||
|
||||
class TestStripUnmanagedPluginTables:
|
||||
"""Regression tests for issue #26250 Bug B.
|
||||
|
||||
When codex itself writes ``[plugins."<name>@<marketplace>"]`` tables
|
||||
(via the user running ``codex plugins enable`` directly), re-running
|
||||
``hermes codex-runtime migrate`` would re-emit them inside the managed
|
||||
block and the resulting duplicate-table-header would crash codex.
|
||||
"""
|
||||
|
||||
def test_strips_plugin_tables_outside_managed_block(self):
|
||||
text = (
|
||||
'model = "gpt-5.5"\n'
|
||||
"\n"
|
||||
"[mcp_servers.user-thing]\n"
|
||||
'command = "x"\n'
|
||||
"\n"
|
||||
'[plugins."tasks@openai-curated"]\n'
|
||||
"enabled = true\n"
|
||||
"\n"
|
||||
'[plugins."web-search@openai-curated"]\n'
|
||||
"enabled = true\n"
|
||||
"\n"
|
||||
"[features]\n"
|
||||
"terminal_resize_reflow = true\n"
|
||||
)
|
||||
stripped = _strip_unmanaged_plugin_tables(text)
|
||||
assert "[plugins." not in stripped
|
||||
# Non-plugin content preserved
|
||||
assert "[mcp_servers.user-thing]" in stripped
|
||||
assert "[features]" in stripped
|
||||
assert "terminal_resize_reflow = true" in stripped
|
||||
|
||||
def test_preserves_content_when_no_plugin_tables(self):
|
||||
text = (
|
||||
'model = "gpt-5.5"\n'
|
||||
"\n"
|
||||
"[mcp_servers.x]\n"
|
||||
'command = "y"\n'
|
||||
)
|
||||
assert _strip_unmanaged_plugin_tables(text) == text
|
||||
|
||||
def test_multi_line_array_in_plugin_table_does_not_leak(self):
|
||||
"""A multi-line TOML array inside a [plugins.X] table whose
|
||||
continuation lines start with ``[`` (e.g. nested arrays) must NOT
|
||||
prematurely exit the strip region — otherwise array fragments
|
||||
leak into top-level output and produce invalid TOML on the next
|
||||
codex startup. Regression guard for #26260 review.
|
||||
"""
|
||||
text = (
|
||||
'[plugins."tasks@openai-curated"]\n'
|
||||
"allowed = [\n"
|
||||
' "a",\n'
|
||||
' ["nested"],\n'
|
||||
"]\n"
|
||||
"[features]\n"
|
||||
"x = 1\n"
|
||||
)
|
||||
stripped = _strip_unmanaged_plugin_tables(text)
|
||||
# Everything inside the plugin table — including the multi-line
|
||||
# array's continuation lines starting with `[` — should be gone.
|
||||
assert '["nested"]' not in stripped
|
||||
assert "allowed" not in stripped
|
||||
# Sibling user table survives intact.
|
||||
assert "[features]" in stripped
|
||||
assert "x = 1" in stripped
|
||||
# Result is still valid TOML.
|
||||
import tomllib
|
||||
tomllib.loads(stripped)
|
||||
|
||||
def test_migrate_dedups_codex_owned_plugin_tables(self, tmp_path, monkeypatch):
|
||||
"""End-to-end: codex's pre-existing [plugins.X] tables get replaced by
|
||||
the managed block's re-emission rather than duplicated."""
|
||||
target = tmp_path / "config.toml"
|
||||
target.write_text(
|
||||
"[mcp_servers.user-server]\n"
|
||||
'command = "x"\n'
|
||||
"\n"
|
||||
'[plugins."tasks@openai-curated"]\n'
|
||||
"enabled = true\n"
|
||||
)
|
||||
|
||||
# Simulate codex's plugin/list reporting the same plugin tasks@openai-curated.
|
||||
def fake_query(codex_home=None, timeout=8.0):
|
||||
return (
|
||||
[{"name": "tasks", "marketplace": "openai-curated", "enabled": True}],
|
||||
None,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.codex_runtime_plugin_migration._query_codex_plugins",
|
||||
fake_query,
|
||||
)
|
||||
migrate({}, codex_home=tmp_path, discover_plugins=True, expose_hermes_tools=False)
|
||||
new_text = target.read_text()
|
||||
# Only ONE [plugins."tasks@openai-curated"] header should remain — inside
|
||||
# the managed block — not the original outside-the-block copy.
|
||||
assert new_text.count('[plugins."tasks@openai-curated"]') == 1
|
||||
# And the surviving one is inside our managed section.
|
||||
managed_start = new_text.index(MIGRATION_MARKER)
|
||||
managed_end = new_text.index(MIGRATION_END_MARKER)
|
||||
plugin_idx = new_text.index('[plugins."tasks@openai-curated"]')
|
||||
assert managed_start < plugin_idx < managed_end
|
||||
# File parses cleanly as TOML (the original duplicate-key error is gone).
|
||||
import tomllib
|
||||
tomllib.loads(new_text)
|
||||
|
||||
def test_migrate_preserves_plugin_tables_when_plugin_list_fails(self, tmp_path, monkeypatch):
|
||||
"""If plugin/list RPC fails, we can't re-emit plugins authoritatively,
|
||||
so we must NOT strip the user's existing [plugins.X] tables — that
|
||||
would silently lose them."""
|
||||
target = tmp_path / "config.toml"
|
||||
target.write_text(
|
||||
'[plugins."tasks@openai-curated"]\n'
|
||||
"enabled = true\n"
|
||||
)
|
||||
|
||||
def fake_query(codex_home=None, timeout=8.0):
|
||||
return ([], "plugin/list query failed: codex not installed")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.codex_runtime_plugin_migration._query_codex_plugins",
|
||||
fake_query,
|
||||
)
|
||||
migrate({}, codex_home=tmp_path, discover_plugins=True, expose_hermes_tools=False)
|
||||
new_text = target.read_text()
|
||||
# User's plugin table preserved verbatim — we can't re-emit it.
|
||||
assert '[plugins."tasks@openai-curated"]' in new_text
|
||||
|
||||
|
||||
# ---- Bug C: HERMES_HOME tempdir leak into ~/.codex/config.toml ----
|
||||
|
||||
|
||||
class TestHermesHomeLeakGuard:
|
||||
"""Regression tests for issue #26250 Bug C.
|
||||
|
||||
Previously ``_build_hermes_tools_mcp_entry()`` read ``HERMES_HOME``
|
||||
directly from ``os.environ``, so a pytest ``monkeypatch.setenv`` would
|
||||
leak a transient tempdir path into the user's real ``~/.codex/config.toml``
|
||||
once codex spawned the hermes-tools MCP subprocess.
|
||||
"""
|
||||
|
||||
def test_tempdir_detector_recognizes_pytest_paths(self):
|
||||
assert _looks_like_test_tempdir(
|
||||
"/private/var/folders/abc/pytest-of-kshitij/pytest-137/popen-gw2/test_X/hermes_test"
|
||||
)
|
||||
assert _looks_like_test_tempdir(
|
||||
"/tmp/pytest-of-user/pytest-12/test_X/hermes"
|
||||
)
|
||||
assert _looks_like_test_tempdir(
|
||||
"/private/var/folders/zz/T/pytest-of-bob/pytest-1"
|
||||
)
|
||||
|
||||
def test_tempdir_detector_accepts_real_hermes_home(self):
|
||||
assert not _looks_like_test_tempdir("/Users/alice/.hermes")
|
||||
assert not _looks_like_test_tempdir("/home/bob/.hermes")
|
||||
assert not _looks_like_test_tempdir("/opt/hermes")
|
||||
assert not _looks_like_test_tempdir("")
|
||||
|
||||
def test_pytest_tempdir_not_burned_into_mcp_env(self, monkeypatch):
|
||||
"""The headline regression: even when HERMES_HOME points at a pytest
|
||||
tempdir, _build_hermes_tools_mcp_entry() must NOT propagate it."""
|
||||
monkeypatch.setenv(
|
||||
"HERMES_HOME",
|
||||
"/private/var/folders/xx/pytest-of-user/pytest-99/test_x/hermes_test",
|
||||
)
|
||||
entry = _build_hermes_tools_mcp_entry()
|
||||
env = entry.get("env", {})
|
||||
assert "HERMES_HOME" not in env, (
|
||||
f"pytest-tempdir HERMES_HOME leaked into codex MCP entry: "
|
||||
f"{env.get('HERMES_HOME')!r}"
|
||||
)
|
||||
|
||||
def test_real_hermes_home_propagates(self, monkeypatch, tmp_path):
|
||||
"""A legitimate HERMES_HOME (not a tempdir path) DOES propagate so the
|
||||
MCP subprocess sees the same config as the parent CLI."""
|
||||
# Use a path that looks real — under /Users or /home, not /var/folders.
|
||||
# We can't easily create one in the test, so just use a stable path
|
||||
# outside any tempdir-detector needle. The detector checks for tempdir
|
||||
# markers, not for path existence.
|
||||
real_path = "/Users/alice/.hermes"
|
||||
monkeypatch.setenv("HERMES_HOME", real_path)
|
||||
entry = _build_hermes_tools_mcp_entry()
|
||||
env = entry.get("env", {})
|
||||
assert env.get("HERMES_HOME") == real_path
|
||||
|
||||
def test_unset_hermes_home_omits_env_key(self, monkeypatch):
|
||||
"""When HERMES_HOME is unset in the environment, the MCP entry MUST
|
||||
NOT bake in a resolved-default path. The codex subprocess should
|
||||
inherit whatever HERMES_HOME its launcher (systemd, gateway, shell)
|
||||
sets at runtime, rather than being pinned to migrate-time defaults.
|
||||
Regression guard for issue #26250 follow-up review."""
|
||||
monkeypatch.delenv("HERMES_HOME", raising=False)
|
||||
entry = _build_hermes_tools_mcp_entry()
|
||||
env = entry.get("env", {})
|
||||
assert "HERMES_HOME" not in env, (
|
||||
f"HERMES_HOME should not be set when env var is unset, got: "
|
||||
f"{env.get('HERMES_HOME')!r}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -114,8 +114,15 @@ class TestApply:
|
|||
def persist(c):
|
||||
persisted.update(c)
|
||||
|
||||
# Patch migrate so this test doesn't reach into the user's real
|
||||
# ~/.codex/config.toml. See issue #26250 Bug C — without this patch,
|
||||
# crs.apply() invokes the real migrate() which writes to
|
||||
# Path.home() / ".codex" using whatever HERMES_HOME the running pytest
|
||||
# session has set, leaking pytest tempdir paths into the user's
|
||||
# codex config.
|
||||
with patch.object(crs, "check_codex_binary_ok",
|
||||
return_value=(True, "0.130.0")):
|
||||
return_value=(True, "0.130.0")), \
|
||||
patch("hermes_cli.codex_runtime_plugin_migration.migrate"):
|
||||
r = crs.apply(cfg, "codex_app_server", persist_callback=persist)
|
||||
assert r.success
|
||||
assert r.new_value == "codex_app_server"
|
||||
|
|
|
|||
43
tests/hermes_cli/test_dep_ensure.py
Normal file
43
tests/hermes_cli/test_dep_ensure.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def test_ensure_dependency_skips_when_present():
|
||||
"""ensure_dependency is a no-op when the dep is already available."""
|
||||
from hermes_cli.dep_ensure import ensure_dependency
|
||||
with patch("hermes_cli.dep_ensure.shutil") as mock_shutil:
|
||||
mock_shutil.which.return_value = "/usr/bin/node"
|
||||
result = ensure_dependency("node", interactive=False)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_ensure_dependency_returns_false_when_missing_noninteractive():
|
||||
"""ensure_dependency returns False for missing dep in non-interactive mode."""
|
||||
from hermes_cli.dep_ensure import ensure_dependency
|
||||
with patch("hermes_cli.dep_ensure.shutil") as mock_shutil:
|
||||
mock_shutil.which.return_value = None
|
||||
with patch("hermes_cli.dep_ensure._find_install_script", return_value=None):
|
||||
result = ensure_dependency("node", interactive=False)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_find_install_script_from_checkout(tmp_path):
|
||||
"""_find_install_script finds scripts/install.sh in a git checkout."""
|
||||
from hermes_cli.dep_ensure import _find_install_script
|
||||
scripts_dir = tmp_path / "scripts"
|
||||
scripts_dir.mkdir()
|
||||
(scripts_dir / "install.sh").write_text("#!/bin/bash", encoding="utf-8")
|
||||
result = _find_install_script(package_dir=tmp_path / "hermes_cli", repo_root=tmp_path)
|
||||
assert result is not None
|
||||
assert result.name == "install.sh"
|
||||
|
||||
|
||||
def test_find_install_script_from_wheel(tmp_path):
|
||||
"""_find_install_script finds bundled install.sh in a wheel."""
|
||||
from hermes_cli.dep_ensure import _find_install_script
|
||||
bundled = tmp_path / "hermes_cli" / "scripts"
|
||||
bundled.mkdir(parents=True)
|
||||
(bundled / "install.sh").write_text("#!/bin/bash", encoding="utf-8")
|
||||
result = _find_install_script(package_dir=tmp_path / "hermes_cli", repo_root=tmp_path)
|
||||
assert result is not None
|
||||
assert result.name == "install.sh"
|
||||
|
|
@ -839,3 +839,108 @@ class TestGitHubTokenCheck:
|
|||
|
||||
assert "gh auth" in str(call_log) or any(c[0] == "gh" for c in call_log), f"gh not called: {call_log}"
|
||||
assert "GitHub authenticated via gh CLI" in out or "token configured" in out
|
||||
|
||||
|
||||
def _run_doctor_with_healthy_oauth_fallback(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
*,
|
||||
env_key: str,
|
||||
bad_key: str,
|
||||
failing_host: str,
|
||||
gemini_oauth_status: dict,
|
||||
minimax_oauth_status: dict,
|
||||
) -> str:
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir(parents=True, exist_ok=True)
|
||||
(home / "config.yaml").write_text(
|
||||
"model:\n"
|
||||
" provider: nous\n"
|
||||
" default: moonshotai/kimi-k2.6\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
project = tmp_path / "project"
|
||||
project.mkdir(exist_ok=True)
|
||||
|
||||
monkeypatch.setattr(doctor_mod, "HERMES_HOME", home)
|
||||
monkeypatch.setattr(doctor_mod, "PROJECT_ROOT", project)
|
||||
monkeypatch.setattr(doctor_mod, "_DHH", str(home))
|
||||
monkeypatch.setenv(env_key, bad_key)
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
|
||||
monkeypatch.delenv("MINIMAX_API_KEY", raising=False)
|
||||
monkeypatch.delenv("MINIMAX_CN_API_KEY", raising=False)
|
||||
monkeypatch.setenv(env_key, bad_key)
|
||||
|
||||
fake_model_tools = types.SimpleNamespace(
|
||||
check_tool_availability=lambda *a, **kw: ([], []),
|
||||
TOOLSET_REQUIREMENTS={},
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "model_tools", fake_model_tools)
|
||||
|
||||
from hermes_cli import auth as _auth_mod
|
||||
|
||||
monkeypatch.setattr(_auth_mod, "get_nous_auth_status", lambda: {"logged_in": True})
|
||||
monkeypatch.setattr(_auth_mod, "get_codex_auth_status", lambda: {})
|
||||
monkeypatch.setattr(_auth_mod, "get_gemini_oauth_auth_status", lambda: gemini_oauth_status)
|
||||
monkeypatch.setattr(_auth_mod, "get_minimax_oauth_auth_status", lambda: minimax_oauth_status)
|
||||
|
||||
def fake_get(url, headers=None, timeout=None):
|
||||
status = 401 if failing_host in url else 200
|
||||
return types.SimpleNamespace(status_code=status)
|
||||
|
||||
import httpx
|
||||
|
||||
monkeypatch.setattr(httpx, "get", fake_get)
|
||||
|
||||
buf = io.StringIO()
|
||||
with contextlib.redirect_stdout(buf):
|
||||
doctor_mod.run_doctor(Namespace(fix=False))
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("env_key", "bad_key", "failing_host", "gemini_oauth_status", "minimax_oauth_status", "unexpected_issue"),
|
||||
[
|
||||
(
|
||||
"GOOGLE_API_KEY",
|
||||
"bad-gemini-key",
|
||||
"googleapis.com",
|
||||
{"logged_in": True, "email": "user@example.com"},
|
||||
{},
|
||||
"Check GOOGLE_API_KEY in .env",
|
||||
),
|
||||
(
|
||||
"MINIMAX_API_KEY",
|
||||
"bad-minimax-key",
|
||||
"minimax.io",
|
||||
{},
|
||||
{"logged_in": True, "region": "global"},
|
||||
"Check MINIMAX_API_KEY in .env",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_run_doctor_ignores_invalid_direct_keys_when_oauth_fallback_is_healthy(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
env_key,
|
||||
bad_key,
|
||||
failing_host,
|
||||
gemini_oauth_status,
|
||||
minimax_oauth_status,
|
||||
unexpected_issue,
|
||||
):
|
||||
out = _run_doctor_with_healthy_oauth_fallback(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
env_key=env_key,
|
||||
bad_key=bad_key,
|
||||
failing_host=failing_host,
|
||||
gemini_oauth_status=gemini_oauth_status,
|
||||
minimax_oauth_status=minimax_oauth_status,
|
||||
)
|
||||
|
||||
assert "invalid API key" in out
|
||||
assert unexpected_issue not in out
|
||||
|
|
|
|||
|
|
@ -559,3 +559,9 @@ class TestStopProfileGateway:
|
|||
assert calls["kill"] == 1 # one SIGTERM
|
||||
assert calls["alive_probes"] == 20 # 20 liveness polls over the 2s window
|
||||
assert calls["remove"] == 0
|
||||
|
||||
|
||||
def test_module_has_logger():
|
||||
"""Verify module has a logger instance (regression guard for #27154)."""
|
||||
assert hasattr(gateway, "logger")
|
||||
assert gateway.logger.name == "hermes_cli.gateway"
|
||||
|
|
|
|||
31
tests/hermes_cli/test_gateway_service_paths.py
Normal file
31
tests/hermes_cli/test_gateway_service_paths.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def test_service_path_skips_nonexistent_node_modules(tmp_path):
|
||||
"""Service PATH should not include node_modules/.bin if it doesn't exist."""
|
||||
from hermes_cli.gateway import _build_service_path_dirs
|
||||
with patch("hermes_cli.gateway.get_hermes_home", return_value=tmp_path / ".hermes"):
|
||||
dirs = _build_service_path_dirs(project_root=tmp_path)
|
||||
node_modules_bin = str(tmp_path / "node_modules" / ".bin")
|
||||
assert node_modules_bin not in dirs
|
||||
|
||||
|
||||
def test_service_path_includes_node_modules_when_present(tmp_path):
|
||||
"""Service PATH should include node_modules/.bin when it exists."""
|
||||
nm_bin = tmp_path / "node_modules" / ".bin"
|
||||
nm_bin.mkdir(parents=True)
|
||||
from hermes_cli.gateway import _build_service_path_dirs
|
||||
with patch("hermes_cli.gateway.get_hermes_home", return_value=tmp_path / ".hermes"):
|
||||
dirs = _build_service_path_dirs(project_root=tmp_path)
|
||||
assert str(nm_bin) in dirs
|
||||
|
||||
|
||||
def test_service_path_includes_hermes_home_node_modules(tmp_path):
|
||||
"""Service PATH should include ~/.hermes/node_modules/.bin when it exists."""
|
||||
hermes_nm = tmp_path / ".hermes" / "node_modules" / ".bin"
|
||||
hermes_nm.mkdir(parents=True)
|
||||
from hermes_cli.gateway import _build_service_path_dirs
|
||||
with patch("hermes_cli.gateway.get_hermes_home", return_value=tmp_path / ".hermes"):
|
||||
dirs = _build_service_path_dirs(project_root=tmp_path)
|
||||
assert str(hermes_nm) in dirs
|
||||
|
|
@ -103,6 +103,33 @@ class TestPluginPickerInjection:
|
|||
visible = tools_config._visible_providers(browser, {})
|
||||
assert all(p.get("image_gen_plugin_name") is None for p in visible)
|
||||
|
||||
def test_post_setup_propagated_when_declared(self, monkeypatch):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
image_gen_registry.register_provider(_FakeProvider(
|
||||
"xai_img",
|
||||
schema={
|
||||
"name": "xAI Grok Imagine",
|
||||
"badge": "paid",
|
||||
"tag": "grok image",
|
||||
"env_vars": [],
|
||||
"post_setup": "xai_grok",
|
||||
},
|
||||
))
|
||||
|
||||
rows = tools_config._plugin_image_gen_providers()
|
||||
match = next(r for r in rows if r.get("image_gen_plugin_name") == "xai_img")
|
||||
assert match["post_setup"] == "xai_grok"
|
||||
|
||||
def test_post_setup_omitted_when_not_declared(self, monkeypatch):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
image_gen_registry.register_provider(_FakeProvider("plain_img"))
|
||||
|
||||
rows = tools_config._plugin_image_gen_providers()
|
||||
match = next(r for r in rows if r.get("image_gen_plugin_name") == "plain_img")
|
||||
assert "post_setup" not in match
|
||||
|
||||
|
||||
class TestPluginCatalog:
|
||||
def test_plugin_catalog_returns_models(self):
|
||||
|
|
|
|||
|
|
@ -29,7 +29,8 @@ def test_format_managed_message_homebrew(monkeypatch):
|
|||
def test_recommended_update_command_defaults_to_hermes_update(monkeypatch):
|
||||
monkeypatch.delenv("HERMES_MANAGED", raising=False)
|
||||
|
||||
assert recommended_update_command() == "hermes update"
|
||||
with patch("hermes_cli.config.detect_install_method", return_value="git"):
|
||||
assert recommended_update_command() == "hermes update"
|
||||
|
||||
|
||||
def test_cmd_update_blocks_managed_homebrew(monkeypatch, capsys):
|
||||
|
|
|
|||
37
tests/hermes_cli/test_pip_install_detection.py
Normal file
37
tests/hermes_cli/test_pip_install_detection.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def test_pip_install_detected_when_no_git_dir(tmp_path):
|
||||
"""When PROJECT_ROOT has no .git, detect as pip install."""
|
||||
with patch("hermes_cli.config.get_managed_system", return_value=None):
|
||||
from hermes_cli.config import detect_install_method
|
||||
method = detect_install_method(project_root=tmp_path)
|
||||
assert method == "pip"
|
||||
|
||||
|
||||
def test_git_install_detected_when_git_dir_exists(tmp_path):
|
||||
"""When PROJECT_ROOT has .git, detect as git install."""
|
||||
(tmp_path / ".git").mkdir()
|
||||
with patch("hermes_cli.config.get_managed_system", return_value=None):
|
||||
from hermes_cli.config import detect_install_method
|
||||
method = detect_install_method(project_root=tmp_path)
|
||||
assert method == "git"
|
||||
|
||||
|
||||
def test_managed_install_takes_precedence(tmp_path):
|
||||
"""When HERMES_MANAGED is set, that takes precedence over git detection."""
|
||||
(tmp_path / ".git").mkdir()
|
||||
with patch("hermes_cli.config.get_managed_system", return_value="NixOS"):
|
||||
from hermes_cli.config import detect_install_method
|
||||
method = detect_install_method(project_root=tmp_path)
|
||||
assert method == "nixos"
|
||||
|
||||
|
||||
def test_recommended_update_command_pip():
|
||||
"""Pip installs recommend pip install --upgrade."""
|
||||
from hermes_cli.config import recommended_update_command_for_method
|
||||
cmd = recommended_update_command_for_method("pip")
|
||||
assert "pip install" in cmd or "uv pip install" in cmd
|
||||
assert "--upgrade" in cmd
|
||||
assert "hermes-agent" in cmd
|
||||
|
|
@ -662,6 +662,129 @@ class TestPluginContext:
|
|||
from tools.registry import registry
|
||||
assert "plugin_echo" in registry._tools
|
||||
|
||||
def test_register_tool_rejects_shadow_without_override(self, tmp_path, monkeypatch, caplog):
|
||||
"""Without override=True, registering a tool name claimed by a different toolset is rejected."""
|
||||
from tools.registry import registry
|
||||
|
||||
# Seed an existing entry from a non-plugin toolset.
|
||||
registry.register(
|
||||
name="shadow_target",
|
||||
toolset="terminal",
|
||||
schema={"name": "shadow_target", "description": "Built-in", "parameters": {"type": "object", "properties": {}}},
|
||||
handler=lambda args, **kw: "built-in",
|
||||
)
|
||||
original_handler = registry._tools["shadow_target"].handler
|
||||
try:
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
plugin_dir = plugins_dir / "shadow_plugin"
|
||||
plugin_dir.mkdir(parents=True)
|
||||
(plugin_dir / "plugin.yaml").write_text(yaml.dump({"name": "shadow_plugin"}))
|
||||
(plugin_dir / "__init__.py").write_text(
|
||||
'def register(ctx):\n'
|
||||
' ctx.register_tool(\n'
|
||||
' name="shadow_target",\n'
|
||||
' toolset="plugin_shadow_plugin",\n'
|
||||
' schema={"name": "shadow_target", "description": "Plugin", "parameters": {"type": "object", "properties": {}}},\n'
|
||||
' handler=lambda args, **kw: "plugin",\n'
|
||||
' )\n'
|
||||
)
|
||||
hermes_home = tmp_path / "hermes_test"
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
yaml.safe_dump({"plugins": {"enabled": ["shadow_plugin"]}})
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
with caplog.at_level(logging.ERROR, logger="tools.registry"):
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
# Original handler must still be in place — registration was rejected.
|
||||
assert registry._tools["shadow_target"].handler is original_handler
|
||||
assert registry._tools["shadow_target"].toolset == "terminal"
|
||||
# And an ERROR was logged explaining why and how to opt in.
|
||||
assert any("override=True" in r.message for r in caplog.records)
|
||||
finally:
|
||||
registry.deregister("shadow_target")
|
||||
|
||||
def test_register_tool_override_replaces_existing(self, tmp_path, monkeypatch, caplog):
|
||||
"""override=True lets a plugin replace an existing built-in tool."""
|
||||
from tools.registry import registry
|
||||
|
||||
registry.register(
|
||||
name="override_target",
|
||||
toolset="terminal",
|
||||
schema={"name": "override_target", "description": "Built-in", "parameters": {"type": "object", "properties": {}}},
|
||||
handler=lambda args, **kw: "built-in",
|
||||
)
|
||||
try:
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
plugin_dir = plugins_dir / "override_plugin"
|
||||
plugin_dir.mkdir(parents=True)
|
||||
(plugin_dir / "plugin.yaml").write_text(yaml.dump({"name": "override_plugin"}))
|
||||
(plugin_dir / "__init__.py").write_text(
|
||||
'def register(ctx):\n'
|
||||
' ctx.register_tool(\n'
|
||||
' name="override_target",\n'
|
||||
' toolset="plugin_override_plugin",\n'
|
||||
' schema={"name": "override_target", "description": "Plugin", "parameters": {"type": "object", "properties": {}}},\n'
|
||||
' handler=lambda args, **kw: "plugin",\n'
|
||||
' override=True,\n'
|
||||
' )\n'
|
||||
)
|
||||
hermes_home = tmp_path / "hermes_test"
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
yaml.safe_dump({"plugins": {"enabled": ["override_plugin"]}})
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
with caplog.at_level(logging.INFO, logger="tools.registry"):
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
# Plugin handler replaced the built-in one.
|
||||
assert registry._tools["override_target"].toolset == "plugin_override_plugin"
|
||||
assert registry._tools["override_target"].handler({}, ) == "plugin"
|
||||
# Override is audit-logged at INFO.
|
||||
assert any(
|
||||
"overriding existing" in r.message and "override_target" in r.message
|
||||
for r in caplog.records
|
||||
)
|
||||
# Plugin tracks it.
|
||||
assert "override_target" in mgr._plugin_tool_names
|
||||
finally:
|
||||
registry.deregister("override_target")
|
||||
|
||||
def test_register_tool_override_on_new_name_is_noop_path(self, tmp_path, monkeypatch):
|
||||
"""override=True on a brand-new name still registers cleanly (no existing entry to replace)."""
|
||||
from tools.registry import registry
|
||||
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
plugin_dir = plugins_dir / "new_override_plugin"
|
||||
plugin_dir.mkdir(parents=True)
|
||||
(plugin_dir / "plugin.yaml").write_text(yaml.dump({"name": "new_override_plugin"}))
|
||||
(plugin_dir / "__init__.py").write_text(
|
||||
'def register(ctx):\n'
|
||||
' ctx.register_tool(\n'
|
||||
' name="brand_new_override_tool",\n'
|
||||
' toolset="plugin_new_override_plugin",\n'
|
||||
' schema={"name": "brand_new_override_tool", "description": "New", "parameters": {"type": "object", "properties": {}}},\n'
|
||||
' handler=lambda args, **kw: "ok",\n'
|
||||
' override=True,\n'
|
||||
' )\n'
|
||||
)
|
||||
hermes_home = tmp_path / "hermes_test"
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
yaml.safe_dump({"plugins": {"enabled": ["new_override_plugin"]}})
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
try:
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
assert "brand_new_override_tool" in registry._tools
|
||||
finally:
|
||||
registry.deregister("brand_new_override_tool")
|
||||
|
||||
|
||||
# ── TestPluginToolVisibility ───────────────────────────────────────────────
|
||||
|
||||
|
|
|
|||
|
|
@ -396,6 +396,117 @@ class TestCmdList:
|
|||
cmd_list()
|
||||
|
||||
|
||||
# ── _discover_all_plugins tests ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDiscoverAllPlugins:
|
||||
"""Exercise the recursive scan that powers ``hermes plugins list``.
|
||||
|
||||
Mirrors the layouts the runtime loader handles
|
||||
(:meth:`PluginManager._scan_directory_level`): flat plugins at the root,
|
||||
category-namespaced plugins one level deeper, and user-overrides-bundled
|
||||
on key collision.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _write_plugin(root: Path, segments: list, manifest_name: str = None) -> None:
|
||||
plugin_dir = root
|
||||
for seg in segments:
|
||||
plugin_dir = plugin_dir / seg
|
||||
plugin_dir.mkdir(parents=True, exist_ok=True)
|
||||
manifest = {
|
||||
"name": manifest_name or segments[-1],
|
||||
"version": "0.1.0",
|
||||
"description": f"Test plugin {'/'.join(segments)}",
|
||||
}
|
||||
(plugin_dir / "plugin.yaml").write_text(yaml.dump(manifest))
|
||||
|
||||
def _entries_by_key(self, tmp_path, monkeypatch) -> dict:
|
||||
from hermes_cli import plugins_cmd
|
||||
bundled = tmp_path / "bundled"
|
||||
user = tmp_path / "user"
|
||||
bundled.mkdir()
|
||||
user.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.get_bundled_plugins_dir", lambda: bundled
|
||||
)
|
||||
monkeypatch.setattr(plugins_cmd, "_plugins_dir", lambda: user)
|
||||
return bundled, user, lambda: {
|
||||
e[0]: e for e in plugins_cmd._discover_all_plugins()
|
||||
}
|
||||
|
||||
def test_flat_plugin_uses_manifest_name_as_key(self, tmp_path, monkeypatch):
|
||||
bundled, _, discover = self._entries_by_key(tmp_path, monkeypatch)
|
||||
self._write_plugin(bundled, ["disk-cleanup"])
|
||||
|
||||
entries = discover()
|
||||
assert "disk-cleanup" in entries
|
||||
assert entries["disk-cleanup"][3] == "bundled"
|
||||
|
||||
def test_category_namespaced_plugin_uses_path_derived_key(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""Regression test for the original bug — ``observability/langfuse``
|
||||
and ``image_gen/openai`` must surface under their path-derived key,
|
||||
not vanish because the category directory has no ``plugin.yaml``."""
|
||||
bundled, _, discover = self._entries_by_key(tmp_path, monkeypatch)
|
||||
# langfuse's real manifest declares ``name: langfuse`` (bare), but it
|
||||
# lives under ``observability/`` — the key must reflect the path.
|
||||
self._write_plugin(
|
||||
bundled, ["observability", "langfuse"], manifest_name="langfuse"
|
||||
)
|
||||
self._write_plugin(bundled, ["image_gen", "openai"])
|
||||
|
||||
entries = discover()
|
||||
assert "observability/langfuse" in entries
|
||||
assert "image_gen/openai" in entries
|
||||
# Bare manifest name must NOT leak through as a top-level key.
|
||||
assert "langfuse" not in entries
|
||||
assert "openai" not in entries
|
||||
|
||||
def test_user_overrides_bundled_on_key_collision(self, tmp_path, monkeypatch):
|
||||
bundled, user, discover = self._entries_by_key(tmp_path, monkeypatch)
|
||||
self._write_plugin(bundled, ["observability", "langfuse"])
|
||||
self._write_plugin(user, ["observability", "langfuse"])
|
||||
|
||||
entries = discover()
|
||||
assert entries["observability/langfuse"][3] == "user"
|
||||
|
||||
def test_depth_cap_skips_third_level(self, tmp_path, monkeypatch):
|
||||
"""Anything deeper than ``<root>/<category>/<plugin>/`` is ignored,
|
||||
matching the loader's depth cap."""
|
||||
bundled, _, discover = self._entries_by_key(tmp_path, monkeypatch)
|
||||
# plugins/a/b/c/plugin.yaml — too deep, must NOT be discovered.
|
||||
self._write_plugin(bundled, ["a", "b", "c"])
|
||||
|
||||
entries = discover()
|
||||
assert not any(k.startswith("a/") for k in entries), entries
|
||||
|
||||
def test_bundled_memory_and_context_engine_skipped(self, tmp_path, monkeypatch):
|
||||
"""``plugins/memory/`` and ``plugins/context_engine/`` use their own
|
||||
loaders; bundled entries inside them must not appear in the general
|
||||
list (matches the pre-refactor skip set)."""
|
||||
bundled, _, discover = self._entries_by_key(tmp_path, monkeypatch)
|
||||
self._write_plugin(bundled, ["memory", "honcho"])
|
||||
self._write_plugin(bundled, ["context_engine", "compressor"])
|
||||
self._write_plugin(bundled, ["observability", "langfuse"])
|
||||
|
||||
entries = discover()
|
||||
assert "memory/honcho" not in entries
|
||||
assert "context_engine/compressor" not in entries
|
||||
assert "observability/langfuse" in entries
|
||||
|
||||
def test_user_memory_subdir_is_still_scanned(self, tmp_path, monkeypatch):
|
||||
"""The memory/context_engine skip only applies to *bundled* — a user
|
||||
plugin at ``~/.hermes/plugins/memory/<x>/`` should still be discovered
|
||||
so the user can see what they installed."""
|
||||
bundled, user, discover = self._entries_by_key(tmp_path, monkeypatch)
|
||||
self._write_plugin(user, ["memory", "my-custom-store"])
|
||||
|
||||
entries = discover()
|
||||
assert "memory/my-custom-store" in entries
|
||||
|
||||
|
||||
# ── _copy_example_files tests ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
|
|
|||
400
tests/hermes_cli/test_send_cmd.py
Normal file
400
tests/hermes_cli/test_send_cmd.py
Normal file
|
|
@ -0,0 +1,400 @@
|
|||
"""Tests for the ``hermes send`` CLI subcommand.
|
||||
|
||||
Covers the argument parsing / stdin / file / list behavior of
|
||||
``hermes_cli.send_cmd``. The underlying ``send_message_tool`` is stubbed so
|
||||
no network I/O or gateway is required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli import send_cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse(argv):
|
||||
"""Build the top-level parser and return the parsed args for ``argv``."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(prog="hermes")
|
||||
subparsers = parser.add_subparsers(dest="command")
|
||||
send_cmd.register_send_subparser(subparsers)
|
||||
return parser.parse_args(["send", *argv])
|
||||
|
||||
|
||||
class _FakeTool:
|
||||
"""Replacement for ``tools.send_message_tool.send_message_tool``."""
|
||||
|
||||
def __init__(self, payload):
|
||||
self.payload = payload
|
||||
self.calls = []
|
||||
|
||||
def __call__(self, args, **_kw):
|
||||
self.calls.append(dict(args))
|
||||
return json.dumps(self.payload)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_tool(monkeypatch):
|
||||
"""Install a fake send_message_tool and return the stub for inspection."""
|
||||
import sys
|
||||
import types
|
||||
|
||||
fake = _FakeTool({"success": True, "message_id": "m123"})
|
||||
|
||||
mod = types.ModuleType("tools.send_message_tool")
|
||||
mod.send_message_tool = fake
|
||||
# Register the stub so ``from tools.send_message_tool import ...`` inside
|
||||
# cmd_send resolves to our fake. Also patch the parent ``tools`` package
|
||||
# entry so attribute lookup works.
|
||||
monkeypatch.setitem(sys.modules, "tools.send_message_tool", mod)
|
||||
return fake
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_positional_message_success(fake_tool, capsys):
|
||||
args = _parse(["--to", "telegram", "hello world"])
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
send_cmd.cmd_send(args)
|
||||
assert exc.value.code == 0
|
||||
assert fake_tool.calls == [
|
||||
{"action": "send", "target": "telegram", "message": "hello world"}
|
||||
]
|
||||
out = capsys.readouterr()
|
||||
assert "sent" in out.out or out.out == "" # "sent" is the default success banner
|
||||
|
||||
|
||||
def test_stdin_message(fake_tool, monkeypatch, capsys):
|
||||
# Piped stdin (not a tty) should be consumed as the message body.
|
||||
monkeypatch.setattr("sys.stdin", io.StringIO("piped body\n"))
|
||||
# Force isatty to return False so the CLI reads from stdin.
|
||||
monkeypatch.setattr("sys.stdin.isatty", lambda: False)
|
||||
args = _parse(["--to", "discord:#ops"])
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
send_cmd.cmd_send(args)
|
||||
assert exc.value.code == 0
|
||||
assert fake_tool.calls[0]["message"] == "piped body\n"
|
||||
assert fake_tool.calls[0]["target"] == "discord:#ops"
|
||||
|
||||
|
||||
def test_file_message(fake_tool, tmp_path):
|
||||
body = tmp_path / "msg.txt"
|
||||
body.write_text("from a file\n")
|
||||
args = _parse(["--to", "slack:#eng", "--file", str(body)])
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
send_cmd.cmd_send(args)
|
||||
assert exc.value.code == 0
|
||||
assert fake_tool.calls[0]["message"] == "from a file\n"
|
||||
|
||||
|
||||
def test_file_dash_means_stdin(fake_tool, monkeypatch):
|
||||
monkeypatch.setattr("sys.stdin", io.StringIO("dash body"))
|
||||
args = _parse(["--to", "telegram", "--file", "-"])
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
send_cmd.cmd_send(args)
|
||||
assert exc.value.code == 0
|
||||
assert fake_tool.calls[0]["message"] == "dash body"
|
||||
|
||||
|
||||
def test_subject_prepends_header(fake_tool):
|
||||
args = _parse(["--to", "telegram", "--subject", "[CI]", "body text"])
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
send_cmd.cmd_send(args)
|
||||
assert exc.value.code == 0
|
||||
assert fake_tool.calls[0]["message"] == "[CI]\n\nbody text"
|
||||
|
||||
|
||||
def test_json_mode_emits_payload(fake_tool, capsys):
|
||||
args = _parse(["--to", "telegram", "--json", "hi"])
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
send_cmd.cmd_send(args)
|
||||
assert exc.value.code == 0
|
||||
out = capsys.readouterr().out
|
||||
payload = json.loads(out)
|
||||
assert payload.get("success") is True
|
||||
assert payload.get("message_id") == "m123"
|
||||
|
||||
|
||||
def test_quiet_suppresses_stdout(fake_tool, capsys):
|
||||
args = _parse(["--to", "telegram", "--quiet", "shh"])
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
send_cmd.cmd_send(args)
|
||||
assert exc.value.code == 0
|
||||
out = capsys.readouterr()
|
||||
assert out.out == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_missing_target(fake_tool, capsys, monkeypatch):
|
||||
# Ensure stdin is a tty so the CLI does not try to consume it as a body.
|
||||
monkeypatch.setattr("sys.stdin.isatty", lambda: True)
|
||||
args = _parse(["hello"])
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
send_cmd.cmd_send(args)
|
||||
assert exc.value.code == 2
|
||||
err = capsys.readouterr().err
|
||||
assert "--to" in err
|
||||
|
||||
|
||||
def test_missing_message(fake_tool, capsys, monkeypatch):
|
||||
monkeypatch.setattr("sys.stdin.isatty", lambda: True)
|
||||
args = _parse(["--to", "telegram"])
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
send_cmd.cmd_send(args)
|
||||
assert exc.value.code == 2
|
||||
err = capsys.readouterr().err
|
||||
assert "no message" in err.lower()
|
||||
|
||||
|
||||
def test_file_not_found_is_usage_error(fake_tool, capsys, monkeypatch):
|
||||
monkeypatch.setattr("sys.stdin.isatty", lambda: True)
|
||||
args = _parse(["--to", "telegram", "--file", "/nonexistent/does-not-exist.txt"])
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
send_cmd.cmd_send(args)
|
||||
assert exc.value.code == 2
|
||||
err = capsys.readouterr().err
|
||||
assert "cannot read" in err.lower()
|
||||
|
||||
|
||||
def test_file_decode_error_is_usage_error(fake_tool, capsys, monkeypatch, tmp_path):
|
||||
monkeypatch.setattr("sys.stdin.isatty", lambda: True)
|
||||
bad = tmp_path / "bad-bytes.bin"
|
||||
bad.write_bytes(b"\xff\xfe\x00")
|
||||
|
||||
args = _parse(["--to", "telegram", "--file", str(bad)])
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
send_cmd.cmd_send(args)
|
||||
assert exc.value.code == 2
|
||||
err = capsys.readouterr().err
|
||||
assert "cannot read" in err.lower()
|
||||
|
||||
|
||||
def test_tool_error_returns_failure_exit(monkeypatch, capsys):
|
||||
import sys as _sys
|
||||
import types as _types
|
||||
|
||||
fake_mod = _types.ModuleType("tools.send_message_tool")
|
||||
|
||||
def _bad_tool(args, **_kw):
|
||||
return json.dumps({"error": "platform blew up"})
|
||||
|
||||
fake_mod.send_message_tool = _bad_tool
|
||||
monkeypatch.setitem(_sys.modules, "tools.send_message_tool", fake_mod)
|
||||
|
||||
args = _parse(["--to", "telegram", "nope"])
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
send_cmd.cmd_send(args)
|
||||
assert exc.value.code == 1
|
||||
err = capsys.readouterr().err
|
||||
assert "platform blew up" in err
|
||||
|
||||
|
||||
def test_skipped_result_is_success(monkeypatch):
|
||||
import sys as _sys
|
||||
import types as _types
|
||||
|
||||
fake_mod = _types.ModuleType("tools.send_message_tool")
|
||||
fake_mod.send_message_tool = lambda args, **_kw: json.dumps(
|
||||
{"success": True, "skipped": True, "reason": "duplicate"}
|
||||
)
|
||||
monkeypatch.setitem(_sys.modules, "tools.send_message_tool", fake_mod)
|
||||
|
||||
args = _parse(["--to", "telegram", "dup"])
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
send_cmd.cmd_send(args)
|
||||
assert exc.value.code == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# --list
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_list_human_output(monkeypatch, capsys):
|
||||
import sys as _sys
|
||||
import types as _types
|
||||
|
||||
fake_dir = _types.ModuleType("gateway.channel_directory")
|
||||
fake_dir.format_directory_for_display = lambda: "Available messaging targets:\n\nTelegram:\n telegram:-100123\n"
|
||||
fake_dir.load_directory = lambda: {
|
||||
"platforms": {"telegram": [{"id": "-100123", "name": "Test Group"}]}
|
||||
}
|
||||
monkeypatch.setitem(_sys.modules, "gateway.channel_directory", fake_dir)
|
||||
|
||||
args = _parse(["--list"])
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
send_cmd.cmd_send(args)
|
||||
assert exc.value.code == 0
|
||||
out = capsys.readouterr().out
|
||||
assert "Telegram" in out
|
||||
|
||||
|
||||
def test_list_json(monkeypatch, capsys):
|
||||
import sys as _sys
|
||||
import types as _types
|
||||
|
||||
fake_dir = _types.ModuleType("gateway.channel_directory")
|
||||
fake_dir.format_directory_for_display = lambda: "(ignored in json mode)"
|
||||
fake_dir.load_directory = lambda: {
|
||||
"platforms": {"telegram": [{"id": "-100123", "name": "Test Group"}]}
|
||||
}
|
||||
monkeypatch.setitem(_sys.modules, "gateway.channel_directory", fake_dir)
|
||||
|
||||
args = _parse(["--list", "--json"])
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
send_cmd.cmd_send(args)
|
||||
assert exc.value.code == 0
|
||||
out = capsys.readouterr().out
|
||||
payload = json.loads(out)
|
||||
assert payload["platforms"]["telegram"][0]["name"] == "Test Group"
|
||||
|
||||
|
||||
def test_list_filter_platform(monkeypatch, capsys):
|
||||
import sys as _sys
|
||||
import types as _types
|
||||
|
||||
fake_dir = _types.ModuleType("gateway.channel_directory")
|
||||
fake_dir.format_directory_for_display = lambda: "(should not be called when filter set)"
|
||||
fake_dir.load_directory = lambda: {
|
||||
"platforms": {
|
||||
"telegram": [{"id": "-100123", "name": "TG Chat"}],
|
||||
"discord": [{"id": "555", "name": "bot-home"}],
|
||||
}
|
||||
}
|
||||
monkeypatch.setitem(_sys.modules, "gateway.channel_directory", fake_dir)
|
||||
|
||||
# When --list is set, argparse puts the optional bareword in the
|
||||
# `message` positional slot (where the send-mode body would go).
|
||||
args = _parse(["--list", "telegram"])
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
send_cmd.cmd_send(args)
|
||||
assert exc.value.code == 0
|
||||
out = capsys.readouterr().out
|
||||
assert "telegram" in out.lower()
|
||||
assert "discord" not in out.lower()
|
||||
|
||||
|
||||
def test_list_unknown_platform_fails(monkeypatch, capsys):
|
||||
import sys as _sys
|
||||
import types as _types
|
||||
|
||||
fake_dir = _types.ModuleType("gateway.channel_directory")
|
||||
fake_dir.format_directory_for_display = lambda: ""
|
||||
fake_dir.load_directory = lambda: {"platforms": {"telegram": []}}
|
||||
monkeypatch.setitem(_sys.modules, "gateway.channel_directory", fake_dir)
|
||||
|
||||
args = _parse(["--list", "pigeon-post"])
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
send_cmd.cmd_send(args)
|
||||
assert exc.value.code == 1
|
||||
err = capsys.readouterr().err
|
||||
assert "pigeon-post" in err
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parser registration contract
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_register_send_subparser_is_reusable():
|
||||
"""Sanity check: the registrar returns a parser and wires ``cmd_send``."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers(dest="command")
|
||||
send_parser = send_cmd.register_send_subparser(subparsers)
|
||||
assert send_parser is not None
|
||||
args = parser.parse_args(["send", "--to", "telegram", "hi"])
|
||||
assert args.func is send_cmd.cmd_send
|
||||
assert args.to == "telegram"
|
||||
assert args.message == "hi"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Env loader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_load_hermes_env_bridges_config_yaml_scalars(tmp_path, monkeypatch):
|
||||
"""Top-level config.yaml scalars should be bridged into os.environ.
|
||||
|
||||
This mirrors the gateway/run.py bootstrap behavior: without this, running
|
||||
``hermes send`` from a fresh shell cannot resolve the home channel
|
||||
because ``TELEGRAM_HOME_CHANNEL`` (saved by ``hermes config set``) lives
|
||||
in config.yaml, not in .env — and the gateway's config loader reads via
|
||||
``os.getenv(...)``.
|
||||
"""
|
||||
import os
|
||||
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / ".env").write_text("SOME_TOKEN=abc123\n")
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
"TELEGRAM_HOME_CHANNEL: '5550001111'\nnested:\n ignored: true\n"
|
||||
)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.delenv("TELEGRAM_HOME_CHANNEL", raising=False)
|
||||
monkeypatch.delenv("SOME_TOKEN", raising=False)
|
||||
|
||||
# Force get_hermes_home() to re-resolve under the patched env.
|
||||
from importlib import reload
|
||||
|
||||
import hermes_cli.config as _hc_config
|
||||
reload(_hc_config)
|
||||
|
||||
send_cmd._load_hermes_env()
|
||||
|
||||
assert os.environ.get("SOME_TOKEN") == "abc123"
|
||||
assert os.environ.get("TELEGRAM_HOME_CHANNEL") == "5550001111"
|
||||
|
||||
|
||||
def test_load_hermes_env_does_not_override_existing(tmp_path, monkeypatch):
|
||||
"""Existing env vars must not be clobbered by config.yaml values."""
|
||||
import os
|
||||
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("TELEGRAM_HOME_CHANNEL: yaml_value\n")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setenv("TELEGRAM_HOME_CHANNEL", "env_value")
|
||||
|
||||
from importlib import reload
|
||||
import hermes_cli.config as _hc_config
|
||||
reload(_hc_config)
|
||||
|
||||
send_cmd._load_hermes_env()
|
||||
|
||||
assert os.environ.get("TELEGRAM_HOME_CHANNEL") == "env_value"
|
||||
|
||||
|
||||
def test_load_hermes_env_handles_missing_files(tmp_path, monkeypatch):
|
||||
"""No .env or config.yaml should be a silent no-op, not an exception."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
from importlib import reload
|
||||
import hermes_cli.config as _hc_config
|
||||
reload(_hc_config)
|
||||
|
||||
# Should not raise.
|
||||
send_cmd._load_hermes_env()
|
||||
180
tests/hermes_cli/test_session_recap.py
Normal file
180
tests/hermes_cli/test_session_recap.py
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
"""Unit tests for hermes_cli.session_recap."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.session_recap import build_recap
|
||||
|
||||
|
||||
def _user(text):
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
|
||||
def _assistant(text=None, tool_calls=None):
|
||||
msg = {"role": "assistant", "content": text}
|
||||
if tool_calls:
|
||||
msg["tool_calls"] = tool_calls
|
||||
return msg
|
||||
|
||||
|
||||
def _tool_call(name, args):
|
||||
return {
|
||||
"id": f"call_{name}",
|
||||
"type": "function",
|
||||
"function": {"name": name, "arguments": json.dumps(args)},
|
||||
}
|
||||
|
||||
|
||||
def _tool_result(content="ok"):
|
||||
return {"role": "tool", "content": content}
|
||||
|
||||
|
||||
def test_empty_history():
|
||||
out = build_recap([])
|
||||
assert "Session recap" in out
|
||||
assert "nothing to recap" in out
|
||||
|
||||
|
||||
def test_header_shows_title_when_provided():
|
||||
out = build_recap([_user("hello")], session_title="Refactor the adapter")
|
||||
assert "Refactor the adapter" in out.splitlines()[0]
|
||||
|
||||
|
||||
def test_header_shows_short_id_when_no_title():
|
||||
out = build_recap([_user("hello")], session_id="abcdef1234567890")
|
||||
assert "abcdef12" in out.splitlines()[0]
|
||||
|
||||
|
||||
def test_counts_recent_turns():
|
||||
msgs = [
|
||||
_user("one"),
|
||||
_assistant("first reply"),
|
||||
_user("two"),
|
||||
_assistant("second reply"),
|
||||
]
|
||||
out = build_recap(msgs)
|
||||
assert "2 user turn" in out
|
||||
assert "assistant repl" in out
|
||||
|
||||
|
||||
def test_last_ask_and_reply_are_surfaced():
|
||||
msgs = [
|
||||
_user("old question"),
|
||||
_assistant("old answer"),
|
||||
_user("summarise the docs"),
|
||||
_assistant("here is the summary of the docs you asked for"),
|
||||
]
|
||||
out = build_recap(msgs)
|
||||
assert "summarise the docs" in out
|
||||
assert "summary of the docs" in out
|
||||
|
||||
|
||||
def test_tool_counts_and_files():
|
||||
msgs = [
|
||||
_user("edit the readme and run tests"),
|
||||
_assistant(
|
||||
tool_calls=[
|
||||
_tool_call("read_file", {"path": "README.md"}),
|
||||
_tool_call("patch", {"path": "README.md"}),
|
||||
]
|
||||
),
|
||||
_tool_result(),
|
||||
_tool_result(),
|
||||
_assistant(
|
||||
tool_calls=[
|
||||
_tool_call("terminal", {"command": "pytest"}),
|
||||
]
|
||||
),
|
||||
_tool_result("tests ok"),
|
||||
_assistant("All green."),
|
||||
]
|
||||
out = build_recap(msgs)
|
||||
assert "patch×1" in out
|
||||
assert "terminal×1" in out
|
||||
assert "read_file×1" in out
|
||||
# README.md should appear (may include cwd-relative prefix stripping).
|
||||
assert "README.md" in out
|
||||
|
||||
|
||||
def test_tool_preview_length_truncates_long_user_prompt():
|
||||
long = "x " * 500
|
||||
out = build_recap([_user(long)])
|
||||
ask_line = [l for l in out.splitlines() if "Last ask" in l][0]
|
||||
assert len(ask_line) < 300 # truncated with ellipsis
|
||||
assert "…" in ask_line
|
||||
|
||||
|
||||
def test_respects_recent_window():
|
||||
# 30 turns of user+assistant; only the most recent 20 should be summarised.
|
||||
msgs = []
|
||||
for i in range(30):
|
||||
msgs.append(_user(f"question {i}"))
|
||||
msgs.append(_assistant(f"answer {i}"))
|
||||
out = build_recap(msgs)
|
||||
# We scoped to the 20-turn window but show "of 30/30 total".
|
||||
assert "of 30/30 total" in out
|
||||
|
||||
|
||||
def test_multimodal_content_blocks_flattened():
|
||||
msgs = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "check this file"},
|
||||
{"type": "image_url", "image_url": {"url": "..."}},
|
||||
],
|
||||
},
|
||||
_assistant("Looked at your image."),
|
||||
]
|
||||
out = build_recap(msgs)
|
||||
assert "check this file" in out
|
||||
assert "Looked at your image" in out
|
||||
|
||||
|
||||
def test_handles_arguments_as_dict_not_string():
|
||||
# Some providers return arguments already as a dict.
|
||||
msgs = [
|
||||
_user("go"),
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "patch",
|
||||
"arguments": {"path": "foo.py"},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
out = build_recap(msgs)
|
||||
assert "patch×1" in out
|
||||
assert "foo.py" in out
|
||||
|
||||
|
||||
def test_no_assistant_activity_hint():
|
||||
out = build_recap([_user("just sent my first message")])
|
||||
assert "no assistant activity" in out or "Last ask" in out
|
||||
|
||||
|
||||
def test_tool_message_count_reported():
|
||||
msgs = [
|
||||
_user("go"),
|
||||
_assistant(tool_calls=[_tool_call("read_file", {"path": "a"})]),
|
||||
_tool_result(),
|
||||
_tool_result(),
|
||||
_assistant("done"),
|
||||
]
|
||||
out = build_recap(msgs)
|
||||
assert "2 tool result" in out
|
||||
|
||||
|
||||
def test_ignores_non_mapping_entries_gracefully():
|
||||
msgs = [None, "stray", _user("hi"), _assistant("hello")]
|
||||
# Should not raise.
|
||||
out = build_recap(msgs)
|
||||
assert "Session recap" in out
|
||||
21
tests/hermes_cli/test_tui_bundled.py
Normal file
21
tests/hermes_cli/test_tui_bundled.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
def test_tui_finds_bundled_entry_js(tmp_path):
|
||||
"""_find_bundled_tui finds entry.js bundled in the package."""
|
||||
tui_dist = tmp_path / "hermes_cli" / "tui_dist"
|
||||
tui_dist.mkdir(parents=True)
|
||||
entry = tui_dist / "entry.js"
|
||||
entry.write_text("// bundled TUI", encoding="utf-8")
|
||||
|
||||
from hermes_cli.main import _find_bundled_tui
|
||||
result = _find_bundled_tui(hermes_cli_dir=tmp_path / "hermes_cli")
|
||||
assert result is not None
|
||||
assert result.name == "entry.js"
|
||||
|
||||
|
||||
def test_tui_returns_none_when_no_bundle(tmp_path):
|
||||
"""_find_bundled_tui returns None when no bundle exists."""
|
||||
from hermes_cli.main import _find_bundled_tui
|
||||
result = _find_bundled_tui(hermes_cli_dir=tmp_path / "hermes_cli")
|
||||
assert result is None
|
||||
|
|
@ -523,6 +523,34 @@ def test_launch_tui_exports_model_provider_and_toolsets(monkeypatch, main_mod):
|
|||
assert env["NODE_ENV"] == "production"
|
||||
|
||||
|
||||
def test_make_tui_argv_dev_prebuilds_hermes_ink(monkeypatch, main_mod, tmp_path):
|
||||
tui_dir = tmp_path / "ui-tui"
|
||||
tsx = tui_dir / "node_modules" / ".bin" / "tsx"
|
||||
ink_dir = tui_dir / "packages" / "hermes-ink"
|
||||
tsx.parent.mkdir(parents=True)
|
||||
ink_dir.mkdir(parents=True)
|
||||
tsx.write_text("#!/usr/bin/env node\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(main_mod, "_ensure_tui_node", lambda: None)
|
||||
monkeypatch.setattr(main_mod, "_tui_need_npm_install", lambda _tui_dir: False)
|
||||
monkeypatch.delenv("HERMES_TUI_DIR", raising=False)
|
||||
monkeypatch.setattr(main_mod.shutil, "which", lambda bin_name: f"/usr/bin/{bin_name}")
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_run(cmd, cwd=None, **_kwargs):
|
||||
calls.append((cmd, cwd))
|
||||
return types.SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr(main_mod.subprocess, "run", fake_run)
|
||||
|
||||
argv, cwd = main_mod._make_tui_argv(tui_dir, tui_dev=True)
|
||||
|
||||
assert argv == [str(tsx), "src/entry.tsx"]
|
||||
assert cwd == tui_dir
|
||||
assert calls == [(["/usr/bin/npm", "run", "build"], str(ink_dir))]
|
||||
|
||||
|
||||
def test_print_tui_exit_summary_includes_resume_and_token_totals(monkeypatch, capsys):
|
||||
import hermes_cli.main as main_mod
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ def test_check_for_updates_expired_cache(tmp_path, monkeypatch):
|
|||
|
||||
|
||||
def test_check_for_updates_no_git_dir(tmp_path, monkeypatch):
|
||||
"""Returns None when .git directory doesn't exist anywhere."""
|
||||
"""Falls back to PyPI check when .git directory doesn't exist anywhere."""
|
||||
import hermes_cli.banner as banner
|
||||
|
||||
# Create a fake banner.py so the fallback path also has no .git
|
||||
|
|
@ -70,8 +70,9 @@ def test_check_for_updates_no_git_dir(tmp_path, monkeypatch):
|
|||
monkeypatch.setattr(banner, "__file__", str(fake_banner))
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
with patch("hermes_cli.banner.subprocess.run") as mock_run:
|
||||
result = banner.check_for_updates()
|
||||
assert result is None
|
||||
with patch("hermes_cli.banner.check_via_pypi", return_value=0):
|
||||
result = banner.check_for_updates()
|
||||
assert result == 0
|
||||
mock_run.assert_not_called()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -178,8 +178,11 @@ class TestLaunchdPlistPath:
|
|||
raise AssertionError("PATH key not found in plist")
|
||||
|
||||
def test_plist_path_includes_node_modules_bin(self):
|
||||
node_bin_dir = gateway_cli.PROJECT_ROOT / "node_modules" / ".bin"
|
||||
if not node_bin_dir.is_dir():
|
||||
pytest.skip("node_modules/.bin not present in this checkout")
|
||||
plist = gateway_cli.generate_launchd_plist()
|
||||
node_bin = str(gateway_cli.PROJECT_ROOT / "node_modules" / ".bin")
|
||||
node_bin = str(node_bin_dir)
|
||||
lines = plist.splitlines()
|
||||
for i, line in enumerate(lines):
|
||||
if "<key>PATH</key>" in line.strip():
|
||||
|
|
|
|||
|
|
@ -146,3 +146,92 @@ class TestReconfigureWritesProvider:
|
|||
assert config["video_gen"]["provider"] == "noenv_video"
|
||||
assert config["video_gen"]["model"] == "noenv_video-video-v1"
|
||||
assert config["video_gen"]["use_gateway"] is False
|
||||
|
||||
|
||||
class TestPluginVideoProvidersRow:
|
||||
"""Tests for _plugin_video_gen_providers row contents."""
|
||||
|
||||
def test_post_setup_propagated_when_declared(self, monkeypatch):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
video_gen_registry.register_provider(_FakeVideoProvider(
|
||||
"xai_video",
|
||||
schema={
|
||||
"name": "xAI Grok Imagine",
|
||||
"badge": "paid",
|
||||
"tag": "grok video",
|
||||
"env_vars": [],
|
||||
"post_setup": "xai_grok",
|
||||
},
|
||||
))
|
||||
|
||||
rows = tools_config._plugin_video_gen_providers()
|
||||
match = next(r for r in rows if r.get("video_gen_plugin_name") == "xai_video")
|
||||
assert match["post_setup"] == "xai_grok"
|
||||
|
||||
def test_post_setup_omitted_when_not_declared(self, monkeypatch):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
video_gen_registry.register_provider(_FakeVideoProvider("plain_video"))
|
||||
|
||||
rows = tools_config._plugin_video_gen_providers()
|
||||
match = next(r for r in rows if r.get("video_gen_plugin_name") == "plain_video")
|
||||
assert "post_setup" not in match
|
||||
|
||||
|
||||
class TestVideoPluginProviderActive:
|
||||
"""Tests for _is_provider_active recognizing video_gen_plugin_name."""
|
||||
|
||||
def test_active_when_video_gen_provider_matches(self):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
config = {"video_gen": {"provider": "xai"}}
|
||||
row = {"name": "xAI Grok Imagine", "video_gen_plugin_name": "xai"}
|
||||
|
||||
assert tools_config._is_provider_active(row, config) is True
|
||||
|
||||
def test_inactive_when_video_gen_provider_differs(self):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
config = {"video_gen": {"provider": "fal"}}
|
||||
row = {"name": "xAI Grok Imagine", "video_gen_plugin_name": "xai"}
|
||||
|
||||
assert tools_config._is_provider_active(row, config) is False
|
||||
|
||||
def test_inactive_when_video_gen_section_missing(self):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
row = {"name": "xAI Grok Imagine", "video_gen_plugin_name": "xai"}
|
||||
assert tools_config._is_provider_active(row, {}) is False
|
||||
|
||||
def test_detect_active_index_picks_video_plugin_match(self, monkeypatch):
|
||||
"""When xAI is the configured video_gen provider, the picker should
|
||||
default to the xAI row even if FAL_KEY happens to be set in env.
|
||||
|
||||
Regression: previously _detect_active_provider_index() saw
|
||||
_is_provider_active(xai) return False (no video_gen branch),
|
||||
skipped xAI (empty env_vars), and matched the FAL row via the
|
||||
env-var fallback — so the picker visually defaulted to FAL even
|
||||
though the user picked xAI. The xAI row uses empty env_vars
|
||||
because authentication is handled via xAI Grok OAuth (post_setup
|
||||
hook).
|
||||
"""
|
||||
from hermes_cli import tools_config
|
||||
|
||||
monkeypatch.setattr(
|
||||
tools_config,
|
||||
"get_env_value",
|
||||
lambda key: "fal-key" if key == "FAL_KEY" else "",
|
||||
)
|
||||
|
||||
config = {"video_gen": {"provider": "xai"}}
|
||||
providers = [
|
||||
{"name": "xAI Grok Imagine", "env_vars": [], "video_gen_plugin_name": "xai"},
|
||||
{
|
||||
"name": "FAL.ai",
|
||||
"env_vars": [{"key": "FAL_KEY", "prompt": "FAL"}],
|
||||
"video_gen_plugin_name": "fal",
|
||||
},
|
||||
]
|
||||
|
||||
assert tools_config._detect_active_provider_index(providers, config) == 0
|
||||
|
|
|
|||
140
tests/hermes_cli/test_whatsapp_setup_ordering.py
Normal file
140
tests/hermes_cli/test_whatsapp_setup_ordering.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
"""Regression tests for ``cmd_whatsapp`` env-var write ordering.
|
||||
|
||||
Before the fix, ``hermes whatsapp`` wrote ``WHATSAPP_ENABLED=true`` at
|
||||
step 2 — before npm install (step 4) and before QR pairing (step 6).
|
||||
If the user Ctrl+C'd at any later step, ``.env`` claimed WhatsApp was
|
||||
ready when the bridge still had no ``creds.json``. Every subsequent
|
||||
``hermes gateway`` then paid a 30s bridge-bootstrap timeout and queued
|
||||
WhatsApp for indefinite retries — looking like "the gateway is broken."
|
||||
|
||||
The fix: only set ``WHATSAPP_ENABLED=true`` once pairing actually
|
||||
succeeds (creds.json exists). Aborted setup leaves no enabled state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
from contextlib import redirect_stdout
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_home(tmp_path, monkeypatch):
|
||||
home = tmp_path / "home"
|
||||
hermes = home / ".hermes"
|
||||
hermes.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: home)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes))
|
||||
# Ensure get_env_value cache doesn't carry stale state.
|
||||
for key in list(os.environ):
|
||||
if key.startswith("WHATSAPP_"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
return hermes
|
||||
|
||||
|
||||
def _env_value(hermes_home: Path, key: str) -> str | None:
|
||||
env_file = hermes_home / ".env"
|
||||
if not env_file.exists():
|
||||
return None
|
||||
for line in env_file.read_text().splitlines():
|
||||
if "=" not in line:
|
||||
continue
|
||||
k, _, v = line.partition("=")
|
||||
if k.strip() == key:
|
||||
return v.strip().strip('"').strip("'")
|
||||
return None
|
||||
|
||||
|
||||
def test_aborted_setup_does_not_enable_whatsapp(isolated_home, monkeypatch):
|
||||
"""User picks mode 1, then Ctrl+C's at the allowed-users prompt.
|
||||
|
||||
WHATSAPP_ENABLED must NOT be present in .env after abort.
|
||||
"""
|
||||
from hermes_cli.main import cmd_whatsapp
|
||||
|
||||
# First input() = mode choice, second input() = allowed-users prompt
|
||||
# We raise KeyboardInterrupt on the second call to simulate abort.
|
||||
inputs = iter(["1"])
|
||||
|
||||
def fake_input(_prompt=""):
|
||||
try:
|
||||
return next(inputs)
|
||||
except StopIteration:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
monkeypatch.setattr("builtins.input", fake_input)
|
||||
# _require_tty calls sys.stdin.isatty — make it pass.
|
||||
monkeypatch.setattr("hermes_cli.main._require_tty", lambda *_a, **_kw: None)
|
||||
# No node, no bridge script — we shouldn't reach those steps anyway.
|
||||
|
||||
buf = io.StringIO()
|
||||
with redirect_stdout(buf):
|
||||
try:
|
||||
cmd_whatsapp(MagicMock())
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
assert _env_value(isolated_home, "WHATSAPP_ENABLED") is None, (
|
||||
"Setup aborted before pairing — WHATSAPP_ENABLED must not be set. "
|
||||
f"Got .env: {(isolated_home / '.env').read_text() if (isolated_home / '.env').exists() else '(missing)'}"
|
||||
)
|
||||
|
||||
|
||||
def test_existing_pairing_skip_branch_enables_whatsapp(isolated_home, monkeypatch):
|
||||
"""User runs ``hermes whatsapp`` with an existing paired session and
|
||||
chooses "no, keep my session" at the re-pair prompt. The env var
|
||||
should be (re-)written to true so the gateway picks WhatsApp back up,
|
||||
even if the var was lost since the original pairing.
|
||||
"""
|
||||
from hermes_cli.main import cmd_whatsapp
|
||||
|
||||
# Pre-create a paired session WITHOUT WHATSAPP_ENABLED in .env.
|
||||
session = isolated_home / "whatsapp" / "session"
|
||||
session.mkdir(parents=True)
|
||||
(session / "creds.json").write_text("{}")
|
||||
monkeypatch.setenv("WHATSAPP_MODE", "bot")
|
||||
monkeypatch.setenv("WHATSAPP_ALLOWED_USERS", "15551234567")
|
||||
|
||||
# mode already set → skip mode prompt; users already set → skip update
|
||||
# prompt with "no"; pairing exists → "no, keep session" → return.
|
||||
inputs = iter(["n", "n"])
|
||||
|
||||
def fake_input(_prompt=""):
|
||||
try:
|
||||
return next(inputs)
|
||||
except StopIteration:
|
||||
return "n"
|
||||
|
||||
monkeypatch.setattr("builtins.input", fake_input)
|
||||
monkeypatch.setattr("hermes_cli.main._require_tty", lambda *_a, **_kw: None)
|
||||
# Skip the bridge npm install — we're testing setup-ordering, not bridge
|
||||
# bootstrapping. Pretend node_modules exists (Path.exists -> True for that
|
||||
# specific check is hard to scope, so instead pretend npm install would
|
||||
# succeed silently if reached).
|
||||
monkeypatch.setattr(
|
||||
"subprocess.run",
|
||||
lambda *_a, **_kw: MagicMock(returncode=0, stderr=""),
|
||||
)
|
||||
monkeypatch.setattr("shutil.which", lambda _name: "/usr/bin/npm")
|
||||
# Patch (bridge_dir / "node_modules").exists() by stubbing Path.exists
|
||||
# to True for that one specific subpath. Easier: pre-create it as a
|
||||
# symlink to /tmp. But we can't write to the repo. Instead, stub
|
||||
# Path.exists wholesale to True for node_modules; the creds.json check
|
||||
# in the same function still works because we wrote it ourselves.
|
||||
_orig_exists = Path.exists
|
||||
def _stub_exists(self):
|
||||
if self.name == "node_modules":
|
||||
return True
|
||||
return _orig_exists(self)
|
||||
monkeypatch.setattr(Path, "exists", _stub_exists)
|
||||
|
||||
buf = io.StringIO()
|
||||
with redirect_stdout(buf):
|
||||
cmd_whatsapp(MagicMock())
|
||||
|
||||
# The skip-rebar branch should have set the env var on its way out.
|
||||
assert _env_value(isolated_home, "WHATSAPP_ENABLED") == "true"
|
||||
|
|
@ -72,10 +72,13 @@ class TestXAIImageGenProvider:
|
|||
|
||||
provider = XAIImageGenProvider()
|
||||
schema = provider.get_setup_schema()
|
||||
assert schema["name"] == "xAI (Grok)"
|
||||
assert schema["name"] == "xAI Grok Imagine (image)"
|
||||
assert schema["badge"] == "paid"
|
||||
assert len(schema["env_vars"]) == 1
|
||||
assert schema["env_vars"][0]["key"] == "XAI_API_KEY"
|
||||
# Auth resolution is delegated to the shared "xai_grok" post_setup
|
||||
# hook so the picker doesn't blindly prompt for XAI_API_KEY when the
|
||||
# user is already signed in via xAI Grok OAuth.
|
||||
assert schema["env_vars"] == []
|
||||
assert schema["post_setup"] == "xai_grok"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
207
tests/plugins/model_providers/test_deepseek_profile.py
Normal file
207
tests/plugins/model_providers/test_deepseek_profile.py
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
"""Unit tests for the DeepSeek provider profile's thinking-mode wiring.
|
||||
|
||||
DeepSeek V4 (and the legacy ``deepseek-reasoner``) expects every request to
|
||||
carry an explicit ``extra_body.thinking`` parameter. Omitting it makes the
|
||||
server default to thinking-mode ON, which then enforces the
|
||||
``reasoning_content``-must-be-echoed-back contract on subsequent turns and
|
||||
breaks the conversation with HTTP 400 (#15700, #17212, #17825).
|
||||
|
||||
These tests pin the profile's wire-shape contract so DeepSeek requests stay
|
||||
correctly shaped without going live.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deepseek_profile():
|
||||
"""Resolve the registered DeepSeek profile.
|
||||
|
||||
Going through ``providers.get_provider_profile`` keeps the test honest —
|
||||
if someone later replaces the registered class with a plain
|
||||
``ProviderProfile``, every assertion below collapses.
|
||||
"""
|
||||
# ``model_tools`` triggers plugin discovery on import, which is what
|
||||
# registers the DeepSeek profile in the global provider registry.
|
||||
import model_tools # noqa: F401
|
||||
import providers
|
||||
|
||||
profile = providers.get_provider_profile("deepseek")
|
||||
assert profile is not None, "deepseek provider profile must be registered"
|
||||
return profile
|
||||
|
||||
|
||||
class TestDeepSeekThinkingWireShape:
|
||||
"""``build_api_kwargs_extras`` produces DeepSeek's exact wire format."""
|
||||
|
||||
def test_v4_pro_default_enables_thinking_without_effort(self, deepseek_profile):
|
||||
"""No reasoning_config → thinking enabled, server picks default effort."""
|
||||
extra_body, top_level = deepseek_profile.build_api_kwargs_extras(
|
||||
reasoning_config=None, model="deepseek-v4-pro"
|
||||
)
|
||||
assert extra_body == {"thinking": {"type": "enabled"}}
|
||||
assert top_level == {}
|
||||
|
||||
def test_v4_pro_enabled_with_high_effort(self, deepseek_profile):
|
||||
extra_body, top_level = deepseek_profile.build_api_kwargs_extras(
|
||||
reasoning_config={"enabled": True, "effort": "high"},
|
||||
model="deepseek-v4-pro",
|
||||
)
|
||||
assert extra_body == {"thinking": {"type": "enabled"}}
|
||||
assert top_level == {"reasoning_effort": "high"}
|
||||
|
||||
@pytest.mark.parametrize("effort", ["low", "medium", "high"])
|
||||
def test_standard_efforts_pass_through(self, deepseek_profile, effort):
|
||||
_, top_level = deepseek_profile.build_api_kwargs_extras(
|
||||
reasoning_config={"enabled": True, "effort": effort},
|
||||
model="deepseek-v4-pro",
|
||||
)
|
||||
assert top_level == {"reasoning_effort": effort}
|
||||
|
||||
@pytest.mark.parametrize("effort", ["xhigh", "max", "MAX", " Max "])
|
||||
def test_xhigh_and_max_normalize_to_max(self, deepseek_profile, effort):
|
||||
_, top_level = deepseek_profile.build_api_kwargs_extras(
|
||||
reasoning_config={"enabled": True, "effort": effort},
|
||||
model="deepseek-v4-pro",
|
||||
)
|
||||
assert top_level == {"reasoning_effort": "max"}
|
||||
|
||||
def test_explicitly_disabled_sends_disabled_marker(self, deepseek_profile):
|
||||
"""``reasoning_config.enabled=False`` → ``thinking.type=disabled``.
|
||||
|
||||
The crucial bit is that the parameter is *sent* at all — DeepSeek
|
||||
defaults to thinking-on when ``thinking`` is absent.
|
||||
"""
|
||||
extra_body, top_level = deepseek_profile.build_api_kwargs_extras(
|
||||
reasoning_config={"enabled": False}, model="deepseek-v4-pro"
|
||||
)
|
||||
assert extra_body == {"thinking": {"type": "disabled"}}
|
||||
# No effort when disabled — DeepSeek rejects it.
|
||||
assert top_level == {}
|
||||
|
||||
def test_disabled_ignores_effort_field(self, deepseek_profile):
|
||||
"""Effort silently dropped when thinking is off."""
|
||||
_, top_level = deepseek_profile.build_api_kwargs_extras(
|
||||
reasoning_config={"enabled": False, "effort": "high"},
|
||||
model="deepseek-v4-pro",
|
||||
)
|
||||
assert top_level == {}
|
||||
|
||||
def test_unknown_effort_omits_top_level(self, deepseek_profile):
|
||||
"""Garbage effort → omit reasoning_effort so DeepSeek applies its default."""
|
||||
_, top_level = deepseek_profile.build_api_kwargs_extras(
|
||||
reasoning_config={"enabled": True, "effort": "garbage"},
|
||||
model="deepseek-v4-pro",
|
||||
)
|
||||
assert top_level == {}
|
||||
|
||||
def test_empty_effort_omits_top_level(self, deepseek_profile):
|
||||
_, top_level = deepseek_profile.build_api_kwargs_extras(
|
||||
reasoning_config={"enabled": True, "effort": ""},
|
||||
model="deepseek-v4-pro",
|
||||
)
|
||||
assert top_level == {}
|
||||
|
||||
|
||||
class TestDeepSeekModelGating:
|
||||
"""V4 family + ``deepseek-reasoner`` get thinking; V3 stays untouched."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"deepseek-v4-pro",
|
||||
"deepseek-v4-flash",
|
||||
"deepseek-v4-future-variant",
|
||||
"deepseek-reasoner",
|
||||
"DEEPSEEK-V4-PRO", # case-insensitive
|
||||
],
|
||||
)
|
||||
def test_thinking_capable_models_emit_thinking(self, deepseek_profile, model):
|
||||
extra_body, _ = deepseek_profile.build_api_kwargs_extras(
|
||||
reasoning_config=None, model=model
|
||||
)
|
||||
assert extra_body == {"thinking": {"type": "enabled"}}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"deepseek-chat", # V3 alias
|
||||
"deepseek-v3-0324", # explicit V3
|
||||
"deepseek-v3.1", # V3 minor revisions
|
||||
"", # bare/unknown
|
||||
None, # missing
|
||||
"deepseek-unknown", # unrecognized
|
||||
],
|
||||
)
|
||||
def test_non_thinking_models_emit_nothing(self, deepseek_profile, model):
|
||||
extra_body, top_level = deepseek_profile.build_api_kwargs_extras(
|
||||
reasoning_config={"enabled": True, "effort": "high"}, model=model
|
||||
)
|
||||
assert extra_body == {}
|
||||
assert top_level == {}
|
||||
|
||||
|
||||
class TestDeepSeekFullKwargsIntegration:
|
||||
"""End-to-end: the transport's full kwargs match DeepSeek's live wire format.
|
||||
|
||||
The live test harness in ``tests/run_agent/test_deepseek_v4_thinking_live.py``
|
||||
sends ``{"reasoning_effort": "high", "extra_body": {"thinking": {"type":
|
||||
"enabled"}}}``. Confirm the transport produces that exact shape when wired
|
||||
through the registered DeepSeek profile.
|
||||
"""
|
||||
|
||||
def test_full_kwargs_match_live_wire_shape(self, deepseek_profile):
|
||||
from agent.transports.chat_completions import ChatCompletionsTransport
|
||||
|
||||
kwargs = ChatCompletionsTransport().build_kwargs(
|
||||
model="deepseek-v4-pro",
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
tools=None,
|
||||
provider_profile=deepseek_profile,
|
||||
reasoning_config={"enabled": True, "effort": "high"},
|
||||
base_url="https://api.deepseek.com/v1",
|
||||
provider_name="deepseek",
|
||||
)
|
||||
assert kwargs["model"] == "deepseek-v4-pro"
|
||||
assert kwargs["reasoning_effort"] == "high"
|
||||
assert kwargs["extra_body"] == {"thinking": {"type": "enabled"}}
|
||||
|
||||
def test_v3_chat_full_kwargs_omit_thinking(self, deepseek_profile):
|
||||
from agent.transports.chat_completions import ChatCompletionsTransport
|
||||
|
||||
kwargs = ChatCompletionsTransport().build_kwargs(
|
||||
model="deepseek-chat",
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
tools=None,
|
||||
provider_profile=deepseek_profile,
|
||||
reasoning_config={"enabled": True, "effort": "high"},
|
||||
base_url="https://api.deepseek.com/v1",
|
||||
provider_name="deepseek",
|
||||
)
|
||||
assert "reasoning_effort" not in kwargs
|
||||
assert "extra_body" not in kwargs or "thinking" not in kwargs.get("extra_body", {})
|
||||
|
||||
|
||||
class TestDeepSeekAuxModel:
|
||||
"""DeepSeek aux model is set on the profile so users stop seeing the
|
||||
bogus 'No auxiliary LLM provider configured' warning (#26924).
|
||||
|
||||
Pinned at the profile layer rather than the legacy
|
||||
`_API_KEY_PROVIDER_AUX_MODELS_FALLBACK` dict — new providers are
|
||||
expected to set `default_aux_model` on `ProviderProfile`, and the
|
||||
fallback dict only exists for providers that predate the profiles
|
||||
system.
|
||||
"""
|
||||
|
||||
def test_profile_advertises_deepseek_chat(self, deepseek_profile):
|
||||
assert deepseek_profile.default_aux_model == "deepseek-chat"
|
||||
|
||||
def test_consumer_api_returns_deepseek_chat(self):
|
||||
from agent.auxiliary_client import _get_aux_model_for_provider
|
||||
assert _get_aux_model_for_provider("deepseek") == "deepseek-chat"
|
||||
|
||||
def test_consumer_api_returns_non_empty(self):
|
||||
from agent.auxiliary_client import _get_aux_model_for_provider
|
||||
assert _get_aux_model_for_provider("deepseek") != ""
|
||||
|
|
@ -2,6 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -164,7 +165,542 @@ class TestHooksInert:
|
|||
|
||||
# Each hook should just return; no exceptions.
|
||||
mod.on_pre_llm_call(task_id="t", session_id="s", messages=[{"role": "user", "content": "hi"}])
|
||||
mod.on_pre_llm_request(task_id="t", session_id="s", api_call_count=1, messages=[])
|
||||
mod.on_pre_llm_request(task_id="t", session_id="s", api_call_count=1, request_messages=[])
|
||||
mod.on_post_llm_call(task_id="t", session_id="s", api_call_count=1)
|
||||
mod.on_pre_tool_call(tool_name="read_file", args={}, task_id="t", session_id="s")
|
||||
mod.on_post_tool_call(tool_name="read_file", args={}, result="ok", task_id="t", session_id="s")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Placeholder-credential guard (#23823).
|
||||
#
|
||||
# Regression coverage for the silent-failure bug: when an operator leaves
|
||||
# HERMES_LANGFUSE_PUBLIC_KEY / SECRET_KEY at a template value like
|
||||
# "placeholder", "test-key", or "your-langfuse-key", the SDK accepts the
|
||||
# credentials at construction time (it does no server-side validation
|
||||
# eagerly) but drops every trace at flush time, with no signal in the
|
||||
# Hermes logs. The fix in `_get_langfuse()` validates the documented
|
||||
# `pk-lf-` / `sk-lf-` prefix Langfuse always issues, surfaces a one-shot
|
||||
# warning naming the offending env var(s), and short-circuits via the
|
||||
# same `_INIT_FAILED` path used for missing credentials so subsequent
|
||||
# hook invocations don't re-log.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeLangfuse:
|
||||
"""Stand-in for the real :class:`langfuse.Langfuse` so tests don't
|
||||
need the optional ``langfuse`` SDK installed. The plugin's runtime
|
||||
gate refuses to proceed past ``if Langfuse is None`` when the SDK
|
||||
is missing, which would short-circuit before the placeholder check
|
||||
can fire. Patching ``plugin.Langfuse`` with this class lets the
|
||||
placeholder validator exercise its full code path."""
|
||||
|
||||
instances: list["_FakeLangfuse"] = []
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
_FakeLangfuse.instances.append(self)
|
||||
|
||||
|
||||
class TestPlaceholderKeyDetection:
|
||||
LOGGER_NAME = "plugins.observability.langfuse"
|
||||
|
||||
def _fresh_plugin(self, monkeypatch=None):
|
||||
mod_name = "plugins.observability.langfuse"
|
||||
sys.modules.pop(mod_name, None)
|
||||
mod = importlib.import_module(mod_name)
|
||||
if monkeypatch is not None:
|
||||
# Pretend the SDK is installed so `_get_langfuse()` actually
|
||||
# reaches the placeholder check. Real SDK calls are never
|
||||
# made because the placeholder/missing-credentials paths
|
||||
# return before constructing a client.
|
||||
_FakeLangfuse.instances.clear()
|
||||
monkeypatch.setattr(mod, "Langfuse", _FakeLangfuse, raising=False)
|
||||
return mod
|
||||
|
||||
@staticmethod
|
||||
def _clear_env(monkeypatch):
|
||||
for k in (
|
||||
"HERMES_LANGFUSE_PUBLIC_KEY", "HERMES_LANGFUSE_SECRET_KEY",
|
||||
"LANGFUSE_PUBLIC_KEY", "LANGFUSE_SECRET_KEY",
|
||||
):
|
||||
monkeypatch.delenv(k, raising=False)
|
||||
|
||||
# -- helper unit tests (no SDK stub needed: these don't go through
|
||||
# _get_langfuse, they exercise the pure-Python helpers directly) ------
|
||||
|
||||
def test_redact_key_preview_empty(self, monkeypatch):
|
||||
self._clear_env(monkeypatch)
|
||||
plugin = self._fresh_plugin()
|
||||
assert plugin._redact_key_preview("") == "<empty>"
|
||||
|
||||
def test_redact_key_preview_short_value_echoed(self, monkeypatch):
|
||||
"""Short placeholder strings are echoed in full so the operator
|
||||
can see exactly which template they forgot to replace."""
|
||||
self._clear_env(monkeypatch)
|
||||
plugin = self._fresh_plugin()
|
||||
assert plugin._redact_key_preview("placeholder") == "'placeholder'"
|
||||
assert plugin._redact_key_preview("test-key") == "'test-key'"
|
||||
|
||||
def test_redact_key_preview_long_value_truncated(self, monkeypatch):
|
||||
"""If an operator pasted a real secret into the wrong env var the
|
||||
preview must NOT echo it in full — only the leading 6 chars."""
|
||||
self._clear_env(monkeypatch)
|
||||
plugin = self._fresh_plugin()
|
||||
result = plugin._redact_key_preview("sk-lf-abcdefghijklmnop")
|
||||
assert "abcdefghij" not in result
|
||||
assert result.startswith("'sk-lf-")
|
||||
assert result.endswith("...'")
|
||||
|
||||
def test_validate_langfuse_key_accepts_documented_prefix(self, monkeypatch):
|
||||
self._clear_env(monkeypatch)
|
||||
plugin = self._fresh_plugin()
|
||||
assert plugin._validate_langfuse_key(
|
||||
"HERMES_LANGFUSE_PUBLIC_KEY", "pk-lf-real-public-xyz"
|
||||
) is None
|
||||
assert plugin._validate_langfuse_key(
|
||||
"HERMES_LANGFUSE_SECRET_KEY", "sk-lf-real-secret-xyz"
|
||||
) is None
|
||||
|
||||
def test_validate_langfuse_key_rejects_wrong_prefix(self, monkeypatch):
|
||||
self._clear_env(monkeypatch)
|
||||
plugin = self._fresh_plugin()
|
||||
msg = plugin._validate_langfuse_key(
|
||||
"HERMES_LANGFUSE_PUBLIC_KEY", "placeholder"
|
||||
)
|
||||
assert msg is not None
|
||||
assert "HERMES_LANGFUSE_PUBLIC_KEY" in msg
|
||||
assert "pk-lf-" in msg
|
||||
|
||||
def test_validate_langfuse_key_unknown_name_passes(self, monkeypatch):
|
||||
"""Defensive: an env var with no registered prefix is trusted."""
|
||||
self._clear_env(monkeypatch)
|
||||
plugin = self._fresh_plugin()
|
||||
assert plugin._validate_langfuse_key("HERMES_LANGFUSE_BASE_URL", "anything") is None
|
||||
|
||||
# -- end-to-end _get_langfuse() behaviour --------------------------------
|
||||
# These tests pass `monkeypatch` to _fresh_plugin() so the helper can
|
||||
# stub out `Langfuse` (the optional SDK). Without that, every call
|
||||
# short-circuits at `if Langfuse is None` before reaching the
|
||||
# placeholder validator — masking the very behaviour we're testing.
|
||||
|
||||
def test_placeholder_public_key_warns_and_skips(self, monkeypatch, caplog):
|
||||
self._clear_env(monkeypatch)
|
||||
monkeypatch.setenv("HERMES_LANGFUSE_PUBLIC_KEY", "placeholder")
|
||||
monkeypatch.setenv("HERMES_LANGFUSE_SECRET_KEY", "sk-lf-real-secret-xyz")
|
||||
plugin = self._fresh_plugin(monkeypatch)
|
||||
with caplog.at_level(logging.WARNING, logger=self.LOGGER_NAME):
|
||||
assert plugin._get_langfuse() is None
|
||||
text = caplog.text
|
||||
assert "HERMES_LANGFUSE_PUBLIC_KEY" in text
|
||||
assert "'placeholder'" in text
|
||||
assert "pk-lf-" in text
|
||||
# The valid secret value must NOT appear (the var NAME does, in
|
||||
# the "or unset ..." hint, but the value preview shouldn't).
|
||||
assert "'sk-lf-" not in text
|
||||
# Never constructed the SDK client — short-circuited before that.
|
||||
assert _FakeLangfuse.instances == []
|
||||
|
||||
def test_placeholder_secret_key_warns_and_skips(self, monkeypatch, caplog):
|
||||
self._clear_env(monkeypatch)
|
||||
monkeypatch.setenv("HERMES_LANGFUSE_PUBLIC_KEY", "pk-lf-real-public-xyz")
|
||||
monkeypatch.setenv("HERMES_LANGFUSE_SECRET_KEY", "test-key")
|
||||
plugin = self._fresh_plugin(monkeypatch)
|
||||
with caplog.at_level(logging.WARNING, logger=self.LOGGER_NAME):
|
||||
assert plugin._get_langfuse() is None
|
||||
text = caplog.text
|
||||
assert "HERMES_LANGFUSE_SECRET_KEY" in text
|
||||
assert "'test-key'" in text
|
||||
assert "sk-lf-" in text
|
||||
# The valid public value must NOT appear.
|
||||
assert "'pk-lf-" not in text
|
||||
assert _FakeLangfuse.instances == []
|
||||
|
||||
def test_both_placeholders_one_warning_with_both_keys(self, monkeypatch, caplog):
|
||||
self._clear_env(monkeypatch)
|
||||
monkeypatch.setenv("HERMES_LANGFUSE_PUBLIC_KEY", "placeholder")
|
||||
monkeypatch.setenv("HERMES_LANGFUSE_SECRET_KEY", "placeholder")
|
||||
plugin = self._fresh_plugin(monkeypatch)
|
||||
with caplog.at_level(logging.WARNING, logger=self.LOGGER_NAME):
|
||||
assert plugin._get_langfuse() is None
|
||||
warnings = [r for r in caplog.records if r.levelname == "WARNING"
|
||||
and r.name == self.LOGGER_NAME]
|
||||
assert len(warnings) == 1, (
|
||||
f"Expected a single combined warning; got {len(warnings)}:\n"
|
||||
+ "\n".join(r.getMessage() for r in warnings)
|
||||
)
|
||||
text = warnings[0].getMessage()
|
||||
assert "HERMES_LANGFUSE_PUBLIC_KEY" in text
|
||||
assert "HERMES_LANGFUSE_SECRET_KEY" in text
|
||||
|
||||
def test_repeated_calls_do_not_re_warn(self, monkeypatch, caplog):
|
||||
"""The cached ``_INIT_FAILED`` sentinel must short-circuit
|
||||
subsequent calls so each hook invocation isn't a fresh log
|
||||
line — otherwise a busy gateway will spam the operator's
|
||||
terminal."""
|
||||
self._clear_env(monkeypatch)
|
||||
monkeypatch.setenv("HERMES_LANGFUSE_PUBLIC_KEY", "placeholder")
|
||||
monkeypatch.setenv("HERMES_LANGFUSE_SECRET_KEY", "placeholder")
|
||||
plugin = self._fresh_plugin(monkeypatch)
|
||||
with caplog.at_level(logging.WARNING, logger=self.LOGGER_NAME):
|
||||
for _ in range(15):
|
||||
assert plugin._get_langfuse() is None
|
||||
warnings = [r for r in caplog.records if r.levelname == "WARNING"
|
||||
and r.name == self.LOGGER_NAME]
|
||||
assert len(warnings) == 1, (
|
||||
f"Warning fired {len(warnings)} times across 15 calls; "
|
||||
"expected 1 (cached via _INIT_FAILED)"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("placeholder", [
|
||||
"placeholder",
|
||||
"test-key",
|
||||
"your-langfuse-key",
|
||||
"change-me",
|
||||
"xxx",
|
||||
"dummy-key-here",
|
||||
"<your-key>",
|
||||
"REPLACE_ME",
|
||||
])
|
||||
def test_common_placeholders_detected(self, monkeypatch, caplog, placeholder):
|
||||
"""A grab-bag of values that real-world ``.env.example`` templates
|
||||
use as stand-ins. Any of them in either key must trip the guard."""
|
||||
self._clear_env(monkeypatch)
|
||||
monkeypatch.setenv("HERMES_LANGFUSE_PUBLIC_KEY", placeholder)
|
||||
monkeypatch.setenv("HERMES_LANGFUSE_SECRET_KEY", "sk-lf-real-secret-xyz")
|
||||
plugin = self._fresh_plugin(monkeypatch)
|
||||
with caplog.at_level(logging.WARNING, logger=self.LOGGER_NAME):
|
||||
assert plugin._get_langfuse() is None
|
||||
assert "HERMES_LANGFUSE_PUBLIC_KEY" in caplog.text
|
||||
|
||||
def test_legacy_LANGFUSE_PUBLIC_KEY_also_validated(self, monkeypatch, caplog):
|
||||
"""The plugin reads both the canonical HERMES_-prefixed env var and
|
||||
the legacy bare ``LANGFUSE_PUBLIC_KEY``. The validator must run on
|
||||
whichever value ``_get_langfuse()`` actually consumed."""
|
||||
self._clear_env(monkeypatch)
|
||||
monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "placeholder")
|
||||
monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-real-secret-xyz")
|
||||
plugin = self._fresh_plugin(monkeypatch)
|
||||
with caplog.at_level(logging.WARNING, logger=self.LOGGER_NAME):
|
||||
assert plugin._get_langfuse() is None
|
||||
# Warning names the canonical user-facing env var (the bare
|
||||
# LANGFUSE_PUBLIC_KEY is a backwards-compat alias for the
|
||||
# HERMES_-prefixed one — operators set the HERMES_-prefixed one).
|
||||
assert "HERMES_LANGFUSE_PUBLIC_KEY" in caplog.text
|
||||
assert "'placeholder'" in caplog.text
|
||||
|
||||
def test_missing_credentials_still_skip_silently(self, monkeypatch, caplog):
|
||||
"""Missing-creds is the documented opt-out path (operator hasn't
|
||||
configured the plugin yet) — it must remain SILENT. Regression
|
||||
guard against the placeholder validator accidentally running on
|
||||
empty values and re-introducing log noise for unconfigured
|
||||
installs."""
|
||||
self._clear_env(monkeypatch)
|
||||
plugin = self._fresh_plugin(monkeypatch)
|
||||
with caplog.at_level(logging.WARNING, logger=self.LOGGER_NAME):
|
||||
assert plugin._get_langfuse() is None
|
||||
warnings = [r for r in caplog.records if r.levelname == "WARNING"
|
||||
and r.name == self.LOGGER_NAME]
|
||||
assert warnings == []
|
||||
|
||||
def test_sdk_not_installed_still_skips_silently(self, monkeypatch, caplog):
|
||||
"""If the langfuse SDK isn't installed at all, the placeholder
|
||||
check should never run — there's nothing the operator can do
|
||||
about a credential mismatch when the package is missing, and
|
||||
re-warning here would dilute the actually-actionable SDK-missing
|
||||
signal upstream. The ``Langfuse is None`` guard at the top of
|
||||
``_get_langfuse`` already handles this; this test pins that
|
||||
behaviour."""
|
||||
self._clear_env(monkeypatch)
|
||||
monkeypatch.setenv("HERMES_LANGFUSE_PUBLIC_KEY", "placeholder")
|
||||
monkeypatch.setenv("HERMES_LANGFUSE_SECRET_KEY", "placeholder")
|
||||
# NO monkeypatch on Langfuse here — falls back to whatever the
|
||||
# plugin imported at module load (None if SDK absent).
|
||||
plugin = self._fresh_plugin()
|
||||
monkeypatch.setattr(plugin, "Langfuse", None, raising=False)
|
||||
with caplog.at_level(logging.WARNING, logger=self.LOGGER_NAME):
|
||||
assert plugin._get_langfuse() is None
|
||||
warnings = [r for r in caplog.records if r.levelname == "WARNING"
|
||||
and r.name == self.LOGGER_NAME]
|
||||
assert warnings == []
|
||||
|
||||
def test_valid_prefixes_do_not_trigger_placeholder_warning(self, monkeypatch, caplog):
|
||||
"""Real Langfuse keys (``pk-lf-…`` / ``sk-lf-…``) must pass the
|
||||
guard and proceed to SDK init. We stub the SDK constructor with
|
||||
a recording fake so the assertion can confirm BOTH that the
|
||||
placeholder warning didn't fire AND that the client was actually
|
||||
constructed — the latter is the success signal the bug report
|
||||
wanted."""
|
||||
self._clear_env(monkeypatch)
|
||||
monkeypatch.setenv("HERMES_LANGFUSE_PUBLIC_KEY", "pk-lf-real-public-xyz")
|
||||
monkeypatch.setenv("HERMES_LANGFUSE_SECRET_KEY", "sk-lf-real-secret-xyz")
|
||||
plugin = self._fresh_plugin(monkeypatch)
|
||||
with caplog.at_level(logging.WARNING, logger=self.LOGGER_NAME):
|
||||
client = plugin._get_langfuse()
|
||||
assert isinstance(client, _FakeLangfuse)
|
||||
assert client.kwargs["public_key"] == "pk-lf-real-public-xyz"
|
||||
assert client.kwargs["secret_key"] == "sk-lf-real-secret-xyz"
|
||||
assert "placeholders" not in caplog.text.lower(), (
|
||||
f"Valid Langfuse keys tripped the placeholder guard: {caplog.text!r}"
|
||||
)
|
||||
|
||||
|
||||
class TestRequestMessageCoercion:
|
||||
def test_prefers_request_messages_then_messages_then_history_then_user_message(self):
|
||||
sys.modules.pop("plugins.observability.langfuse", None)
|
||||
mod = importlib.import_module("plugins.observability.langfuse")
|
||||
|
||||
assert mod._coerce_request_messages(
|
||||
request_messages=[{"role": "system", "content": "s"}],
|
||||
messages=[{"role": "user", "content": "m"}],
|
||||
conversation_history=[{"role": "user", "content": "h"}],
|
||||
user_message="u",
|
||||
) == [{"role": "system", "content": "s"}]
|
||||
assert mod._coerce_request_messages(
|
||||
messages=[{"role": "user", "content": "m"}],
|
||||
conversation_history=[{"role": "user", "content": "h"}],
|
||||
user_message="u",
|
||||
) == [{"role": "user", "content": "m"}]
|
||||
assert mod._coerce_request_messages(
|
||||
conversation_history=[{"role": "user", "content": "h"}],
|
||||
user_message="u",
|
||||
) == [{"role": "user", "content": "h"}]
|
||||
assert mod._coerce_request_messages(user_message="u") == [{"role": "user", "content": "u"}]
|
||||
|
||||
|
||||
class TestToolCallOutputBackfill:
|
||||
def test_post_tool_call_backfills_matching_turn_tool_call_output(self, monkeypatch):
|
||||
sys.modules.pop("plugins.observability.langfuse", None)
|
||||
mod = importlib.import_module("plugins.observability.langfuse")
|
||||
|
||||
observation = object()
|
||||
state = mod.TraceState(trace_id="trace-1", root_ctx=None, root_span=None)
|
||||
state.tools["call-1"] = observation
|
||||
state.turn_tool_calls.append({
|
||||
"id": "call-1",
|
||||
"type": "function",
|
||||
"name": "web_extract",
|
||||
"arguments": '{"urls": ["https://example.com"]}',
|
||||
"function": {
|
||||
"name": "web_extract",
|
||||
"arguments": '{"urls": ["https://example.com"]}',
|
||||
},
|
||||
})
|
||||
|
||||
task_key = mod._trace_key("task-1", "session-1")
|
||||
monkeypatch.setitem(mod._TRACE_STATE, task_key, state)
|
||||
|
||||
ended = {}
|
||||
|
||||
def fake_end_observation(obs, *, output=None, metadata=None, usage_details=None, cost_details=None):
|
||||
ended["observation"] = obs
|
||||
ended["output"] = output
|
||||
ended["metadata"] = metadata
|
||||
|
||||
monkeypatch.setattr(mod, "_end_observation", fake_end_observation)
|
||||
|
||||
mod.on_post_tool_call(
|
||||
tool_name="web_extract",
|
||||
args={"urls": ["https://example.com"]},
|
||||
result='{"results": [{"url": "https://example.com", "content": "Example Domain"}]}',
|
||||
task_id="task-1",
|
||||
session_id="session-1",
|
||||
tool_call_id="call-1",
|
||||
)
|
||||
|
||||
assert ended["observation"] is observation
|
||||
assert state.turn_tool_calls[0]["output"] == ended["output"]
|
||||
assert state.turn_tool_calls[0]["function"]["output"] == ended["output"]
|
||||
assert state.turn_tool_calls[0]["output"] == {
|
||||
"results": [{"url": "https://example.com", "content": "Example Domain"}]
|
||||
}
|
||||
|
||||
def test_serialize_messages_keeps_tool_name_and_call_id(self):
|
||||
sys.modules.pop("plugins.observability.langfuse", None)
|
||||
mod = importlib.import_module("plugins.observability.langfuse")
|
||||
|
||||
messages = [{
|
||||
"role": "tool",
|
||||
"name": "web_extract",
|
||||
"tool_call_id": "call-1",
|
||||
"content": '{"ok": true}',
|
||||
}]
|
||||
|
||||
assert mod._serialize_messages(messages) == [{
|
||||
"role": "tool",
|
||||
"name": "web_extract",
|
||||
"tool_call_id": "call-1",
|
||||
"content": {"ok": True},
|
||||
}]
|
||||
|
||||
def test_serialize_tool_calls_emits_openai_style_function_shape(self):
|
||||
sys.modules.pop("plugins.observability.langfuse", None)
|
||||
mod = importlib.import_module("plugins.observability.langfuse")
|
||||
|
||||
class _Fn:
|
||||
name = "web_extract"
|
||||
arguments = '{"urls": ["https://example.com"]}'
|
||||
|
||||
class _ToolCall:
|
||||
id = "call-1"
|
||||
type = "function"
|
||||
function = _Fn()
|
||||
|
||||
assert mod._serialize_tool_calls([_ToolCall()]) == [{
|
||||
"id": "call-1",
|
||||
"type": "function",
|
||||
"name": "web_extract",
|
||||
"arguments": '{"urls": ["https://example.com"]}',
|
||||
"function": {
|
||||
"name": "web_extract",
|
||||
"arguments": '{"urls": ["https://example.com"]}',
|
||||
},
|
||||
}]
|
||||
|
||||
|
||||
class TestToolObservationKeying:
|
||||
"""Tests for pre/post tool_call observation matching when tool_call_id is absent."""
|
||||
|
||||
def _make_mod(self):
|
||||
sys.modules.pop("plugins.observability.langfuse", None)
|
||||
return importlib.import_module("plugins.observability.langfuse")
|
||||
|
||||
def test_empty_tool_call_id_single_tool_sets_output(self, monkeypatch):
|
||||
mod = self._make_mod()
|
||||
obs = object()
|
||||
state = mod.TraceState(trace_id="t", root_ctx=None, root_span=None)
|
||||
state.pending_tools_by_name.setdefault("my_tool", []).append(obs)
|
||||
|
||||
task_key = mod._trace_key("task-1", "sess-1")
|
||||
monkeypatch.setitem(mod._TRACE_STATE, task_key, state)
|
||||
|
||||
ended = {}
|
||||
|
||||
def fake_end(o, *, output=None, metadata=None, **kw):
|
||||
ended["obs"] = o
|
||||
ended["output"] = output
|
||||
|
||||
monkeypatch.setattr(mod, "_end_observation", fake_end)
|
||||
|
||||
mod.on_post_tool_call(
|
||||
tool_name="my_tool",
|
||||
args={},
|
||||
result='{"ok": true}',
|
||||
task_id="task-1",
|
||||
session_id="sess-1",
|
||||
tool_call_id="",
|
||||
)
|
||||
|
||||
assert ended["obs"] is obs
|
||||
assert ended["output"] == {"ok": True}
|
||||
assert state.pending_tools_by_name.get("my_tool") is None
|
||||
|
||||
def test_empty_tool_call_id_observations_are_fifo_within_tool_name(self, monkeypatch):
|
||||
"""Two queued observations are consumed in FIFO order so the first
|
||||
post hook gets the first observation's output, not the second.
|
||||
|
||||
Sequential-on-one-thread coverage; the real concurrent case is
|
||||
guarded by ``_STATE_LOCK`` around every read-modify-write on
|
||||
``pending_tools_by_name`` and is exercised in
|
||||
``test_threaded_post_calls_preserve_fifo_under_lock`` below.
|
||||
"""
|
||||
mod = self._make_mod()
|
||||
obs_a, obs_b = object(), object()
|
||||
state = mod.TraceState(trace_id="t", root_ctx=None, root_span=None)
|
||||
state.pending_tools_by_name["web_extract"] = [obs_a, obs_b]
|
||||
|
||||
task_key = mod._trace_key("task-1", "sess-1")
|
||||
monkeypatch.setitem(mod._TRACE_STATE, task_key, state)
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_end(o, *, output=None, metadata=None, **kw):
|
||||
calls.append((o, output))
|
||||
|
||||
monkeypatch.setattr(mod, "_end_observation", fake_end)
|
||||
|
||||
mod.on_post_tool_call(
|
||||
tool_name="web_extract", args={}, result='{"val": "a"}',
|
||||
task_id="task-1", session_id="sess-1", tool_call_id="",
|
||||
)
|
||||
mod.on_post_tool_call(
|
||||
tool_name="web_extract", args={}, result='{"val": "b"}',
|
||||
task_id="task-1", session_id="sess-1", tool_call_id="",
|
||||
)
|
||||
|
||||
assert calls[0] == (obs_a, {"val": "a"})
|
||||
assert calls[1] == (obs_b, {"val": "b"})
|
||||
assert state.pending_tools_by_name.get("web_extract") is None
|
||||
|
||||
def test_threaded_post_calls_preserve_fifo_under_lock(self, monkeypatch):
|
||||
"""The actual concurrency contract: when 8 threads race to drain
|
||||
the pending queue, no observation is consumed twice and none is
|
||||
lost. Validates ``_STATE_LOCK`` discipline, not Python list
|
||||
semantics."""
|
||||
import threading
|
||||
|
||||
mod = self._make_mod()
|
||||
n = 8
|
||||
observations = [object() for _ in range(n)]
|
||||
state = mod.TraceState(trace_id="t", root_ctx=None, root_span=None)
|
||||
state.pending_tools_by_name["web_extract"] = list(observations)
|
||||
|
||||
task_key = mod._trace_key("task-thr", "sess-thr")
|
||||
monkeypatch.setitem(mod._TRACE_STATE, task_key, state)
|
||||
|
||||
recorded: list = []
|
||||
lock = threading.Lock()
|
||||
|
||||
def fake_end(o, *, output=None, metadata=None, **kw):
|
||||
with lock:
|
||||
recorded.append(o)
|
||||
|
||||
monkeypatch.setattr(mod, "_end_observation", fake_end)
|
||||
|
||||
barrier = threading.Barrier(n)
|
||||
|
||||
def worker():
|
||||
barrier.wait()
|
||||
mod.on_post_tool_call(
|
||||
tool_name="web_extract", args={}, result='{"ok": true}',
|
||||
task_id="task-thr", session_id="sess-thr", tool_call_id="",
|
||||
)
|
||||
|
||||
threads = [threading.Thread(target=worker) for _ in range(n)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Every observation was consumed exactly once; queue is empty.
|
||||
assert len(recorded) == n
|
||||
assert set(map(id, recorded)) == set(map(id, observations))
|
||||
assert state.pending_tools_by_name.get("web_extract") is None
|
||||
|
||||
def test_explicit_tool_call_id_uses_tools_dict(self, monkeypatch):
|
||||
"""When tool_call_id is present, pending_tools_by_name is not touched."""
|
||||
mod = self._make_mod()
|
||||
obs = object()
|
||||
state = mod.TraceState(trace_id="t", root_ctx=None, root_span=None)
|
||||
state.tools["call-99"] = obs
|
||||
|
||||
task_key = mod._trace_key("task-1", "sess-1")
|
||||
monkeypatch.setitem(mod._TRACE_STATE, task_key, state)
|
||||
|
||||
ended = {}
|
||||
|
||||
def fake_end(o, *, output=None, metadata=None, **kw):
|
||||
ended["obs"] = o
|
||||
ended["output"] = output
|
||||
|
||||
monkeypatch.setattr(mod, "_end_observation", fake_end)
|
||||
|
||||
mod.on_post_tool_call(
|
||||
tool_name="my_tool", args={}, result='{"status": "done"}',
|
||||
task_id="task-1", session_id="sess-1", tool_call_id="call-99",
|
||||
)
|
||||
|
||||
assert ended["obs"] is obs
|
||||
assert ended["output"] == {"status": "done"}
|
||||
assert not state.tools
|
||||
|
||||
|
|
|
|||
|
|
@ -54,6 +54,50 @@ def test_xai_generate_requires_xai_key(monkeypatch):
|
|||
assert result["error_type"] == "auth_required"
|
||||
|
||||
|
||||
def test_xai_available_with_oauth_only(monkeypatch):
|
||||
"""The plugin must honour xAI Grok OAuth credentials, not just
|
||||
XAI_API_KEY. Otherwise the agent's tool-availability check filters
|
||||
``video_generate`` out of the toolbelt and the agent silently falls
|
||||
back to whatever skill advertises video generation (e.g. comfyui).
|
||||
"""
|
||||
import plugins.video_gen.xai as xai_plugin
|
||||
|
||||
monkeypatch.delenv("XAI_API_KEY", raising=False)
|
||||
monkeypatch.setattr(
|
||||
"tools.xai_http.resolve_xai_http_credentials",
|
||||
lambda: {
|
||||
"provider": "xai-oauth",
|
||||
"api_key": "oauth-bearer-token",
|
||||
"base_url": "https://api.x.ai/v1",
|
||||
},
|
||||
)
|
||||
|
||||
assert xai_plugin.XAIVideoGenProvider().is_available() is True
|
||||
|
||||
|
||||
def test_xai_resolved_credentials_threaded_through_request(monkeypatch):
|
||||
"""OAuth-resolved creds must reach the HTTP layer — bug class where
|
||||
``is_available()`` says yes but the request still hits with no key.
|
||||
"""
|
||||
import plugins.video_gen.xai as xai_plugin
|
||||
|
||||
monkeypatch.delenv("XAI_API_KEY", raising=False)
|
||||
monkeypatch.setattr(
|
||||
"tools.xai_http.resolve_xai_http_credentials",
|
||||
lambda: {
|
||||
"provider": "xai-oauth",
|
||||
"api_key": "oauth-bearer-token",
|
||||
"base_url": "https://api.x.ai/v1",
|
||||
},
|
||||
)
|
||||
|
||||
api_key, base_url = xai_plugin._resolve_xai_credentials()
|
||||
assert api_key == "oauth-bearer-token"
|
||||
assert base_url == "https://api.x.ai/v1"
|
||||
headers = xai_plugin._xai_headers(api_key)
|
||||
assert headers["Authorization"] == "Bearer oauth-bearer-token"
|
||||
|
||||
|
||||
def test_xai_no_operation_kwarg():
|
||||
"""The ABC's generate() signature no longer accepts 'operation'.
|
||||
Passing it through **kwargs should be ignored (forward-compat)."""
|
||||
|
|
|
|||
|
|
@ -42,6 +42,10 @@ class TestNvidiaProfile:
|
|||
p = get_provider_profile("nvidia")
|
||||
assert "nvidia.com" in p.base_url
|
||||
|
||||
def test_billing_header_not_profile_wide(self):
|
||||
p = get_provider_profile("nvidia")
|
||||
assert p.default_headers == {}
|
||||
|
||||
|
||||
class TestKimiProfile:
|
||||
def test_temperature_omit(self):
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class TestTruncatedAnthropicResponseNormalization:
|
|||
nr = get_transport("anthropic_messages").normalize_response(response)
|
||||
|
||||
# The continuation block checks these two attributes:
|
||||
# assistant_message.content → appended to truncated_response_prefix
|
||||
# assistant_message.content → appended to truncated_response_parts
|
||||
# assistant_message.tool_calls → guards the text-retry branch
|
||||
assert nr.content is not None
|
||||
assert "partial response" in nr.content
|
||||
|
|
|
|||
|
|
@ -193,3 +193,51 @@ def test_background_review_summary_is_attributed_to_self_improvement_loop(monkey
|
|||
assert captured_bg_callback[0].startswith("💾 Self-improvement review:"), (
|
||||
captured_bg_callback[0]
|
||||
)
|
||||
|
||||
|
||||
def test_background_review_fork_skips_external_memory_plugins(monkeypatch):
|
||||
"""The background review fork must NOT touch external memory plugins.
|
||||
|
||||
Without skip_memory=True on the fork constructor, AIAgent.__init__
|
||||
rebuilds its own _memory_manager from config, scoped to the parent's
|
||||
session_id. The review fork's run_conversation() then leaks the
|
||||
harness prompt into the user's real memory namespace via three
|
||||
ingestion sites: on_turn_start (cadence + turn message),
|
||||
prefetch_all (recall query), and sync_all (harness prompt + review
|
||||
output recorded as a (user, assistant) turn pair). The fix is a
|
||||
single kwarg on the fork constructor — this test guards it.
|
||||
"""
|
||||
captured_kwargs: dict = {}
|
||||
|
||||
class FakeReviewAgent:
|
||||
def __init__(self, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
self._session_messages = []
|
||||
|
||||
def run_conversation(self, **kwargs):
|
||||
pass
|
||||
|
||||
def shutdown_memory_provider(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(run_agent_module, "AIAgent", FakeReviewAgent)
|
||||
monkeypatch.setattr(run_agent_module.threading, "Thread", ImmediateThread)
|
||||
|
||||
agent = _bare_agent()
|
||||
|
||||
AIAgent._spawn_background_review(
|
||||
agent,
|
||||
messages_snapshot=[{"role": "user", "content": "hello"}],
|
||||
review_memory=True,
|
||||
)
|
||||
|
||||
assert captured_kwargs.get("skip_memory") is True, (
|
||||
"Background review fork must be constructed with skip_memory=True "
|
||||
"so AIAgent.__init__ does not rebuild a _memory_manager wired to "
|
||||
"external plugins (honcho, mem0, supermemory, ...). Without this "
|
||||
"the fork leaks harness prompts into the user's real memory "
|
||||
"namespace via on_turn_start / prefetch_all / sync_all."
|
||||
)
|
||||
|
|
|
|||
544
tests/run_agent/test_codex_xai_oauth_recovery.py
Normal file
544
tests/run_agent/test_codex_xai_oauth_recovery.py
Normal file
|
|
@ -0,0 +1,544 @@
|
|||
"""Regression tests for the May 2026 xAI OAuth (SuperGrok / X Premium) bugs.
|
||||
|
||||
Three distinct failure modes the user community hit during rollout:
|
||||
|
||||
1. ``RuntimeError("Expected to have received `response.created` before
|
||||
`error`")`` on multi-turn xAI OAuth conversations. The OpenAI SDK's
|
||||
Responses streaming state machine collapses an upstream ``error`` SSE
|
||||
frame into a generic stream-ordering error. ``_run_codex_stream``
|
||||
now treats this the same way it already treats the missing
|
||||
``response.completed`` postlude — fall back to a non-stream
|
||||
``responses.create(stream=True)`` which surfaces the real provider
|
||||
error. Also closes #8133 (``response.in_progress`` prelude on custom
|
||||
relays) and #14634 (``codex.rate_limits`` prelude on codex-lb).
|
||||
|
||||
2. The HTTP 403 entitlement error xAI returns when an OAuth token lacks
|
||||
SuperGrok / X Premium ("You have either run out of available
|
||||
resources or do not have an active Grok subscription") used to read
|
||||
as a confusing wall of JSON. ``_summarize_api_error`` now appends a
|
||||
one-line hint pointing the user at https://grok.com and ``/model``.
|
||||
|
||||
3. Multi-turn replay of ``codex_reasoning_items`` (with
|
||||
``encrypted_content``) is now suppressed for ``is_xai_responses=True``
|
||||
in ``_chat_messages_to_responses_input``. xAI's OAuth/SuperGrok
|
||||
surface rejects replayed encrypted reasoning items; Grok still
|
||||
reasons natively each turn, so coherence rides on visible message
|
||||
text.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fix A: prelude error fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_codex_agent():
|
||||
"""Build a minimal AIAgent wired for codex_responses streaming tests."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://api.x.ai/v1",
|
||||
model="grok-4.3",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "codex_responses"
|
||||
agent.provider = "xai-oauth"
|
||||
agent._interrupt_requested = False
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prelude_event_type",
|
||||
[
|
||||
"error", # xAI OAuth multi-turn
|
||||
"codex.rate_limits", # codex-lb relays (#14634)
|
||||
"response.in_progress", # custom Responses relays (#8133)
|
||||
],
|
||||
)
|
||||
def test_codex_stream_prelude_error_falls_back_to_create_stream(prelude_event_type):
|
||||
"""The SDK's prelude RuntimeError must trigger the non-stream fallback.
|
||||
|
||||
When the first SSE event isn't ``response.created``, openai-python
|
||||
raises RuntimeError before our event loop sees anything. We must
|
||||
detect that, retry once, then fall back to ``create(stream=True)``
|
||||
which surfaces the real provider error or a real response.
|
||||
"""
|
||||
agent = _make_codex_agent()
|
||||
|
||||
prelude_error = RuntimeError(
|
||||
f"Expected to have received `response.created` before `{prelude_event_type}`"
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.responses.stream.side_effect = prelude_error
|
||||
|
||||
fallback_response = SimpleNamespace(
|
||||
output=[SimpleNamespace(
|
||||
type="message",
|
||||
content=[SimpleNamespace(type="output_text", text="fallback ok")],
|
||||
)],
|
||||
status="completed",
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
agent, "_run_codex_create_stream_fallback", return_value=fallback_response
|
||||
) as mock_fallback:
|
||||
result = agent._run_codex_stream({}, client=mock_client)
|
||||
|
||||
assert result is fallback_response
|
||||
mock_fallback.assert_called_once_with({}, client=mock_client)
|
||||
|
||||
|
||||
def test_codex_stream_prelude_error_retries_once_before_fallback():
|
||||
"""The retry path must fire one extra stream attempt before falling back."""
|
||||
agent = _make_codex_agent()
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def stream_side_effect(**kwargs):
|
||||
call_count["n"] += 1
|
||||
raise RuntimeError(
|
||||
"Expected to have received `response.created` before `error`"
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.responses.stream.side_effect = stream_side_effect
|
||||
|
||||
fallback_response = SimpleNamespace(output=[], status="completed")
|
||||
with patch.object(
|
||||
agent, "_run_codex_create_stream_fallback", return_value=fallback_response
|
||||
) as mock_fallback:
|
||||
agent._run_codex_stream({}, client=mock_client)
|
||||
|
||||
# max_stream_retries=1 → one retry + final attempt → 2 stream calls,
|
||||
# THEN the fallback path runs.
|
||||
assert call_count["n"] == 2
|
||||
mock_fallback.assert_called_once()
|
||||
|
||||
|
||||
def test_codex_stream_unrelated_runtimeerror_still_raises():
|
||||
"""RuntimeErrors that aren't prelude/postlude shape must propagate."""
|
||||
agent = _make_codex_agent()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.responses.stream.side_effect = RuntimeError("something else broke")
|
||||
|
||||
with patch.object(agent, "_run_codex_create_stream_fallback") as mock_fallback:
|
||||
with pytest.raises(RuntimeError, match="something else broke"):
|
||||
agent._run_codex_stream({}, client=mock_client)
|
||||
|
||||
mock_fallback.assert_not_called()
|
||||
|
||||
|
||||
def test_codex_stream_postlude_error_still_falls_back():
|
||||
"""Existing ``response.completed`` fallback must not regress."""
|
||||
agent = _make_codex_agent()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.responses.stream.side_effect = RuntimeError(
|
||||
"Didn't receive a `response.completed` event."
|
||||
)
|
||||
|
||||
fallback_response = SimpleNamespace(output=[], status="completed")
|
||||
with patch.object(
|
||||
agent, "_run_codex_create_stream_fallback", return_value=fallback_response
|
||||
) as mock_fallback:
|
||||
result = agent._run_codex_stream({}, client=mock_client)
|
||||
|
||||
assert result is fallback_response
|
||||
mock_fallback.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fix B: surface xAI's entitlement body verbatim (no editorializing)
|
||||
#
|
||||
# The original PR #26644 appended a hint that led with "X Premium+ does NOT
|
||||
# include xAI API access — only standalone SuperGrok subscribers can use this
|
||||
# provider." xAI announced on 2026-05-16 that X Premium subs now work in
|
||||
# Hermes (https://x.ai/news/grok-hermes), making that hint actively wrong:
|
||||
# a Premium+ user hitting a real entitlement issue (no Grok sub, wrong tier,
|
||||
# exhausted quota) would be misdirected to switch subscriptions when their
|
||||
# Premium sub is in fact valid. We now surface xAI's own body text verbatim
|
||||
# (which already says "Manage subscriptions at https://grok.com/?_s=usage")
|
||||
# and leave the diagnosis to xAI's wording.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_summarize_api_error_surfaces_xai_entitlement_body_verbatim():
|
||||
"""xAI's OAuth 403 body must surface as-is, with no Hermes-side hint."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
error = RuntimeError(
|
||||
"HTTP 403: Error code: 403 - {'code': 'The caller does not have permission "
|
||||
"to execute the specified operation', 'error': 'You have either run out of "
|
||||
"available resources or do not have an active Grok subscription. Manage "
|
||||
"subscriptions at https://grok.com'}"
|
||||
)
|
||||
summary = AIAgent._summarize_api_error(error)
|
||||
# xAI's own body text must reach the user — they need it to diagnose.
|
||||
assert "do not have an active Grok subscription" in summary
|
||||
# No stale claim that X Premium is incompatible with Hermes.
|
||||
assert "X Premium+ does NOT include" not in summary
|
||||
assert "standalone SuperGrok subscribers" not in summary
|
||||
|
||||
|
||||
def test_summarize_api_error_xai_body_message_unwrapped():
|
||||
"""SDK-style error with structured body surfaces the message cleanly."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
class _XaiErr(Exception):
|
||||
status_code = 403
|
||||
body = {
|
||||
"error": {
|
||||
"message": (
|
||||
"You have either run out of available resources or do "
|
||||
"not have an active Grok subscription. Manage at "
|
||||
"https://grok.com"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
summary = AIAgent._summarize_api_error(_XaiErr("403"))
|
||||
assert "HTTP 403" in summary
|
||||
assert "do not have an active Grok subscription" in summary
|
||||
# No editorializing on top of xAI's own wording.
|
||||
assert "X Premium+ does NOT include" not in summary
|
||||
|
||||
|
||||
def test_summarize_api_error_passes_through_unrelated_errors():
|
||||
"""Non-xAI / non-entitlement errors must not be touched."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
error = RuntimeError("HTTP 500: upstream is sad")
|
||||
summary = AIAgent._summarize_api_error(error)
|
||||
assert "SuperGrok" not in summary
|
||||
assert "grok.com" not in summary
|
||||
assert "upstream is sad" in summary
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fix C: reasoning replay gating for xai-oauth
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _assistant_msg_with_encrypted_reasoning(text="hi from grok", encrypted="enc_blob"):
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
"codex_reasoning_items": [
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": "rs_xai_001",
|
||||
"encrypted_content": encrypted,
|
||||
"summary": [],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_codex_reasoning_replay_default_includes_encrypted_content():
|
||||
"""Native Codex backend (default) must still replay encrypted reasoning."""
|
||||
from agent.codex_responses_adapter import _chat_messages_to_responses_input
|
||||
|
||||
msgs = [
|
||||
{"role": "user", "content": "hi"},
|
||||
_assistant_msg_with_encrypted_reasoning(),
|
||||
{"role": "user", "content": "what's your name?"},
|
||||
]
|
||||
|
||||
items = _chat_messages_to_responses_input(msgs)
|
||||
reasoning = [it for it in items if it.get("type") == "reasoning"]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0]["encrypted_content"] == "enc_blob"
|
||||
|
||||
|
||||
def test_codex_reasoning_replay_stripped_for_xai_oauth():
|
||||
"""xAI OAuth surface must NOT receive replayed encrypted reasoning."""
|
||||
from agent.codex_responses_adapter import _chat_messages_to_responses_input
|
||||
|
||||
msgs = [
|
||||
{"role": "user", "content": "hi"},
|
||||
_assistant_msg_with_encrypted_reasoning(),
|
||||
{"role": "user", "content": "what's your name?"},
|
||||
]
|
||||
|
||||
items = _chat_messages_to_responses_input(msgs, is_xai_responses=True)
|
||||
reasoning = [it for it in items if it.get("type") == "reasoning"]
|
||||
assert reasoning == []
|
||||
|
||||
# The assistant's visible text must still survive — coherence across
|
||||
# turns rides on the message text alone.
|
||||
assistant_items = [
|
||||
it for it in items
|
||||
if it.get("role") == "assistant" or it.get("type") == "message"
|
||||
]
|
||||
assert assistant_items, "assistant message must still be present"
|
||||
|
||||
|
||||
def test_codex_transport_xai_request_omits_encrypted_content_include():
|
||||
"""Verify the xAI ``include`` array no longer requests encrypted reasoning."""
|
||||
from agent.transports.codex import ResponsesApiTransport
|
||||
|
||||
transport = ResponsesApiTransport()
|
||||
kwargs = transport.build_kwargs(
|
||||
model="grok-4.3",
|
||||
messages=[
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
{"role": "user", "content": "hi"},
|
||||
],
|
||||
tools=None,
|
||||
instructions="you are a helpful assistant",
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
is_xai_responses=True,
|
||||
)
|
||||
# Without this gate, xAI would echo back encrypted_content blobs we'd
|
||||
# then store in codex_reasoning_items and replay next turn — which is
|
||||
# exactly the multi-turn failure mode we're closing.
|
||||
assert kwargs["include"] == []
|
||||
|
||||
|
||||
def test_codex_transport_xai_strips_replayed_reasoning_in_input():
|
||||
"""End-to-end: build_kwargs on xai-oauth must strip prior reasoning."""
|
||||
from agent.transports.codex import ResponsesApiTransport
|
||||
|
||||
transport = ResponsesApiTransport()
|
||||
kwargs = transport.build_kwargs(
|
||||
model="grok-4.3",
|
||||
messages=[
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "hi"},
|
||||
_assistant_msg_with_encrypted_reasoning(text="hi from grok"),
|
||||
{"role": "user", "content": "what's your name?"},
|
||||
],
|
||||
tools=None,
|
||||
instructions="sys",
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
is_xai_responses=True,
|
||||
)
|
||||
input_items = kwargs["input"]
|
||||
reasoning_items = [it for it in input_items if it.get("type") == "reasoning"]
|
||||
assert reasoning_items == []
|
||||
|
||||
|
||||
def test_codex_transport_native_codex_still_replays_reasoning_in_input():
|
||||
"""Regression guard: openai-codex must keep the existing replay path."""
|
||||
from agent.transports.codex import ResponsesApiTransport
|
||||
|
||||
transport = ResponsesApiTransport()
|
||||
kwargs = transport.build_kwargs(
|
||||
model="gpt-5-codex",
|
||||
messages=[
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "hi"},
|
||||
_assistant_msg_with_encrypted_reasoning(text="hi from codex"),
|
||||
{"role": "user", "content": "next"},
|
||||
],
|
||||
tools=None,
|
||||
instructions="sys",
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
is_xai_responses=False,
|
||||
)
|
||||
input_items = kwargs["input"]
|
||||
reasoning_items = [it for it in input_items if it.get("type") == "reasoning"]
|
||||
assert len(reasoning_items) == 1
|
||||
assert reasoning_items[0]["encrypted_content"] == "enc_blob"
|
||||
# Native Codex still asks for encrypted_content back.
|
||||
assert "reasoning.encrypted_content" in kwargs.get("include", [])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fix D: entitlement 403 must NOT trigger credential-pool refresh loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message",
|
||||
[
|
||||
# The exact wire text RaidenTyler and Don Piedro captured.
|
||||
"You have either run out of available resources or do not have an "
|
||||
"active Grok subscription. Manage at https://grok.com",
|
||||
# Permission-style variant from the same 403 body.
|
||||
"The caller does not have permission to execute the specified "
|
||||
"operation for grok-4.3",
|
||||
],
|
||||
)
|
||||
def test_is_entitlement_failure_matches_real_xai_bodies(message):
|
||||
from run_agent import AIAgent
|
||||
|
||||
assert AIAgent._is_entitlement_failure(
|
||||
{"message": message, "reason": "permission_denied"},
|
||||
403,
|
||||
)
|
||||
|
||||
|
||||
def test_is_entitlement_failure_false_for_status_other_than_401_403():
|
||||
"""200/429/500 must never be classified as entitlement, even if body matches."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
body = {
|
||||
"message": "do not have an active Grok subscription",
|
||||
}
|
||||
assert not AIAgent._is_entitlement_failure(body, 500)
|
||||
assert not AIAgent._is_entitlement_failure(body, 429)
|
||||
assert not AIAgent._is_entitlement_failure(body, 200)
|
||||
|
||||
|
||||
def test_is_entitlement_failure_false_for_unrelated_auth_errors():
|
||||
"""A real auth failure (expired token, wrong key) must keep refreshing."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
# Generic Anthropic-style auth failure
|
||||
assert not AIAgent._is_entitlement_failure(
|
||||
{"message": "Invalid API key", "reason": "authentication_error"},
|
||||
401,
|
||||
)
|
||||
# OAuth token expired
|
||||
assert not AIAgent._is_entitlement_failure(
|
||||
{"message": "Token has expired", "reason": "unauthorized"},
|
||||
401,
|
||||
)
|
||||
# Empty context
|
||||
assert not AIAgent._is_entitlement_failure({}, 401)
|
||||
assert not AIAgent._is_entitlement_failure(None, 401)
|
||||
|
||||
|
||||
def test_recover_with_credential_pool_skips_refresh_on_entitlement_403():
|
||||
"""The recovery path must NOT call pool.try_refresh_current() on entitlement 403.
|
||||
|
||||
Before the fix, an unsubscribed xAI OAuth account would burn the agent
|
||||
loop indefinitely: refresh → 403 → refresh → 403, infinitely. With
|
||||
the entitlement guard, recovery returns False so the error surfaces
|
||||
normally with the friendly hint from _summarize_api_error.
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
from agent.error_classifier import FailoverReason
|
||||
|
||||
agent = _make_codex_agent()
|
||||
|
||||
# Wire a fake credential pool that records refresh attempts.
|
||||
refresh_calls = {"n": 0}
|
||||
|
||||
class _FakePool:
|
||||
def try_refresh_current(self):
|
||||
refresh_calls["n"] += 1
|
||||
return MagicMock(id="should_not_be_called")
|
||||
|
||||
def mark_exhausted_and_rotate(self, **_kwargs):
|
||||
return None
|
||||
|
||||
def has_available(self):
|
||||
return False
|
||||
|
||||
agent._credential_pool = _FakePool()
|
||||
|
||||
error_context = {
|
||||
"reason": "The caller does not have permission to execute the specified operation",
|
||||
"message": "You have either run out of available resources or do not have an "
|
||||
"active Grok subscription. Manage at https://grok.com",
|
||||
}
|
||||
|
||||
recovered, _retried_429 = agent._recover_with_credential_pool(
|
||||
status_code=403,
|
||||
has_retried_429=False,
|
||||
classified_reason=FailoverReason.auth,
|
||||
error_context=error_context,
|
||||
)
|
||||
|
||||
assert recovered is False, "Entitlement 403 must surface, not silently recover"
|
||||
assert refresh_calls["n"] == 0, "try_refresh_current must NOT be called on entitlement 403"
|
||||
|
||||
|
||||
def test_recover_with_credential_pool_still_refreshes_genuine_auth_failure():
|
||||
"""Regression guard: legitimate auth errors must still trigger refresh."""
|
||||
from run_agent import AIAgent
|
||||
from agent.error_classifier import FailoverReason
|
||||
|
||||
agent = _make_codex_agent()
|
||||
|
||||
refresh_calls = {"n": 0}
|
||||
|
||||
class _FakePool:
|
||||
def try_refresh_current(self):
|
||||
refresh_calls["n"] += 1
|
||||
# Return a fake refreshed entry — semantically "refresh worked"
|
||||
entry = MagicMock()
|
||||
entry.id = "entry_refreshed"
|
||||
return entry
|
||||
|
||||
def mark_exhausted_and_rotate(self, **_kwargs):
|
||||
return None
|
||||
|
||||
def has_available(self):
|
||||
return False
|
||||
|
||||
agent._credential_pool = _FakePool()
|
||||
# _swap_credential is called by the recovery path — stub it out
|
||||
agent._swap_credential = MagicMock()
|
||||
|
||||
error_context = {
|
||||
"reason": "authentication_error",
|
||||
"message": "Invalid API key",
|
||||
}
|
||||
|
||||
recovered, _retried_429 = agent._recover_with_credential_pool(
|
||||
status_code=401,
|
||||
has_retried_429=False,
|
||||
classified_reason=FailoverReason.auth,
|
||||
error_context=error_context,
|
||||
)
|
||||
|
||||
assert recovered is True, "Genuine auth failure must still recover via refresh"
|
||||
assert refresh_calls["n"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fix E: grok-4.3 context length must be 1M, not 256K
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_grok_4_3_context_length_is_1m():
|
||||
"""grok-4.3 ships with 1M context per docs.x.ai/developers/models/grok-4.3.
|
||||
|
||||
Hermes' substring-match fallback used to return 256k (from the
|
||||
"grok-4" catch-all) which under-reported the model's real capacity.
|
||||
"""
|
||||
from agent.model_metadata import DEFAULT_CONTEXT_LENGTHS
|
||||
|
||||
# The entry exists with the expected value.
|
||||
assert DEFAULT_CONTEXT_LENGTHS["grok-4.3"] == 1_000_000
|
||||
|
||||
# And longest-first substring matching resolves grok-4.3 and
|
||||
# grok-4.3-latest to the new value, NOT the grok-4 catch-all.
|
||||
for slug in ("grok-4.3", "grok-4.3-latest"):
|
||||
matched_key = max(
|
||||
(k for k in DEFAULT_CONTEXT_LENGTHS if k in slug.lower()),
|
||||
key=len,
|
||||
)
|
||||
assert matched_key == "grok-4.3", (
|
||||
f"Expected longest-first match to land on grok-4.3 for {slug}, "
|
||||
f"got {matched_key}"
|
||||
)
|
||||
assert DEFAULT_CONTEXT_LENGTHS[matched_key] == 1_000_000
|
||||
|
||||
|
||||
def test_grok_4_still_resolves_to_256k():
|
||||
"""Regression guard: grok-4 (non-.3) must still resolve to 256k."""
|
||||
from agent.model_metadata import DEFAULT_CONTEXT_LENGTHS
|
||||
|
||||
for slug in ("grok-4", "grok-4-0709"):
|
||||
matched_key = max(
|
||||
(k for k in DEFAULT_CONTEXT_LENGTHS if k in slug.lower()),
|
||||
key=len,
|
||||
)
|
||||
# grok-4-0709 contains "grok-4" but not "grok-4.3"; matched key
|
||||
# must be "grok-4" (or a more specific variant family if one is
|
||||
# ever added). The 256k contract must hold.
|
||||
assert DEFAULT_CONTEXT_LENGTHS[matched_key] == 256_000
|
||||
|
|
@ -123,6 +123,26 @@ class TestRestorePrimaryRuntime:
|
|||
assert agent._fallback_activated is False
|
||||
assert agent._restore_primary_runtime() is False
|
||||
|
||||
def test_resets_index_when_fallback_not_activated(self):
|
||||
"""Regression for #20465: failed activation leaves _fallback_index advanced
|
||||
with _fallback_activated=False; the next turn's restore must reset the index."""
|
||||
fbs = [{"provider": "custom", "model": "gpt-oss:20b",
|
||||
"base_url": "http://host.docker.internal:11434/v1", "api_key": "ollama"}]
|
||||
agent = _make_agent(fallback_model=fbs)
|
||||
|
||||
# resolve_provider_client returns None → _try_activate_fallback returns False
|
||||
# but _fallback_index has already been incremented to 1
|
||||
with patch("agent.auxiliary_client.resolve_provider_client", return_value=(None, None)):
|
||||
assert agent._try_activate_fallback() is False
|
||||
|
||||
assert agent._fallback_activated is False
|
||||
assert agent._fallback_index == 1 # advanced past the only entry
|
||||
|
||||
# _restore_primary_runtime must reset the index so the next turn can retry
|
||||
result = agent._restore_primary_runtime()
|
||||
assert result is False # still no-op (primary was never left)
|
||||
assert agent._fallback_index == 0 # chain available again
|
||||
|
||||
def test_restores_model_and_provider(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"},
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
Mirrors the OpenRouter pattern for the Vercel AI Gateway so that
|
||||
referrerUrl / appName / User-Agent flow into gateway analytics.
|
||||
"""
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
|
@ -65,6 +66,73 @@ def test_routermint_base_url_applies_user_agent_header(mock_openai):
|
|||
assert headers["User-Agent"].startswith("HermesAgent/")
|
||||
|
||||
|
||||
@patch("run_agent.OpenAI")
|
||||
def test_nvidia_cloud_base_url_applies_billing_origin_header(mock_openai):
|
||||
mock_openai.return_value = MagicMock()
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://integrate.api.nvidia.com/v1",
|
||||
model="nvidia/test-model",
|
||||
provider="nvidia",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
|
||||
assert agent._client_kwargs["default_headers"]["X-BILLING-INVOKE-ORIGIN"] == "HermesAgent"
|
||||
|
||||
agent._apply_client_headers_for_base_url("https://integrate.api.nvidia.com/v1")
|
||||
|
||||
headers = agent._client_kwargs["default_headers"]
|
||||
assert headers["X-BILLING-INVOKE-ORIGIN"] == "HermesAgent"
|
||||
|
||||
|
||||
@patch("run_agent.OpenAI")
|
||||
def test_nvidia_local_base_url_does_not_apply_billing_origin_header(mock_openai):
|
||||
mock_openai.return_value = MagicMock()
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://integrate.api.nvidia.com/v1",
|
||||
model="nvidia/test-model",
|
||||
provider="nvidia",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent._client_kwargs["default_headers"] = {
|
||||
"X-BILLING-INVOKE-ORIGIN": "HermesAgent",
|
||||
}
|
||||
|
||||
agent._apply_client_headers_for_base_url("http://localhost:8000/v1")
|
||||
|
||||
assert "default_headers" not in agent._client_kwargs
|
||||
|
||||
|
||||
@patch("run_agent.OpenAI")
|
||||
def test_routed_client_preserves_openai_sdk_custom_headers(mock_openai):
|
||||
mock_openai.return_value = MagicMock()
|
||||
routed_client = SimpleNamespace(
|
||||
api_key="test-key",
|
||||
base_url="https://integrate.api.nvidia.com/v1",
|
||||
_custom_headers={"X-BILLING-INVOKE-ORIGIN": "HermesAgent"},
|
||||
)
|
||||
|
||||
with patch("agent.auxiliary_client.resolve_provider_client", return_value=(
|
||||
routed_client,
|
||||
"nvidia/test-model",
|
||||
)):
|
||||
agent = AIAgent(
|
||||
provider="nvidia",
|
||||
model="nvidia/test-model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
|
||||
headers = agent._client_kwargs["default_headers"]
|
||||
assert headers["X-BILLING-INVOKE-ORIGIN"] == "HermesAgent"
|
||||
|
||||
|
||||
@patch("run_agent.OpenAI")
|
||||
def test_gmi_base_url_picks_up_profile_user_agent(mock_openai):
|
||||
"""GMI declares User-Agent on its ProviderProfile.default_headers.
|
||||
|
|
|
|||
|
|
@ -61,6 +61,8 @@ def _make_agent(monkeypatch, provider, api_mode="chat_completions", base_url="ht
|
|||
)
|
||||
if model:
|
||||
kwargs["model"] = model
|
||||
elif provider == "nous":
|
||||
kwargs["model"] = "gpt-5"
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
|
|
|
|||
|
|
@ -2269,6 +2269,60 @@ class TestParallelScopePathNormalization:
|
|||
assert not _should_parallelize_tool_batch([tc1, tc2])
|
||||
|
||||
|
||||
class TestMcpParallelToolBatch:
|
||||
"""Integration test: _should_parallelize_tool_batch respects MCP parallel flag."""
|
||||
|
||||
def test_mcp_tools_default_sequential(self):
|
||||
"""MCP tools without supports_parallel_tool_calls are sequential."""
|
||||
from run_agent import _should_parallelize_tool_batch
|
||||
tc1 = _mock_tool_call(name="mcp_github_list_repos", arguments='{"org":"openai"}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="mcp_github_search_code", arguments='{"q":"test"}', call_id="c2")
|
||||
assert not _should_parallelize_tool_batch([tc1, tc2])
|
||||
|
||||
def test_mcp_tools_parallel_when_server_opted_in(self):
|
||||
"""MCP tools from a parallel-safe server can run concurrently."""
|
||||
from run_agent import _should_parallelize_tool_batch
|
||||
from tools.mcp_tool import _parallel_safe_servers, _lock
|
||||
with _lock:
|
||||
_parallel_safe_servers.add("github")
|
||||
try:
|
||||
tc1 = _mock_tool_call(name="mcp_github_list_repos", arguments='{"org":"openai"}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="mcp_github_search_code", arguments='{"q":"test"}', call_id="c2")
|
||||
assert _should_parallelize_tool_batch([tc1, tc2])
|
||||
finally:
|
||||
with _lock:
|
||||
_parallel_safe_servers.discard("github")
|
||||
|
||||
def test_mixed_mcp_and_builtin_parallel(self):
|
||||
"""MCP parallel tools mixed with built-in parallel-safe tools."""
|
||||
from run_agent import _should_parallelize_tool_batch
|
||||
from tools.mcp_tool import _parallel_safe_servers, _lock
|
||||
with _lock:
|
||||
_parallel_safe_servers.add("docs")
|
||||
try:
|
||||
tc1 = _mock_tool_call(name="mcp_docs_search", arguments='{"query":"api"}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="web_search", arguments='{"query":"test"}', call_id="c2")
|
||||
assert _should_parallelize_tool_batch([tc1, tc2])
|
||||
finally:
|
||||
with _lock:
|
||||
_parallel_safe_servers.discard("docs")
|
||||
|
||||
def test_mixed_parallel_and_serial_mcp_servers(self):
|
||||
"""One parallel MCP server + one non-parallel MCP server = sequential."""
|
||||
from run_agent import _should_parallelize_tool_batch
|
||||
from tools.mcp_tool import _parallel_safe_servers, _lock
|
||||
with _lock:
|
||||
_parallel_safe_servers.add("docs")
|
||||
# "github" is NOT in _parallel_safe_servers
|
||||
try:
|
||||
tc1 = _mock_tool_call(name="mcp_docs_search", arguments='{"query":"api"}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="mcp_github_list_repos", arguments='{"org":"openai"}', call_id="c2")
|
||||
assert not _should_parallelize_tool_batch([tc1, tc2])
|
||||
finally:
|
||||
with _lock:
|
||||
_parallel_safe_servers.discard("docs")
|
||||
|
||||
|
||||
class TestHandleMaxIterations:
|
||||
def test_returns_summary(self, agent):
|
||||
resp = _mock_response(content="Here is a summary of what I did.")
|
||||
|
|
@ -2524,8 +2578,9 @@ class TestRunConversation:
|
|||
assert [call["api_call_count"] for call in pre_request_calls] == [1, 2]
|
||||
assert [call["api_call_count"] for call in post_request_calls] == [1, 2]
|
||||
assert all(call["session_id"] == agent.session_id for call in pre_request_calls)
|
||||
assert all("message_count" in c and "messages" not in c for c in pre_request_calls)
|
||||
assert all("usage" in c and "response" not in c for c in post_request_calls)
|
||||
assert all("message_count" in c and isinstance(c.get("request_messages"), list) for c in pre_request_calls)
|
||||
assert any(msg.get("role") == "user" and msg.get("content") == "search something" for msg in pre_request_calls[0]["request_messages"])
|
||||
assert all("usage" in c and "response" in c and "assistant_message" in c for c in post_request_calls)
|
||||
|
||||
def test_content_with_tool_calls_stays_silent_for_non_cli_quiet_mode(self, agent):
|
||||
self._setup_agent(agent)
|
||||
|
|
@ -3691,6 +3746,37 @@ class TestCredentialPoolRecovery:
|
|||
assert retry_same is False
|
||||
agent._swap_credential.assert_called_once_with(next_entry)
|
||||
|
||||
def test_recover_with_pool_rotates_usage_limit_429_immediately(self, agent):
|
||||
next_entry = SimpleNamespace(label="secondary")
|
||||
captured = {}
|
||||
|
||||
class _Pool:
|
||||
def current(self):
|
||||
return SimpleNamespace(label="primary")
|
||||
|
||||
def mark_exhausted_and_rotate(self, *, status_code, error_context=None):
|
||||
captured["status_code"] = status_code
|
||||
captured["error_context"] = error_context
|
||||
return next_entry
|
||||
|
||||
agent._credential_pool = _Pool()
|
||||
agent._swap_credential = MagicMock()
|
||||
|
||||
recovered, retry_same = agent._recover_with_credential_pool(
|
||||
status_code=429,
|
||||
has_retried_429=False,
|
||||
error_context={
|
||||
"reason": "usage_limit_reached",
|
||||
"message": "The usage limit has been reached",
|
||||
},
|
||||
)
|
||||
|
||||
assert recovered is True
|
||||
assert retry_same is False
|
||||
assert captured["status_code"] == 429
|
||||
assert captured["error_context"]["reason"] == "usage_limit_reached"
|
||||
agent._swap_credential.assert_called_once_with(next_entry)
|
||||
|
||||
|
||||
def test_recover_with_pool_refreshes_on_401(self, agent):
|
||||
"""401 with successful refresh should swap to refreshed credential."""
|
||||
|
|
@ -3777,6 +3863,22 @@ class TestCredentialPoolRecovery:
|
|||
assert context["message"] == "Weekly credits exhausted."
|
||||
assert context["reset_at"] == "2026-04-12T10:30:00Z"
|
||||
|
||||
def test_extract_api_error_context_uses_type_as_reason(self, agent):
|
||||
error = SimpleNamespace(
|
||||
body={
|
||||
"error": {
|
||||
"type": "usage_limit_reached",
|
||||
"message": "The usage limit has been reached",
|
||||
}
|
||||
},
|
||||
response=SimpleNamespace(headers={}),
|
||||
)
|
||||
|
||||
context = agent._extract_api_error_context(error)
|
||||
|
||||
assert context["reason"] == "usage_limit_reached"
|
||||
assert context["message"] == "The usage limit has been reached"
|
||||
|
||||
def test_recover_with_pool_passes_error_context_on_rotated_429(self, agent):
|
||||
next_entry = SimpleNamespace(label="secondary")
|
||||
captured = {}
|
||||
|
|
|
|||
|
|
@ -578,6 +578,197 @@ def test_run_conversation_codex_refreshes_after_401_and_retries(monkeypatch):
|
|||
assert result["final_response"] == "Recovered after refresh"
|
||||
|
||||
|
||||
def _build_xai_oauth_agent(monkeypatch):
|
||||
_patch_agent_bootstrap(monkeypatch)
|
||||
agent = run_agent.AIAgent(
|
||||
model="grok-4.3",
|
||||
provider="xai-oauth",
|
||||
api_mode="codex_responses",
|
||||
base_url="https://api.x.ai/v1",
|
||||
api_key="xai-oauth-token",
|
||||
quiet_mode=True,
|
||||
max_iterations=4,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent._cleanup_task_resources = lambda task_id: None
|
||||
agent._persist_session = lambda messages, history=None: None
|
||||
agent._save_trajectory = lambda messages, user_message, completed: None
|
||||
agent._save_session_log = lambda messages: None
|
||||
return agent
|
||||
|
||||
|
||||
def test_build_api_kwargs_xai_oauth_sends_cache_key_via_extra_body(monkeypatch):
|
||||
"""xai-oauth + codex_responses must route prompt caching via the
|
||||
``prompt_cache_key`` body field on /v1/responses (xAI's documented
|
||||
Responses-API cache key — see docs.x.ai prompt-caching/maximizing-
|
||||
cache-hits).
|
||||
|
||||
We pass it through ``extra_body`` rather than as a top-level kwarg so
|
||||
the body field is serialized into JSON regardless of whether the
|
||||
installed openai SDK build still accepts ``prompt_cache_key`` on
|
||||
``Responses.stream()``. Older or trimmed SDK builds drop it from the
|
||||
signature and would otherwise raise ``TypeError`` before the request
|
||||
reaches api.x.ai. The ``x-grok-conv-id`` header is retained as a
|
||||
belt-and-braces fallback for clients/proxies that route on headers."""
|
||||
agent = _build_xai_oauth_agent(monkeypatch)
|
||||
kwargs = agent._build_api_kwargs(
|
||||
[
|
||||
{"role": "system", "content": "You are Hermes."},
|
||||
{"role": "user", "content": "Ping"},
|
||||
]
|
||||
)
|
||||
|
||||
assert kwargs.get("model") == "grok-4.3"
|
||||
# Top-level kwarg must NOT be set — that's the openai SDK
|
||||
# incompatibility this whole indirection exists to dodge.
|
||||
assert "prompt_cache_key" not in kwargs
|
||||
extra_body = kwargs.get("extra_body") or {}
|
||||
assert extra_body.get("prompt_cache_key"), (
|
||||
"xAI prompt-cache routing must travel via extra_body.prompt_cache_key "
|
||||
"for /v1/responses — body field is the documented surface."
|
||||
)
|
||||
headers = kwargs.get("extra_headers") or {}
|
||||
assert "x-grok-conv-id" in headers, (
|
||||
"x-grok-conv-id header kept as belt-and-braces fallback for clients "
|
||||
"that route on headers."
|
||||
)
|
||||
|
||||
|
||||
def test_run_conversation_xai_oauth_refreshes_after_401_and_retries(monkeypatch):
|
||||
"""xai-oauth speaks the Responses API just like codex. When the access
|
||||
token is rejected mid-call (401), the same proactive refresh-and-retry
|
||||
handler that fires for openai-codex must also fire for xai-oauth — the
|
||||
bug it caught: the gating condition checked only ``provider == "openai-codex"``,
|
||||
so xai-oauth 401s leaked straight to non-retryable abort path with no
|
||||
chance to swap in a freshly refreshed access token."""
|
||||
agent = _build_xai_oauth_agent(monkeypatch)
|
||||
calls = {"api": 0, "refresh": 0}
|
||||
|
||||
class _UnauthorizedError(RuntimeError):
|
||||
def __init__(self):
|
||||
super().__init__("Error code: 401 - unauthorized")
|
||||
self.status_code = 401
|
||||
|
||||
def _fake_api_call(api_kwargs):
|
||||
calls["api"] += 1
|
||||
if calls["api"] == 1:
|
||||
raise _UnauthorizedError()
|
||||
return _codex_message_response("Recovered after xAI refresh")
|
||||
|
||||
def _fake_refresh(*, force=True):
|
||||
calls["refresh"] += 1
|
||||
assert force is True
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(agent, "_interruptible_api_call", _fake_api_call)
|
||||
monkeypatch.setattr(agent, "_try_refresh_codex_client_credentials", _fake_refresh)
|
||||
|
||||
result = agent.run_conversation("Say OK")
|
||||
|
||||
assert calls["api"] == 2
|
||||
assert calls["refresh"] == 1
|
||||
assert result["completed"] is True
|
||||
assert result["final_response"] == "Recovered after xAI refresh"
|
||||
|
||||
|
||||
def test_try_refresh_codex_client_credentials_handles_xai_oauth(monkeypatch):
|
||||
"""``_try_refresh_codex_client_credentials`` must rebuild the OpenAI
|
||||
client with freshly resolved xAI OAuth credentials when the active
|
||||
provider is xai-oauth. The function name is shared between codex and
|
||||
xai-oauth (both speak codex_responses) — covering both cases prevents
|
||||
silent regressions where the function gets gated to a single provider."""
|
||||
agent = _build_xai_oauth_agent(monkeypatch)
|
||||
closed = {"value": False}
|
||||
rebuilt = {"kwargs": None}
|
||||
|
||||
class _ExistingClient:
|
||||
def close(self):
|
||||
closed["value"] = True
|
||||
|
||||
class _RebuiltClient:
|
||||
pass
|
||||
|
||||
def _fake_openai(**kwargs):
|
||||
rebuilt["kwargs"] = kwargs
|
||||
return _RebuiltClient()
|
||||
|
||||
def _fake_resolve(force_refresh=False, refresh_if_expiring=True, **_):
|
||||
# The pre-refresh guard reads the singleton with refresh_if_expiring=False
|
||||
# to verify that the agent's active key still matches; the actual
|
||||
# refresh later passes force_refresh=True. Both calls must succeed.
|
||||
return {
|
||||
"api_key": "fresh-xai-token" if force_refresh else agent.api_key,
|
||||
"base_url": "https://api.x.ai/v1",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.resolve_xai_oauth_runtime_credentials",
|
||||
_fake_resolve,
|
||||
)
|
||||
monkeypatch.setattr(run_agent, "OpenAI", _fake_openai)
|
||||
|
||||
agent.client = _ExistingClient()
|
||||
ok = agent._try_refresh_codex_client_credentials(force=True)
|
||||
|
||||
assert ok is True
|
||||
assert closed["value"] is True
|
||||
assert rebuilt["kwargs"]["api_key"] == "fresh-xai-token"
|
||||
assert rebuilt["kwargs"]["base_url"] == "https://api.x.ai/v1"
|
||||
assert isinstance(agent.client, _RebuiltClient)
|
||||
assert agent.api_key == "fresh-xai-token"
|
||||
|
||||
|
||||
def test_try_refresh_codex_client_credentials_skips_xai_oauth_when_singleton_differs(monkeypatch):
|
||||
"""An xai-oauth agent constructed with a non-singleton credential
|
||||
(e.g. a manual pool entry whose tokens belong to a different account
|
||||
than the loopback_pkce singleton, or an explicit ``api_key=`` arg)
|
||||
MUST NOT silently adopt the singleton's tokens on a 401 reactive
|
||||
refresh. Otherwise a 401 mid-conversation would re-route the rest
|
||||
of the conversation onto a different account, with no user feedback.
|
||||
|
||||
The credential pool's reactive recovery is the right channel for
|
||||
pool-managed credentials; this fallback path is for the singleton-
|
||||
only case and must short-circuit when the active key differs."""
|
||||
agent = _build_xai_oauth_agent(monkeypatch)
|
||||
# Agent is using "xai-oauth-token" (per the builder); singleton holds
|
||||
# a *different* account's token. No force_refresh should fire.
|
||||
refresh_calls = {"count": 0}
|
||||
|
||||
def _fake_resolve(force_refresh=False, refresh_if_expiring=True, **_):
|
||||
if force_refresh:
|
||||
refresh_calls["count"] += 1
|
||||
return {
|
||||
"api_key": "singleton-account-token",
|
||||
"base_url": "https://api.x.ai/v1",
|
||||
}
|
||||
# The pre-refresh guard read — return the singleton's view of the
|
||||
# singleton's token, which is NOT what the agent is currently using.
|
||||
return {
|
||||
"api_key": "singleton-account-token",
|
||||
"base_url": "https://api.x.ai/v1",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.resolve_xai_oauth_runtime_credentials",
|
||||
_fake_resolve,
|
||||
)
|
||||
|
||||
pre_refresh_key = agent.api_key
|
||||
ok = agent._try_refresh_codex_client_credentials(force=True)
|
||||
|
||||
assert ok is False, (
|
||||
"must not refresh when the active credential isn't the singleton; "
|
||||
"otherwise the conversation silently swaps accounts mid-flight."
|
||||
)
|
||||
assert refresh_calls["count"] == 0, (
|
||||
"force_refresh must not run — that would mutate the singleton's "
|
||||
"tokens on disk and consume its single-use refresh_token for an "
|
||||
"agent that wasn't even using the singleton."
|
||||
)
|
||||
assert agent.api_key == pre_refresh_key
|
||||
|
||||
|
||||
def test_run_conversation_copilot_refreshes_after_401_and_retries(monkeypatch):
|
||||
agent = _build_copilot_agent(monkeypatch)
|
||||
calls = {"api": 0, "refresh": 0}
|
||||
|
|
@ -624,12 +815,18 @@ def test_try_refresh_codex_client_credentials_rebuilds_client(monkeypatch):
|
|||
rebuilt["kwargs"] = kwargs
|
||||
return _RebuiltClient()
|
||||
|
||||
def _fake_resolve(force_refresh=False, refresh_if_expiring=True, **_):
|
||||
# Pre-refresh guard reads the singleton (refresh_if_expiring=False).
|
||||
# It must report the agent's current api_key so the equality check
|
||||
# passes; only then does the actual force_refresh run.
|
||||
return {
|
||||
"api_key": "new-codex-token" if force_refresh else agent.api_key,
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.resolve_codex_runtime_credentials",
|
||||
lambda force_refresh=True: {
|
||||
"api_key": "new-codex-token",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
},
|
||||
_fake_resolve,
|
||||
)
|
||||
monkeypatch.setattr(run_agent, "OpenAI", _fake_openai)
|
||||
|
||||
|
|
|
|||
|
|
@ -999,6 +999,88 @@ class TestAnthropicStreamCallbacks:
|
|||
|
||||
assert touch_calls.count("receiving stream response") == len(events)
|
||||
|
||||
@patch("run_agent.AIAgent._replace_primary_openai_client")
|
||||
def test_anthropic_stream_parser_valueerror_retries_before_delivery(
|
||||
self, mock_replace, monkeypatch,
|
||||
):
|
||||
"""Malformed Anthropic event-stream frames retry instead of surfacing HTTP None."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://api.minimax.io/anthropic",
|
||||
provider="minimax",
|
||||
model="MiniMax-M2.7",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "anthropic_messages"
|
||||
agent._interrupt_requested = False
|
||||
monkeypatch.setenv("HERMES_STREAM_RETRIES", "1")
|
||||
|
||||
class _BadStream:
|
||||
response = None
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_args):
|
||||
return False
|
||||
|
||||
def __iter__(self):
|
||||
raise ValueError("expected ident at line 1 column 149")
|
||||
|
||||
final_message = SimpleNamespace(content=[], stop_reason="end_turn")
|
||||
good_stream = MagicMock()
|
||||
good_stream.__enter__ = MagicMock(return_value=good_stream)
|
||||
good_stream.__exit__ = MagicMock(return_value=False)
|
||||
good_stream.__iter__ = MagicMock(return_value=iter([]))
|
||||
good_stream.get_final_message.return_value = final_message
|
||||
|
||||
agent._anthropic_client = MagicMock()
|
||||
agent._anthropic_client.messages.stream.side_effect = [
|
||||
_BadStream(),
|
||||
good_stream,
|
||||
]
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert response is final_message
|
||||
assert agent._anthropic_client.messages.stream.call_count == 2
|
||||
assert mock_replace.call_count == 1
|
||||
|
||||
@patch("run_agent.AIAgent._replace_primary_openai_client")
|
||||
def test_generic_anthropic_valueerror_still_propagates_without_stream_retry(
|
||||
self, mock_replace, monkeypatch,
|
||||
):
|
||||
"""Only known provider stream parser ValueErrors are treated as transient."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://api.minimax.io/anthropic",
|
||||
provider="minimax",
|
||||
model="MiniMax-M2.7",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "anthropic_messages"
|
||||
agent._interrupt_requested = False
|
||||
monkeypatch.setenv("HERMES_STREAM_RETRIES", "1")
|
||||
|
||||
agent._anthropic_client = MagicMock()
|
||||
agent._anthropic_client.messages.stream.side_effect = ValueError(
|
||||
"invalid local request shape"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="invalid local request shape"):
|
||||
agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert agent._anthropic_client.messages.stream.call_count == 1
|
||||
assert mock_replace.call_count == 0
|
||||
|
||||
|
||||
class TestPartialToolCallWarning:
|
||||
"""Regression: when a stream dies mid tool-call argument generation after
|
||||
|
|
@ -1505,3 +1587,144 @@ class TestCopilotACPStreamingDecision:
|
|||
|
||||
assert _use_streaming is True
|
||||
|
||||
|
||||
class TestCodexFallbackErrorEvent:
|
||||
"""Provider ``error`` SSE frames must surface the real message,
|
||||
not the generic "did not emit a terminal response" RuntimeError.
|
||||
|
||||
xAI emits ``type=error`` as the FIRST frame on the Responses stream
|
||||
when an OAuth account is unsubscribed/exhausted (May 2026
|
||||
SuperGrok rollout). The SDK helper raises
|
||||
``RuntimeError("Expected to have received response.created before
|
||||
error")`` which the caller catches and routes to
|
||||
``_run_codex_create_stream_fallback``. The fallback then opens a
|
||||
NEW stream that emits the same ``type=error`` frame; before this
|
||||
fix it ignored the event entirely and raised a useless RuntimeError.
|
||||
"""
|
||||
|
||||
def _make_agent(self):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://api.x.ai/v1",
|
||||
provider="xai-oauth",
|
||||
model="grok-4.3",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "codex_responses"
|
||||
agent._touch_activity = lambda desc: None
|
||||
return agent
|
||||
|
||||
def test_fallback_raises_synthesized_error_with_xai_subscription_message(self):
|
||||
from run_agent import _StreamErrorEvent
|
||||
|
||||
agent = self._make_agent()
|
||||
|
||||
error_event = SimpleNamespace(
|
||||
type="error",
|
||||
message=(
|
||||
"Forbidden: The caller does not have permission to execute the specified operation. "
|
||||
"'You have either run out of available resources or do not have an active Grok subscription.'"
|
||||
),
|
||||
code="permission_denied",
|
||||
param=None,
|
||||
sequence_number=1,
|
||||
)
|
||||
|
||||
class _FakeStream:
|
||||
def __iter__(self_inner):
|
||||
return iter([error_event])
|
||||
def close(self_inner):
|
||||
return None
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.responses.create.return_value = _FakeStream()
|
||||
|
||||
with pytest.raises(_StreamErrorEvent) as excinfo:
|
||||
agent._run_codex_create_stream_fallback(
|
||||
{"model": "grok-4.3", "instructions": "hi", "input": []},
|
||||
client=mock_client,
|
||||
)
|
||||
|
||||
exc = excinfo.value
|
||||
assert "active Grok subscription" in str(exc)
|
||||
assert exc.code == "permission_denied"
|
||||
assert isinstance(exc.body, dict)
|
||||
assert exc.body["error"]["message"] == error_event.message
|
||||
# _extract_api_error_context reads .body["error"]["message"] — make sure
|
||||
# the entitlement detector will find the subscription phrase there.
|
||||
assert "active Grok subscription" in exc.body["error"]["message"]
|
||||
|
||||
def test_fallback_dict_event_payload_is_also_handled(self):
|
||||
"""Some relays deliver events as plain dicts instead of model
|
||||
objects; the dict branch in the loop must surface them too."""
|
||||
from run_agent import _StreamErrorEvent
|
||||
|
||||
agent = self._make_agent()
|
||||
|
||||
error_event = {
|
||||
"type": "error",
|
||||
"message": "rate_limited",
|
||||
"code": "rate_limit_exceeded",
|
||||
}
|
||||
|
||||
class _FakeStream:
|
||||
def __iter__(self_inner):
|
||||
return iter([error_event])
|
||||
def close(self_inner):
|
||||
return None
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.responses.create.return_value = _FakeStream()
|
||||
|
||||
with pytest.raises(_StreamErrorEvent) as excinfo:
|
||||
agent._run_codex_create_stream_fallback(
|
||||
{"model": "grok-4.3", "instructions": "hi", "input": []},
|
||||
client=mock_client,
|
||||
)
|
||||
|
||||
assert "rate_limited" in str(excinfo.value)
|
||||
assert excinfo.value.code == "rate_limit_exceeded"
|
||||
|
||||
def test_fallback_surfaces_message_useful_to_summarizer(self):
|
||||
"""The synthesized exception must be readable by
|
||||
``_summarize_api_error`` so the user-facing log line shows the
|
||||
real provider message instead of a generic class name."""
|
||||
from run_agent import AIAgent, _StreamErrorEvent
|
||||
|
||||
agent = self._make_agent()
|
||||
exc = _StreamErrorEvent(
|
||||
"You have either run out of available resources or do not have an active Grok subscription.",
|
||||
code="permission_denied",
|
||||
)
|
||||
|
||||
summary = AIAgent._summarize_api_error(exc)
|
||||
assert "active Grok subscription" in summary
|
||||
|
||||
def test_fallback_still_raises_terminal_error_when_no_error_event(self):
|
||||
"""Streams that simply end without any terminal event (and no
|
||||
``error`` frame) must continue to raise the original
|
||||
``"did not emit a terminal response"`` RuntimeError so callers
|
||||
can distinguish "stream truncated mid-flight" from "provider
|
||||
rejected the call"."""
|
||||
agent = self._make_agent()
|
||||
|
||||
# Empty stream — no events at all
|
||||
class _FakeStream:
|
||||
def __iter__(self_inner):
|
||||
return iter([])
|
||||
def close(self_inner):
|
||||
return None
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.responses.create.return_value = _FakeStream()
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
agent._run_codex_create_stream_fallback(
|
||||
{"model": "grok-4.3", "instructions": "hi", "input": []},
|
||||
client=mock_client,
|
||||
)
|
||||
|
||||
assert "did not emit a terminal response" in str(excinfo.value)
|
||||
|
|
|
|||
102
tests/skills/test_darwinian_evolver_skill.py
Normal file
102
tests/skills/test_darwinian_evolver_skill.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""
|
||||
Smoke tests for the darwinian-evolver optional skill.
|
||||
|
||||
We can't actually run the evolution loop in CI (it needs network + a paid LLM),
|
||||
so these tests verify:
|
||||
- SKILL.md frontmatter conforms to the hardline format
|
||||
- shipped scripts parse as valid Python
|
||||
- the scripts reference the right env var / module paths
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
SKILL_DIR = Path(__file__).resolve().parents[2] / "optional-skills" / "research" / "darwinian-evolver"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def frontmatter() -> dict:
|
||||
src = (SKILL_DIR / "SKILL.md").read_text()
|
||||
m = re.search(r"^---\n(.*?)\n---", src, re.DOTALL)
|
||||
assert m, "SKILL.md missing YAML frontmatter"
|
||||
return yaml.safe_load(m.group(1))
|
||||
|
||||
|
||||
def test_skill_dir_exists() -> None:
|
||||
assert SKILL_DIR.is_dir(), f"missing skill dir: {SKILL_DIR}"
|
||||
|
||||
|
||||
def test_skill_md_present() -> None:
|
||||
assert (SKILL_DIR / "SKILL.md").is_file()
|
||||
|
||||
|
||||
def test_description_under_60_chars(frontmatter) -> None:
|
||||
desc = frontmatter["description"]
|
||||
assert len(desc) <= 60, f"description is {len(desc)} chars (hardline ≤60): {desc!r}"
|
||||
|
||||
|
||||
def test_name_matches_dir(frontmatter) -> None:
|
||||
assert frontmatter["name"] == "darwinian-evolver"
|
||||
|
||||
|
||||
def test_platforms_excludes_windows(frontmatter) -> None:
|
||||
# Upstream uses func_timeout (POSIX signals) and uv subprocess pipelines; the
|
||||
# skill is gated [linux, macos]. If we ever port to Windows, update this test
|
||||
# to assert ["linux", "macos", "windows"].
|
||||
assert "windows" not in frontmatter["platforms"]
|
||||
assert set(frontmatter["platforms"]) >= {"linux", "macos"}
|
||||
|
||||
|
||||
def test_author_credits_contributor(frontmatter) -> None:
|
||||
author = frontmatter["author"]
|
||||
assert "Bihruze" in author, f"author should credit the original contributor: {author!r}"
|
||||
|
||||
|
||||
def test_license_mit(frontmatter) -> None:
|
||||
assert frontmatter["license"] == "MIT"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"scripts/parrot_openrouter.py",
|
||||
"scripts/show_snapshot.py",
|
||||
"templates/custom_problem_template.py",
|
||||
],
|
||||
)
|
||||
def test_shipped_scripts_parse(path: str) -> None:
|
||||
src = (SKILL_DIR / path).read_text()
|
||||
ast.parse(src) # raises SyntaxError on broken Python
|
||||
|
||||
|
||||
def test_parrot_script_uses_openrouter() -> None:
|
||||
src = (SKILL_DIR / "scripts" / "parrot_openrouter.py").read_text()
|
||||
assert "OPENROUTER_API_KEY" in src, "parrot driver should read OPENROUTER_API_KEY"
|
||||
assert "openrouter.ai/api/v1" in src, "parrot driver should target OpenRouter"
|
||||
assert "EVOLVER_MODEL" in src, "model should be overridable via EVOLVER_MODEL"
|
||||
|
||||
|
||||
def test_parrot_script_has_error_swallowing() -> None:
|
||||
"""Provider content-filter / rate-limit must not kill the run — see Pitfall 2."""
|
||||
src = (SKILL_DIR / "scripts" / "parrot_openrouter.py").read_text()
|
||||
assert "LLM_ERROR" in src, "_prompt_llm should swallow provider errors and tag them"
|
||||
|
||||
|
||||
def test_skill_calls_out_agpl(frontmatter) -> None:
|
||||
"""The upstream tool is AGPL-3.0. The skill MUST flag this so users don't
|
||||
import it into MIT-licensed code by accident."""
|
||||
src = (SKILL_DIR / "SKILL.md").read_text()
|
||||
assert "AGPL" in src, "SKILL.md must mention upstream AGPL license"
|
||||
|
||||
|
||||
def test_skill_pitfalls_section_present() -> None:
|
||||
src = (SKILL_DIR / "SKILL.md").read_text()
|
||||
assert "## Pitfalls" in src
|
||||
# Pitfalls we discovered during the spike — keep them in sync with reality.
|
||||
assert "Initial organism must be viable" in src
|
||||
assert "generator" in src # loop.run() pitfall
|
||||
85
tests/test_package_json_lazy_deps.py
Normal file
85
tests/test_package_json_lazy_deps.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
"""Invariants for what is eager vs lazy in the root ``package.json``.
|
||||
|
||||
The root ``package.json`` is installed by ``hermes update`` on every user,
|
||||
including users who never opted into a given browser backend. Anything
|
||||
listed in ``dependencies`` therefore runs its npm postinstall script for
|
||||
everyone — including binary-fetching backends, on every update.
|
||||
|
||||
The contract:
|
||||
|
||||
* ``agent-browser`` IS eager. It is the default Chromium-driving backend
|
||||
used whenever the agent makes a browser call without a cloud provider
|
||||
configured, so it must already be installed before any session starts.
|
||||
Its postinstall is also small.
|
||||
|
||||
* ``@askjo/camofox-browser`` is NOT eager. It is an explicit opt-in
|
||||
alternative browser backend, selected by the user via
|
||||
``hermes tools`` → Browser Automation → Camofox, and only used at
|
||||
runtime when ``CAMOFOX_URL`` is set. Its postinstall fetches a ~300MB
|
||||
Firefox-fork binary, which silently blocked ``hermes update`` for
|
||||
multi-minute stretches on slow / network-restricted connections
|
||||
(notably users in China running through a VPN). The package is
|
||||
installed on demand by ``tools_config.py`` ``post_setup_key ==
|
||||
"camofox"`` when the user actually selects Camofox.
|
||||
|
||||
If a future PR re-adds Camofox (or any other binary-postinstall package)
|
||||
to root ``dependencies``, this test fails — read the lazy-install
|
||||
guidance in the ``hermes-agent-dev`` skill before changing the
|
||||
expectations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
def _root_package_json() -> dict:
|
||||
with (REPO_ROOT / "package.json").open("r", encoding="utf-8") as fh:
|
||||
return json.load(fh)
|
||||
|
||||
|
||||
def test_camofox_is_not_in_root_dependencies() -> None:
|
||||
"""Camofox must be opt-in, installed lazily by its post_setup handler."""
|
||||
deps = _root_package_json().get("dependencies", {})
|
||||
assert "@askjo/camofox-browser" not in deps, (
|
||||
"Camofox is a ~300MB binary-postinstall backend that must stay "
|
||||
"out of root package.json dependencies. It belongs in the "
|
||||
"Camofox post_setup handler in hermes_cli/tools_config.py so it "
|
||||
"only installs when the user explicitly selects Camofox via "
|
||||
"`hermes tools` → Browser Automation → Camofox."
|
||||
)
|
||||
|
||||
|
||||
def test_agent_browser_stays_eager() -> None:
|
||||
"""agent-browser is the default backend; it must remain eager."""
|
||||
deps = _root_package_json().get("dependencies", {})
|
||||
assert "agent-browser" in deps, (
|
||||
"agent-browser is the default browser-tool backend used by every "
|
||||
"session that doesn't have a cloud browser provider configured. "
|
||||
"It must stay in root package.json dependencies so it is present "
|
||||
"after `hermes setup` / `hermes update` without an explicit "
|
||||
"post_setup step."
|
||||
)
|
||||
|
||||
|
||||
def test_root_lockfile_has_no_camofox_entries() -> None:
|
||||
"""Regenerated lockfiles should not contain Camofox tree entries."""
|
||||
lock_path = REPO_ROOT / "package-lock.json"
|
||||
if not lock_path.exists():
|
||||
# Some CI matrix shards skip lockfile materialization.
|
||||
return
|
||||
text = lock_path.read_text(encoding="utf-8")
|
||||
assert "@askjo/camofox-browser" not in text, (
|
||||
"package-lock.json still references @askjo/camofox-browser. "
|
||||
"Regenerate the lockfile after removing the dep: "
|
||||
"`rm package-lock.json && npm install --package-lock-only "
|
||||
"--ignore-scripts --no-fund --no-audit`."
|
||||
)
|
||||
assert "camoufox-js" not in text, (
|
||||
"package-lock.json still references camoufox-js (transitive of "
|
||||
"@askjo/camofox-browser). Regenerate the lockfile."
|
||||
)
|
||||
137
tests/test_sanitize_tool_error.py
Normal file
137
tests/test_sanitize_tool_error.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
"""Tests for `_sanitize_tool_error` in model_tools.
|
||||
|
||||
Ported from ironclaw#1639 — defense-in-depth on tool exception strings before
|
||||
they enter the model's `tool` message content. Note that `json.dumps()` in
|
||||
`handle_function_call` already handles quote/backslash escaping at the wire
|
||||
layer; this helper exists to strip structural framing tokens the model
|
||||
itself might react to (XML role tags, CDATA, markdown code fences) and to
|
||||
cap pathological lengths.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from model_tools import _sanitize_tool_error, _TOOL_ERROR_MAX_LEN
|
||||
|
||||
|
||||
class TestRoleTagStripping:
|
||||
def test_strips_tool_call_tags(self):
|
||||
out = _sanitize_tool_error("bad <tool_call>injected</tool_call> happened")
|
||||
assert "<tool_call>" not in out
|
||||
assert "</tool_call>" not in out
|
||||
assert "bad injected happened" in out
|
||||
|
||||
def test_strips_function_call_tags(self):
|
||||
out = _sanitize_tool_error("<function_call>x</function_call>")
|
||||
assert "<function_call>" not in out
|
||||
assert "</function_call>" not in out
|
||||
|
||||
def test_strips_role_tags(self):
|
||||
# Each of these should be stripped
|
||||
for tag in ("system", "assistant", "user", "result", "response", "output", "input"):
|
||||
raw = f"prefix <{tag}>hi</{tag}> suffix"
|
||||
out = _sanitize_tool_error(raw)
|
||||
assert f"<{tag}>" not in out, f"failed to strip <{tag}>"
|
||||
assert f"</{tag}>" not in out, f"failed to strip </{tag}>"
|
||||
|
||||
def test_role_tag_strip_is_case_insensitive(self):
|
||||
out = _sanitize_tool_error("<TOOL_CALL>x</Tool_Call>")
|
||||
assert "<" not in out.replace("[TOOL_ERROR]", "") # only the prefix bracket survives
|
||||
|
||||
def test_unrelated_xml_kept(self):
|
||||
# We intentionally only strip the role-like tag whitelist, not all XML
|
||||
out = _sanitize_tool_error("Error parsing <ParseError>line 5</ParseError>")
|
||||
assert "<ParseError>" in out
|
||||
|
||||
|
||||
class TestCDATAStripping:
|
||||
def test_strips_cdata(self):
|
||||
out = _sanitize_tool_error("error: <![CDATA[malicious]]> here")
|
||||
assert "<![CDATA[" not in out
|
||||
assert "]]>" not in out
|
||||
|
||||
def test_strips_multiline_cdata(self):
|
||||
out = _sanitize_tool_error("a\n<![CDATA[line1\nline2]]>\nb")
|
||||
assert "CDATA" not in out
|
||||
assert "a" in out and "b" in out
|
||||
|
||||
|
||||
class TestCodeFenceStripping:
|
||||
def test_strips_leading_fence_with_lang(self):
|
||||
out = _sanitize_tool_error("```json\n{\"x\": 1}")
|
||||
assert not out.replace("[TOOL_ERROR] ", "").startswith("```")
|
||||
|
||||
def test_strips_trailing_fence(self):
|
||||
out = _sanitize_tool_error("payload\n```")
|
||||
assert not out.rstrip().endswith("```")
|
||||
|
||||
def test_strips_bare_fence(self):
|
||||
out = _sanitize_tool_error("```\nstuff")
|
||||
assert "```" not in out.split("\n")[0]
|
||||
|
||||
|
||||
class TestTruncation:
|
||||
def test_caps_long_input(self):
|
||||
long = "A" * (_TOOL_ERROR_MAX_LEN * 2)
|
||||
out = _sanitize_tool_error(long)
|
||||
# Total length is prefix + truncated body
|
||||
body = out[len("[TOOL_ERROR] "):]
|
||||
assert len(body) == _TOOL_ERROR_MAX_LEN
|
||||
assert body.endswith("...")
|
||||
|
||||
def test_does_not_truncate_short_input(self):
|
||||
msg = "short error"
|
||||
out = _sanitize_tool_error(msg)
|
||||
assert "..." not in out
|
||||
assert msg in out
|
||||
|
||||
|
||||
class TestEnvelope:
|
||||
def test_wraps_with_prefix(self):
|
||||
out = _sanitize_tool_error("oh no")
|
||||
assert out.startswith("[TOOL_ERROR] ")
|
||||
|
||||
def test_empty_input(self):
|
||||
out = _sanitize_tool_error("")
|
||||
assert out == "[TOOL_ERROR] "
|
||||
|
||||
def test_preserves_normal_error_text(self):
|
||||
msg = "Error executing read_file: FileNotFoundError: /tmp/missing"
|
||||
out = _sanitize_tool_error(msg)
|
||||
assert msg in out
|
||||
|
||||
|
||||
class TestHandleFunctionCallIntegration:
|
||||
"""Verify handle_function_call routes exception-path errors through the sanitizer.
|
||||
|
||||
Note: the "Unknown tool: ..." early-return in tools/registry.py is a
|
||||
*different* code path from `except Exception` in handle_function_call —
|
||||
that one returns directly without sanitization (and there's nothing to
|
||||
sanitize in a hardcoded format string anyway). This test exercises the
|
||||
real exception path by passing args that make a known tool raise.
|
||||
"""
|
||||
|
||||
def test_exception_path_error_is_sanitized(self):
|
||||
import json
|
||||
from model_tools import handle_function_call
|
||||
from tools.registry import registry as _registry
|
||||
|
||||
# Force a known tool to raise with a payload containing role tags.
|
||||
def boom(_args, **_kwargs):
|
||||
raise RuntimeError("<tool_call>injected</tool_call> boom")
|
||||
|
||||
all_tools = _registry.get_all_tool_names()
|
||||
assert all_tools, "no tools registered — test environment broken"
|
||||
target = all_tools[0]
|
||||
original = _registry._tools[target].handler
|
||||
_registry._tools[target].handler = boom
|
||||
try:
|
||||
result_str = handle_function_call(target, {})
|
||||
finally:
|
||||
_registry._tools[target].handler = original
|
||||
|
||||
payload = json.loads(result_str)
|
||||
assert "error" in payload, payload
|
||||
assert payload["error"].startswith("[TOOL_ERROR] "), payload["error"]
|
||||
# Role-tag stripping carried through
|
||||
assert "<tool_call>" not in payload["error"]
|
||||
assert "</tool_call>" not in payload["error"]
|
||||
assert "boom" in payload["error"]
|
||||
|
|
@ -4649,3 +4649,158 @@ def test_config_show_displays_nested_max_turns(monkeypatch):
|
|||
)
|
||||
|
||||
assert ["Max Turns", "120"] in agent_rows
|
||||
|
||||
|
||||
def test_notification_poller_delivers_completion(monkeypatch):
|
||||
"""Poller picks up completion events and triggers agent turns."""
|
||||
from tools.process_registry import process_registry
|
||||
|
||||
turns = []
|
||||
emitted = []
|
||||
|
||||
class _Agent:
|
||||
def run_conversation(self, prompt, conversation_history=None, stream_callback=None):
|
||||
turns.append(prompt)
|
||||
return {
|
||||
"final_response": "ok",
|
||||
"messages": [{"role": "assistant", "content": "ok"}],
|
||||
}
|
||||
|
||||
class _ImmediateThread:
|
||||
def __init__(self, target=None, daemon=None):
|
||||
self._target = target
|
||||
def start(self):
|
||||
self._target()
|
||||
|
||||
sess = _session(agent=_Agent())
|
||||
server._sessions["sid_poll"] = sess
|
||||
monkeypatch.setattr(server.threading, "Thread", _ImmediateThread)
|
||||
monkeypatch.setattr(server, "_emit", lambda *a, **kw: emitted.append(a))
|
||||
monkeypatch.setattr(server, "make_stream_renderer", lambda cols: None)
|
||||
monkeypatch.setattr(server, "render_message", lambda raw, cols: None)
|
||||
|
||||
# Clear queue
|
||||
while not process_registry.completion_queue.empty():
|
||||
process_registry.completion_queue.get_nowait()
|
||||
process_registry._completion_consumed.discard("proc_poller_test")
|
||||
|
||||
stop = threading.Event()
|
||||
|
||||
# Put event on queue, then immediately signal stop so the poller
|
||||
# runs exactly one iteration.
|
||||
process_registry.completion_queue.put({
|
||||
"type": "completion",
|
||||
"session_id": "proc_poller_test",
|
||||
"command": "echo hello",
|
||||
"exit_code": 0,
|
||||
"output": "hello",
|
||||
})
|
||||
stop.set()
|
||||
|
||||
try:
|
||||
server._notification_poller_loop(stop, "sid_poll", sess)
|
||||
|
||||
# Should have emitted a status.update with kind=process
|
||||
status_calls = [a for a in emitted if a[0] == "status.update"]
|
||||
assert len(status_calls) >= 1
|
||||
assert status_calls[0][2]["kind"] == "process"
|
||||
|
||||
# Should have triggered an agent turn
|
||||
assert len(turns) == 1
|
||||
assert "[IMPORTANT: Background process proc_poller_test completed" in turns[0]
|
||||
finally:
|
||||
server._sessions.pop("sid_poll", None)
|
||||
while not process_registry.completion_queue.empty():
|
||||
process_registry.completion_queue.get_nowait()
|
||||
|
||||
|
||||
def test_notification_poller_skips_consumed(monkeypatch):
|
||||
"""Already-consumed completions are not dispatched by the poller."""
|
||||
from tools.process_registry import process_registry
|
||||
|
||||
turns = []
|
||||
|
||||
class _Agent:
|
||||
def run_conversation(self, prompt, conversation_history=None, stream_callback=None):
|
||||
turns.append(prompt)
|
||||
return {"final_response": "ok", "messages": []}
|
||||
|
||||
class _ImmediateThread:
|
||||
def __init__(self, target=None, daemon=None):
|
||||
self._target = target
|
||||
def start(self):
|
||||
self._target()
|
||||
|
||||
sess = _session(agent=_Agent())
|
||||
server._sessions["sid_skip"] = sess
|
||||
monkeypatch.setattr(server.threading, "Thread", _ImmediateThread)
|
||||
monkeypatch.setattr(server, "_emit", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(server, "make_stream_renderer", lambda cols: None)
|
||||
monkeypatch.setattr(server, "render_message", lambda raw, cols: None)
|
||||
|
||||
while not process_registry.completion_queue.empty():
|
||||
process_registry.completion_queue.get_nowait()
|
||||
|
||||
process_registry._completion_consumed.add("proc_already_done")
|
||||
process_registry.completion_queue.put({
|
||||
"type": "completion",
|
||||
"session_id": "proc_already_done",
|
||||
"command": "echo x",
|
||||
"exit_code": 0,
|
||||
"output": "x",
|
||||
})
|
||||
|
||||
stop = threading.Event()
|
||||
stop.set()
|
||||
|
||||
try:
|
||||
server._notification_poller_loop(stop, "sid_skip", sess)
|
||||
assert len(turns) == 0
|
||||
finally:
|
||||
server._sessions.pop("sid_skip", None)
|
||||
process_registry._completion_consumed.discard("proc_already_done")
|
||||
while not process_registry.completion_queue.empty():
|
||||
process_registry.completion_queue.get_nowait()
|
||||
|
||||
|
||||
def test_notification_poller_requeues_when_busy(monkeypatch):
|
||||
"""When the agent is busy, the poller requeues the event."""
|
||||
from tools.process_registry import process_registry
|
||||
|
||||
emitted = []
|
||||
|
||||
sess = _session(running=True) # agent is busy
|
||||
server._sessions["sid_busy"] = sess
|
||||
monkeypatch.setattr(server, "_emit", lambda *a, **kw: emitted.append(a))
|
||||
|
||||
while not process_registry.completion_queue.empty():
|
||||
process_registry.completion_queue.get_nowait()
|
||||
process_registry._completion_consumed.discard("proc_busy_test")
|
||||
|
||||
evt = {
|
||||
"type": "completion",
|
||||
"session_id": "proc_busy_test",
|
||||
"command": "make build",
|
||||
"exit_code": 0,
|
||||
"output": "ok",
|
||||
}
|
||||
process_registry.completion_queue.put(evt)
|
||||
|
||||
stop = threading.Event()
|
||||
stop.set()
|
||||
|
||||
try:
|
||||
server._notification_poller_loop(stop, "sid_busy", sess)
|
||||
|
||||
# Status update was emitted (user sees it)
|
||||
status_calls = [a for a in emitted if a[0] == "status.update"]
|
||||
assert len(status_calls) == 1
|
||||
|
||||
# Event was requeued (agent was busy, no turn triggered)
|
||||
assert not process_registry.completion_queue.empty()
|
||||
requeued = process_registry.completion_queue.get_nowait()
|
||||
assert requeued["session_id"] == "proc_busy_test"
|
||||
finally:
|
||||
server._sessions.pop("sid_busy", None)
|
||||
while not process_registry.completion_queue.empty():
|
||||
process_registry.completion_queue.get_nowait()
|
||||
|
|
|
|||
|
|
@ -1102,3 +1102,206 @@ class TestDetectSudoStdin:
|
|||
"make 2>&1 | tee build.log"
|
||||
)
|
||||
assert is_dangerous is False
|
||||
|
||||
|
||||
class TestMacOSPrivateSystemPaths:
|
||||
"""Inspired by Claude Code 2.1.113 "dangerous path protection".
|
||||
|
||||
On macOS, /etc, /var, /tmp, /home are symlinks to
|
||||
/private/{etc,var,tmp,home}. A command that writes to
|
||||
/private/etc/sudoers works identically to /etc/sudoers but bypasses
|
||||
a plain "/etc/" pattern check. These tests guard the shared
|
||||
_SYSTEM_CONFIG_PATH fragment used across redirect / tee / cp / mv /
|
||||
install / sed -i patterns.
|
||||
"""
|
||||
|
||||
def test_private_etc_redirect(self):
|
||||
dangerous, _, desc = detect_dangerous_command(
|
||||
"echo 'root ALL=NOPASSWD: ALL' > /private/etc/sudoers"
|
||||
)
|
||||
assert dangerous is True
|
||||
assert "system config" in desc.lower()
|
||||
|
||||
def test_private_var_redirect(self):
|
||||
dangerous, _, _ = detect_dangerous_command(
|
||||
"echo payload > /private/var/db/dslocal/nodes/x"
|
||||
)
|
||||
assert dangerous is True
|
||||
|
||||
def test_private_etc_via_tee(self):
|
||||
dangerous, _, desc = detect_dangerous_command(
|
||||
"echo malicious | tee /private/etc/hosts"
|
||||
)
|
||||
assert dangerous is True
|
||||
assert "tee" in desc.lower() or "system" in desc.lower()
|
||||
|
||||
def test_private_etc_cp(self):
|
||||
dangerous, _, desc = detect_dangerous_command(
|
||||
"cp malicious.conf /private/etc/hosts"
|
||||
)
|
||||
assert dangerous is True
|
||||
assert "copy" in desc.lower() or "system config" in desc.lower()
|
||||
|
||||
def test_private_etc_mv(self):
|
||||
dangerous, _, _ = detect_dangerous_command(
|
||||
"mv evil /private/etc/ssh/sshd_config"
|
||||
)
|
||||
assert dangerous is True
|
||||
|
||||
def test_private_etc_install(self):
|
||||
dangerous, _, _ = detect_dangerous_command(
|
||||
"install -m 600 key /private/etc/ssh/keys"
|
||||
)
|
||||
assert dangerous is True
|
||||
|
||||
def test_private_etc_sed_in_place(self):
|
||||
dangerous, _, desc = detect_dangerous_command(
|
||||
"sed -i 's/root/pwned/' /private/etc/passwd"
|
||||
)
|
||||
assert dangerous is True
|
||||
assert "in-place" in desc.lower() or "system config" in desc.lower()
|
||||
|
||||
def test_private_var_sed_long_flag(self):
|
||||
dangerous, _, _ = detect_dangerous_command(
|
||||
"sed --in-place 's/x/y/' /private/var/log/wtmp"
|
||||
)
|
||||
assert dangerous is True
|
||||
|
||||
def test_private_tmp_cp(self):
|
||||
dangerous, _, _ = detect_dangerous_command(
|
||||
"cp rootkit /private/tmp/payload"
|
||||
)
|
||||
assert dangerous is True
|
||||
|
||||
def test_ls_private_is_safe(self):
|
||||
"""Reading under /private/ must not trigger approval."""
|
||||
dangerous, _, _ = detect_dangerous_command("ls /private")
|
||||
assert dangerous is False
|
||||
|
||||
def test_echo_mentioning_private_path_is_safe(self):
|
||||
"""Literal mention of /private/etc in an echo string must not fire."""
|
||||
dangerous, _, _ = detect_dangerous_command(
|
||||
"echo 'the macOS path is /private/etc on disk'"
|
||||
)
|
||||
assert dangerous is False
|
||||
|
||||
|
||||
class TestKillallKillSignals:
|
||||
"""Inspired by Claude Code 2.1.113 expanded deny rules.
|
||||
|
||||
The existing pattern caught `pkill -9` but not the equivalent
|
||||
`killall -9` / `-KILL` / `-s KILL` / `-r <regex>` broad sweeps that
|
||||
can wipe out unrelated processes.
|
||||
"""
|
||||
|
||||
def test_killall_dash_9(self):
|
||||
dangerous, _, desc = detect_dangerous_command("killall -9 firefox")
|
||||
assert dangerous is True
|
||||
assert "kill" in desc.lower()
|
||||
|
||||
def test_killall_dash_kill(self):
|
||||
dangerous, _, _ = detect_dangerous_command("killall -KILL firefox")
|
||||
assert dangerous is True
|
||||
|
||||
def test_killall_dash_sigkill(self):
|
||||
dangerous, _, _ = detect_dangerous_command("killall -SIGKILL firefox")
|
||||
assert dangerous is True
|
||||
|
||||
def test_killall_dash_s_kill(self):
|
||||
dangerous, _, _ = detect_dangerous_command("killall -s KILL firefox")
|
||||
assert dangerous is True
|
||||
|
||||
def test_killall_dash_s_signum(self):
|
||||
dangerous, _, _ = detect_dangerous_command("killall -s 9 firefox")
|
||||
assert dangerous is True
|
||||
|
||||
def test_killall_regex(self):
|
||||
"""killall -r <regex> is a broad sweep; require approval."""
|
||||
dangerous, _, desc = detect_dangerous_command("killall -r 'fire.*'")
|
||||
assert dangerous is True
|
||||
assert "regex" in desc.lower() or "kill" in desc.lower()
|
||||
|
||||
def test_killall_combined_flags(self):
|
||||
dangerous, _, _ = detect_dangerous_command("killall -9 -r 'herm.*'")
|
||||
assert dangerous is True
|
||||
|
||||
def test_killall_list_signals_is_safe(self):
|
||||
"""`killall -l` lists signals and is harmless — must not fire."""
|
||||
dangerous, _, _ = detect_dangerous_command("killall -l")
|
||||
assert dangerous is False
|
||||
|
||||
def test_killall_version_is_safe(self):
|
||||
dangerous, _, _ = detect_dangerous_command("killall -V")
|
||||
assert dangerous is False
|
||||
|
||||
|
||||
class TestFindExecdir:
|
||||
"""Inspired by Claude Code 2.1.113 tightening of find rules.
|
||||
|
||||
`find -execdir rm` has the same destructive effect as `find -exec rm`
|
||||
but ran in each match's directory. Previously missed because the
|
||||
pattern required a literal `-exec ` followed by a space.
|
||||
"""
|
||||
|
||||
def test_find_execdir_rm(self):
|
||||
dangerous, _, desc = detect_dangerous_command(
|
||||
"find . -execdir rm {} \\;"
|
||||
)
|
||||
assert dangerous is True
|
||||
assert "find" in desc.lower() or "rm" in desc.lower()
|
||||
|
||||
def test_find_execdir_with_absolute_rm(self):
|
||||
dangerous, _, _ = detect_dangerous_command(
|
||||
"find /var -execdir /bin/rm -rf {} \\;"
|
||||
)
|
||||
assert dangerous is True
|
||||
|
||||
def test_find_exec_rm_still_caught(self):
|
||||
"""Original -exec pattern must still fire (regression guard)."""
|
||||
dangerous, _, _ = detect_dangerous_command(
|
||||
"find . -exec rm {} \\;"
|
||||
)
|
||||
assert dangerous is True
|
||||
|
||||
def test_find_execdir_ls_is_safe(self):
|
||||
"""-execdir with a read-only command is not dangerous."""
|
||||
dangerous, _, _ = detect_dangerous_command(
|
||||
"find . -execdir ls {} \\;"
|
||||
)
|
||||
assert dangerous is False
|
||||
|
||||
|
||||
class TestEtcPatternsUnaffectedByRefactor:
|
||||
"""Regression guard: the /etc/ patterns were refactored to share the
|
||||
_SYSTEM_CONFIG_PATH fragment with the /private/ mirror. Make sure the
|
||||
existing /etc/ coverage remains identical.
|
||||
"""
|
||||
|
||||
def test_etc_redirect(self):
|
||||
dangerous, _, _ = detect_dangerous_command("echo x > /etc/hosts")
|
||||
assert dangerous is True
|
||||
|
||||
def test_etc_cp(self):
|
||||
dangerous, _, _ = detect_dangerous_command("cp evil /etc/hosts")
|
||||
assert dangerous is True
|
||||
|
||||
def test_etc_sed_inline(self):
|
||||
dangerous, _, _ = detect_dangerous_command(
|
||||
"sed -i 's/a/b/' /etc/hosts"
|
||||
)
|
||||
assert dangerous is True
|
||||
|
||||
def test_etc_tee(self):
|
||||
dangerous, _, _ = detect_dangerous_command(
|
||||
"echo x | tee /etc/hosts"
|
||||
)
|
||||
assert dangerous is True
|
||||
|
||||
def test_cat_etc_hostname_is_safe(self):
|
||||
"""Reading /etc/ files is safe — only writes require approval."""
|
||||
dangerous, _, _ = detect_dangerous_command("cat /etc/hostname")
|
||||
assert dangerous is False
|
||||
|
||||
def test_grep_etc_passwd_is_safe(self):
|
||||
dangerous, _, _ = detect_dangerous_command("grep root /etc/passwd")
|
||||
assert dangerous is False
|
||||
|
|
|
|||
|
|
@ -122,6 +122,27 @@ class TestCronjobRequirements:
|
|||
|
||||
assert check_cronjob_requirements() is False
|
||||
|
||||
@pytest.mark.parametrize("false_like_value", ["0", "false", "no", "off"])
|
||||
def test_rejects_false_like_interactive_env(self, monkeypatch, false_like_value):
|
||||
monkeypatch.setenv("HERMES_INTERACTIVE", false_like_value)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
|
||||
monkeypatch.delenv("HERMES_EXEC_ASK", raising=False)
|
||||
assert check_cronjob_requirements() is False
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"var_name",
|
||||
["HERMES_INTERACTIVE", "HERMES_GATEWAY_SESSION", "HERMES_EXEC_ASK"],
|
||||
)
|
||||
@pytest.mark.parametrize("false_like_value", ["0", "false", "no", "off"])
|
||||
def test_rejects_false_like_any_session_env(
|
||||
self, monkeypatch, var_name, false_like_value
|
||||
):
|
||||
"""All three session env vars share the same truthy semantics."""
|
||||
for v in ("HERMES_INTERACTIVE", "HERMES_GATEWAY_SESSION", "HERMES_EXEC_ASK"):
|
||||
monkeypatch.delenv(v, raising=False)
|
||||
monkeypatch.setenv(var_name, false_like_value)
|
||||
assert check_cronjob_requirements() is False
|
||||
|
||||
|
||||
class TestUnifiedCronjobTool:
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
|
|||
|
|
@ -890,6 +890,63 @@ class TestDelegationCredentialResolution(unittest.TestCase):
|
|||
self.assertEqual(creds["api_key"], "local-key")
|
||||
self.assertEqual(creds["api_mode"], "chat_completions")
|
||||
|
||||
def test_direct_endpoint_auto_detects_anthropic_messages_suffix(self):
|
||||
# Issue #10213: Azure AI Foundry exposes Anthropic-compatible models at
|
||||
# a /anthropic URL suffix. Subagents must pick anthropic_messages
|
||||
# automatically, matching the main agent's runtime resolver.
|
||||
parent = _make_mock_parent(depth=0)
|
||||
cfg = {
|
||||
"model": "claude-opus-4-6",
|
||||
"provider": "custom",
|
||||
"base_url": "https://myfoundry.services.ai.azure.com/anthropic",
|
||||
"api_key": "foundry-key",
|
||||
}
|
||||
creds = _resolve_delegation_credentials(cfg, parent)
|
||||
self.assertEqual(creds["provider"], "custom")
|
||||
self.assertEqual(creds["base_url"], "https://myfoundry.services.ai.azure.com/anthropic")
|
||||
self.assertEqual(creds["api_key"], "foundry-key")
|
||||
self.assertEqual(creds["api_mode"], "anthropic_messages")
|
||||
|
||||
def test_direct_endpoint_honors_explicit_api_mode(self):
|
||||
# When delegation.api_mode is set explicitly, it overrides URL-based
|
||||
# detection so users can force a transport on non-standard endpoints.
|
||||
parent = _make_mock_parent(depth=0)
|
||||
cfg = {
|
||||
"model": "claude-opus-4-6",
|
||||
"provider": "custom",
|
||||
"base_url": "https://proxy.example.com/v1",
|
||||
"api_key": "proxy-key",
|
||||
"api_mode": "anthropic_messages",
|
||||
}
|
||||
creds = _resolve_delegation_credentials(cfg, parent)
|
||||
self.assertEqual(creds["api_mode"], "anthropic_messages")
|
||||
|
||||
def test_direct_endpoint_explicit_api_mode_overrides_url_detection(self):
|
||||
# Explicit api_mode in config always wins over auto-detection.
|
||||
parent = _make_mock_parent(depth=0)
|
||||
cfg = {
|
||||
"model": "claude-opus-4-6",
|
||||
"provider": "custom",
|
||||
"base_url": "https://myfoundry.services.ai.azure.com/anthropic",
|
||||
"api_key": "foundry-key",
|
||||
"api_mode": "chat_completions",
|
||||
}
|
||||
creds = _resolve_delegation_credentials(cfg, parent)
|
||||
self.assertEqual(creds["api_mode"], "chat_completions")
|
||||
|
||||
def test_direct_endpoint_invalid_api_mode_falls_back_to_detection(self):
|
||||
# An invalid api_mode string must not break detection; fall back to URL heuristic.
|
||||
parent = _make_mock_parent(depth=0)
|
||||
cfg = {
|
||||
"model": "claude-opus-4-6",
|
||||
"provider": "custom",
|
||||
"base_url": "https://myfoundry.services.ai.azure.com/anthropic",
|
||||
"api_key": "foundry-key",
|
||||
"api_mode": "garbage",
|
||||
}
|
||||
creds = _resolve_delegation_credentials(cfg, parent)
|
||||
self.assertEqual(creds["api_mode"], "anthropic_messages")
|
||||
|
||||
def test_direct_endpoint_returns_none_api_key_when_not_configured(self):
|
||||
# When base_url is set without api_key, api_key should be None so
|
||||
# _build_child_agent inherits the parent's key (effective_api_key = override or parent).
|
||||
|
|
|
|||
|
|
@ -37,3 +37,62 @@ def test_fal_key_empty_is_unset(monkeypatch):
|
|||
)
|
||||
|
||||
assert image_generation_tool.check_fal_api_key() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Actionable setup message when no FAL backend is reachable.
|
||||
# Regression for the silent-drop UX gap described in issue #2543.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_no_backend_message_mentions_fal_signup_and_plugins(monkeypatch):
|
||||
from tools import image_generation_tool
|
||||
|
||||
monkeypatch.setattr(
|
||||
image_generation_tool, "managed_nous_tools_enabled", lambda: False
|
||||
)
|
||||
|
||||
msg = image_generation_tool._build_no_backend_setup_message()
|
||||
|
||||
assert "FAL_KEY" in msg
|
||||
assert "https://fal.ai" in msg
|
||||
# Plugin pointer so users on a stale image_gen.provider know where to look.
|
||||
assert "hermes tools" in msg or "hermes plugins" in msg
|
||||
|
||||
|
||||
def test_no_backend_message_mentions_managed_gateway_when_enabled(monkeypatch):
|
||||
from tools import image_generation_tool
|
||||
|
||||
monkeypatch.setattr(
|
||||
image_generation_tool, "managed_nous_tools_enabled", lambda: True
|
||||
)
|
||||
|
||||
msg = image_generation_tool._build_no_backend_setup_message()
|
||||
|
||||
assert "managed FAL gateway" in msg
|
||||
assert "Nous account" in msg or "hermes setup" in msg
|
||||
|
||||
|
||||
def test_image_generate_tool_returns_actionable_error_when_no_backend(monkeypatch):
|
||||
"""End-to-end: handler must surface the actionable message, not a bare string."""
|
||||
import json
|
||||
|
||||
from tools import image_generation_tool
|
||||
|
||||
monkeypatch.setattr(
|
||||
image_generation_tool, "fal_key_is_configured", lambda: False
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
image_generation_tool, "_resolve_managed_fal_gateway", lambda: None
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
image_generation_tool, "managed_nous_tools_enabled", lambda: False
|
||||
)
|
||||
|
||||
result = json.loads(
|
||||
image_generation_tool.image_generate_tool(prompt="a cat")
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "https://fal.ai" in result["error"]
|
||||
assert "FAL_KEY" in result["error"]
|
||||
|
|
|
|||
200
tests/tools/test_local_env_windows_msys.py
Normal file
200
tests/tools/test_local_env_windows_msys.py
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
"""Tests for the Windows / Git Bash MSYS-path normalization in
|
||||
``LocalEnvironment``.
|
||||
|
||||
Background
|
||||
----------
|
||||
On Windows, ``pwd -P`` inside Git Bash emits paths like
|
||||
``/c/Users/NVIDIA``. ``subprocess.Popen(..., cwd=...)`` only accepts
|
||||
native Windows paths (``C:\\Users\\NVIDIA``), and the validation done
|
||||
by ``_resolve_safe_cwd`` was also checking the MSYS form against
|
||||
``os.path.isdir``, which returns ``False`` on Windows. The combined
|
||||
effect was a warning logged on every single terminal call:
|
||||
|
||||
LocalEnvironment cwd '/c/Users/NVIDIA' is missing on disk;
|
||||
falling back to '/' so terminal commands keep working.
|
||||
|
||||
These tests fake the Windows env on Linux CI by patching ``_IS_WINDOWS``
|
||||
and ``os.path.isdir`` so the MSYS path tests as "missing" exactly like
|
||||
on the real OS.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments import local as local_mod
|
||||
from tools.environments.local import (
|
||||
LocalEnvironment,
|
||||
_msys_to_windows_path,
|
||||
_resolve_safe_cwd,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _msys_to_windows_path — pure-function unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMsysToWindowsPath:
|
||||
def test_noop_on_non_windows(self, monkeypatch):
|
||||
monkeypatch.setattr(local_mod, "_IS_WINDOWS", False)
|
||||
# On a non-Windows host the function must never rewrite the path
|
||||
# — POSIX-style paths are real paths there.
|
||||
assert _msys_to_windows_path("/c/Users/NVIDIA") == "/c/Users/NVIDIA"
|
||||
assert _msys_to_windows_path("/home/teknium") == "/home/teknium"
|
||||
|
||||
def test_translates_drive_path(self, monkeypatch):
|
||||
monkeypatch.setattr(local_mod, "_IS_WINDOWS", True)
|
||||
assert _msys_to_windows_path("/c/Users/NVIDIA") == r"C:\Users\NVIDIA"
|
||||
assert _msys_to_windows_path("/d/Projects/foo bar") == r"D:\Projects\foo bar"
|
||||
|
||||
def test_translates_bare_drive_root(self, monkeypatch):
|
||||
monkeypatch.setattr(local_mod, "_IS_WINDOWS", True)
|
||||
# Bare "/c" alone should resolve to the drive root.
|
||||
assert _msys_to_windows_path("/c") == "C:\\"
|
||||
# Trailing slash on the drive letter is also a root.
|
||||
assert _msys_to_windows_path("/c/") == "C:\\"
|
||||
|
||||
def test_idempotent_on_already_windows_path(self, monkeypatch):
|
||||
monkeypatch.setattr(local_mod, "_IS_WINDOWS", True)
|
||||
assert _msys_to_windows_path(r"C:\Users\NVIDIA") == r"C:\Users\NVIDIA"
|
||||
|
||||
def test_does_not_translate_multi_char_first_segment(self, monkeypatch):
|
||||
"""``/tmp/foo`` and ``/home/x`` must NOT be misread as drive paths
|
||||
just because they start with ``/`` and a single letter — the regex
|
||||
only matches when the first segment is exactly one character."""
|
||||
monkeypatch.setattr(local_mod, "_IS_WINDOWS", True)
|
||||
assert _msys_to_windows_path("/tmp/foo") == "/tmp/foo"
|
||||
assert _msys_to_windows_path("/home/x") == "/home/x"
|
||||
|
||||
def test_empty_string(self, monkeypatch):
|
||||
monkeypatch.setattr(local_mod, "_IS_WINDOWS", True)
|
||||
assert _msys_to_windows_path("") == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_safe_cwd — Windows fast path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResolveSafeCwdWindows:
|
||||
def test_msys_path_resolves_to_native_when_native_exists(
|
||||
self, monkeypatch, tmp_path,
|
||||
):
|
||||
"""The whole point of this fix: a Git Bash ``/c/Users/x`` value
|
||||
should resolve to its native equivalent if that native dir exists,
|
||||
WITHOUT falling back to the temp dir."""
|
||||
monkeypatch.setattr(local_mod, "_IS_WINDOWS", True)
|
||||
|
||||
# tmp_path is a real native dir on the test host. Build a fake
|
||||
# MSYS form pointing at it and prove the resolver finds it.
|
||||
native = str(tmp_path)
|
||||
# Construct a synthetic MSYS form for whatever tmp_path is.
|
||||
# On Linux CI tmp_path is /tmp/... ; the resolver shouldn't even
|
||||
# try to translate that (regex won't match), so emulate the
|
||||
# mapping by pointing the translator at the real native dir.
|
||||
with patch.object(
|
||||
local_mod, "_msys_to_windows_path", return_value=native
|
||||
):
|
||||
assert _resolve_safe_cwd("/c/whatever") == native
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# End-to-end: _update_cwd via marker file (Windows simulation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUpdateCwdWindowsMsys:
|
||||
def test_marker_file_msys_path_stored_in_native_form(
|
||||
self, monkeypatch, tmp_path,
|
||||
):
|
||||
"""When Git Bash writes ``/c/Users/x`` to the cwd marker file on
|
||||
Windows, ``_update_cwd`` must translate to native form before
|
||||
validating and storing — otherwise ``os.path.isdir`` rejects a
|
||||
perfectly real directory."""
|
||||
original = tmp_path / "starting"
|
||||
original.mkdir()
|
||||
|
||||
# Fake Windows for the test
|
||||
monkeypatch.setattr(local_mod, "_IS_WINDOWS", True)
|
||||
|
||||
with patch.object(
|
||||
LocalEnvironment, "init_session", autospec=True, return_value=None
|
||||
):
|
||||
env = LocalEnvironment(cwd=str(original), timeout=10)
|
||||
|
||||
# Pretend Git Bash wrote an MSYS path that maps to tmp_path/"next"
|
||||
new_dir = tmp_path / "next"
|
||||
new_dir.mkdir()
|
||||
|
||||
with open(env._cwd_file, "w") as f:
|
||||
f.write("/c/whatever/from/bash")
|
||||
|
||||
# Translate the synthetic MSYS string to the real native dir.
|
||||
def fake_translate(p):
|
||||
if p == "/c/whatever/from/bash":
|
||||
return str(new_dir)
|
||||
return p
|
||||
|
||||
with patch.object(local_mod, "_msys_to_windows_path", side_effect=fake_translate):
|
||||
env._update_cwd({"output": "", "returncode": 0})
|
||||
|
||||
assert env.cwd == str(new_dir)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# End-to-end: _extract_cwd_from_output rollback when marker is invalid
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractCwdFromOutputWindowsMsys:
|
||||
def test_stale_msys_marker_does_not_clobber_cwd(self, monkeypatch, tmp_path):
|
||||
"""When the cwd marker in stdout points at a non-existent path,
|
||||
``LocalEnvironment._extract_cwd_from_output`` must roll back to
|
||||
the previous cwd instead of propagating a bad value."""
|
||||
original = tmp_path / "starting"
|
||||
original.mkdir()
|
||||
|
||||
monkeypatch.setattr(local_mod, "_IS_WINDOWS", True)
|
||||
|
||||
with patch.object(
|
||||
LocalEnvironment, "init_session", autospec=True, return_value=None
|
||||
):
|
||||
env = LocalEnvironment(cwd=str(original), timeout=10)
|
||||
|
||||
marker = env._cwd_marker
|
||||
result = {
|
||||
"output": f"some command output\n{marker}/c/no/such/path{marker}\n",
|
||||
"returncode": 0,
|
||||
}
|
||||
|
||||
# Translation produces a path that doesn't exist on disk → rollback.
|
||||
with patch.object(
|
||||
local_mod,
|
||||
"_msys_to_windows_path",
|
||||
return_value=str(tmp_path / "definitely-does-not-exist"),
|
||||
):
|
||||
env._extract_cwd_from_output(result)
|
||||
|
||||
assert env.cwd == str(original)
|
||||
|
||||
def test_valid_msys_marker_normalized_to_native(self, monkeypatch, tmp_path):
|
||||
original = tmp_path / "starting"
|
||||
original.mkdir()
|
||||
new_dir = tmp_path / "next"
|
||||
new_dir.mkdir()
|
||||
|
||||
monkeypatch.setattr(local_mod, "_IS_WINDOWS", True)
|
||||
|
||||
with patch.object(
|
||||
LocalEnvironment, "init_session", autospec=True, return_value=None
|
||||
):
|
||||
env = LocalEnvironment(cwd=str(original), timeout=10)
|
||||
|
||||
marker = env._cwd_marker
|
||||
result = {
|
||||
"output": f"x\n{marker}/c/whatever{marker}\n",
|
||||
"returncode": 0,
|
||||
}
|
||||
|
||||
with patch.object(local_mod, "_msys_to_windows_path", return_value=str(new_dir)):
|
||||
env._extract_cwd_from_output(result)
|
||||
|
||||
assert env.cwd == str(new_dir)
|
||||
125
tests/tools/test_mcp_invalid_url.py
Normal file
125
tests/tools/test_mcp_invalid_url.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
"""Tests for the MCP remote-URL validator.
|
||||
|
||||
Ported from anomalyco/opencode#25019 (``fix: handle invalid mcp urls``).
|
||||
|
||||
Previously, a typo in ``config.yaml`` (missing scheme, wrong scheme, empty
|
||||
string, dict where a URL was expected) caused the MCP server startup code
|
||||
to enter httpx's URL-parsing path and crash inside the transport layer.
|
||||
The reconnect-backoff loop would then retry
|
||||
``_MAX_INITIAL_CONNECT_RETRIES`` times with doubling backoff — a minute or
|
||||
more of pointless retries plus a confusing opaque error message — before
|
||||
eventually giving up.
|
||||
|
||||
The fix validates the URL once, up front, and fails fast with a specific
|
||||
error message identifying the offending server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.mcp_tool import (
|
||||
InvalidMcpUrlError,
|
||||
_validate_remote_mcp_url,
|
||||
)
|
||||
|
||||
|
||||
class TestValidUrlsAccepted:
|
||||
"""Every valid http(s) URL must pass through untouched (stripped of whitespace)."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url",
|
||||
[
|
||||
"http://localhost:3000/mcp",
|
||||
"https://example.com/mcp",
|
||||
"https://context7.liam.com/mcp",
|
||||
"http://127.0.0.1:8080",
|
||||
"https://api.example.com:443/v1/mcp?session=abc",
|
||||
"http://[::1]:9000/mcp", # IPv6
|
||||
"https://host.example.com", # no port, no path
|
||||
],
|
||||
)
|
||||
def test_accepts_valid_http_url(self, url):
|
||||
assert _validate_remote_mcp_url("test", url) == url
|
||||
|
||||
def test_strips_surrounding_whitespace(self):
|
||||
assert (
|
||||
_validate_remote_mcp_url("test", " https://example.com/mcp ")
|
||||
== "https://example.com/mcp"
|
||||
)
|
||||
|
||||
|
||||
class TestInvalidUrlsRejected:
|
||||
"""Every broken shape must raise ``InvalidMcpUrlError`` with a clear message."""
|
||||
|
||||
def test_none_rejected(self):
|
||||
with pytest.raises(InvalidMcpUrlError, match="context7.*expected a string"):
|
||||
_validate_remote_mcp_url("context7", None)
|
||||
|
||||
def test_dict_rejected(self):
|
||||
with pytest.raises(InvalidMcpUrlError, match="expected a string, got dict"):
|
||||
_validate_remote_mcp_url("ctx", {"url": "nested"})
|
||||
|
||||
def test_int_rejected(self):
|
||||
with pytest.raises(InvalidMcpUrlError, match="expected a string, got int"):
|
||||
_validate_remote_mcp_url("ctx", 8080)
|
||||
|
||||
def test_empty_string_rejected(self):
|
||||
with pytest.raises(InvalidMcpUrlError, match="empty url"):
|
||||
_validate_remote_mcp_url("ctx", "")
|
||||
|
||||
def test_whitespace_only_rejected(self):
|
||||
with pytest.raises(InvalidMcpUrlError, match="empty url"):
|
||||
_validate_remote_mcp_url("ctx", " \t\n")
|
||||
|
||||
def test_missing_scheme_rejected(self):
|
||||
# The most common typo — users copy a host from a web page.
|
||||
with pytest.raises(
|
||||
InvalidMcpUrlError, match="scheme must be http or https"
|
||||
):
|
||||
_validate_remote_mcp_url("ctx", "example.com/mcp")
|
||||
|
||||
def test_file_scheme_rejected(self):
|
||||
with pytest.raises(
|
||||
InvalidMcpUrlError, match="scheme must be http or https"
|
||||
):
|
||||
_validate_remote_mcp_url("ctx", "file:///etc/passwd")
|
||||
|
||||
def test_ws_scheme_rejected(self):
|
||||
# WebSocket is not MCP's remote transport.
|
||||
with pytest.raises(
|
||||
InvalidMcpUrlError, match="scheme must be http or https"
|
||||
):
|
||||
_validate_remote_mcp_url("ctx", "ws://example.com/mcp")
|
||||
|
||||
def test_stdio_scheme_rejected(self):
|
||||
# stdio servers use the ``command`` key, not ``url``.
|
||||
with pytest.raises(
|
||||
InvalidMcpUrlError, match="scheme must be http or https"
|
||||
):
|
||||
_validate_remote_mcp_url("ctx", "stdio:///node server.js")
|
||||
|
||||
def test_empty_host_rejected(self):
|
||||
with pytest.raises(InvalidMcpUrlError, match="missing host"):
|
||||
_validate_remote_mcp_url("ctx", "http:///")
|
||||
|
||||
def test_empty_host_with_path_rejected(self):
|
||||
with pytest.raises(InvalidMcpUrlError, match="missing host"):
|
||||
_validate_remote_mcp_url("ctx", "https:///path/only")
|
||||
|
||||
def test_error_mentions_server_name(self):
|
||||
# So users can find the bad entry when there are multiple configured.
|
||||
with pytest.raises(InvalidMcpUrlError, match="my-weird-server"):
|
||||
_validate_remote_mcp_url("my-weird-server", "not a url at all")
|
||||
|
||||
|
||||
class TestErrorIsValueError:
|
||||
"""InvalidMcpUrlError must be a ValueError for broad downstream catch blocks."""
|
||||
|
||||
def test_is_value_error(self):
|
||||
try:
|
||||
_validate_remote_mcp_url("ctx", "garbage")
|
||||
except ValueError:
|
||||
pass # expected
|
||||
else:
|
||||
pytest.fail("expected ValueError")
|
||||
|
|
@ -69,7 +69,8 @@ class TestProbeMcpServerTools:
|
|||
patch("tools.mcp_tool._stop_mcp_loop"):
|
||||
|
||||
# Simulate running the async probe
|
||||
def run_coro(coro, timeout=120):
|
||||
def run_coro(coro_or_factory, timeout=120):
|
||||
coro = coro_or_factory() if callable(coro_or_factory) else coro_or_factory
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
|
|
@ -110,7 +111,8 @@ class TestProbeMcpServerTools:
|
|||
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
|
||||
patch("tools.mcp_tool._stop_mcp_loop"):
|
||||
|
||||
def run_coro(coro, timeout=120):
|
||||
def run_coro(coro_or_factory, timeout=120):
|
||||
coro = coro_or_factory() if callable(coro_or_factory) else coro_or_factory
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
|
|
@ -144,7 +146,8 @@ class TestProbeMcpServerTools:
|
|||
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
|
||||
patch("tools.mcp_tool._stop_mcp_loop"):
|
||||
|
||||
def run_coro(coro, timeout=120):
|
||||
def run_coro(coro_or_factory, timeout=120):
|
||||
coro = coro_or_factory() if callable(coro_or_factory) else coro_or_factory
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
|
|
@ -198,7 +201,8 @@ class TestProbeMcpServerTools:
|
|||
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
|
||||
patch("tools.mcp_tool._stop_mcp_loop"):
|
||||
|
||||
def run_coro(coro, timeout=120):
|
||||
def run_coro(coro_or_factory, timeout=120):
|
||||
coro = coro_or_factory() if callable(coro_or_factory) else coro_or_factory
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,8 @@ class _FakeCallToolResult:
|
|||
self.structuredContent = structuredContent
|
||||
|
||||
|
||||
def _fake_run_on_mcp_loop(coro, timeout=30):
|
||||
def _fake_run_on_mcp_loop(coro_or_factory, timeout=30):
|
||||
coro = coro_or_factory() if callable(coro_or_factory) else coro_or_factory
|
||||
"""Run an MCP coroutine directly in a fresh event loop."""
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -397,6 +397,77 @@ class TestCheckFunction:
|
|||
_servers.pop("test_server", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP loop runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRunOnMcpLoop:
|
||||
def test_scheduler_failure_closes_factory_coroutine(self):
|
||||
"""If run_coroutine_threadsafe raises, the factory's coroutine is closed."""
|
||||
import gc
|
||||
import warnings
|
||||
import tools.mcp_tool as mcp
|
||||
|
||||
created = {"coro": None}
|
||||
|
||||
async def _sample():
|
||||
return "ok"
|
||||
|
||||
def factory():
|
||||
created["coro"] = _sample()
|
||||
return created["coro"]
|
||||
|
||||
fake_loop = MagicMock()
|
||||
fake_loop.is_running.return_value = True
|
||||
|
||||
with patch.object(mcp, "_mcp_loop", fake_loop):
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
with patch(
|
||||
"agent.async_utils.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=RuntimeError("scheduler down"),
|
||||
):
|
||||
with pytest.raises(RuntimeError):
|
||||
mcp._run_on_mcp_loop(factory)
|
||||
gc.collect()
|
||||
|
||||
assert created["coro"] is not None
|
||||
assert created["coro"].cr_frame is None
|
||||
runtime_warnings = [
|
||||
w for w in caught
|
||||
if issubclass(w.category, RuntimeWarning)
|
||||
and "was never awaited" in str(w.message)
|
||||
and "_sample" in str(w.message)
|
||||
]
|
||||
assert runtime_warnings == []
|
||||
|
||||
def test_dead_loop_closes_passed_coroutine(self):
|
||||
"""If loop is None, a passed coroutine (not factory) is closed."""
|
||||
import gc
|
||||
import warnings
|
||||
import tools.mcp_tool as mcp
|
||||
|
||||
async def _sample():
|
||||
return "ok"
|
||||
|
||||
coro = _sample()
|
||||
with patch.object(mcp, "_mcp_loop", None):
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
with pytest.raises(RuntimeError, match="not running"):
|
||||
mcp._run_on_mcp_loop(coro)
|
||||
gc.collect()
|
||||
|
||||
assert coro.cr_frame is None
|
||||
runtime_warnings = [
|
||||
w for w in caught
|
||||
if issubclass(w.category, RuntimeWarning)
|
||||
and "was never awaited" in str(w.message)
|
||||
and "_sample" in str(w.message)
|
||||
]
|
||||
assert runtime_warnings == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -406,7 +477,8 @@ class TestToolHandler:
|
|||
|
||||
def _patch_mcp_loop(self, coro_side_effect=None):
|
||||
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
|
||||
def fake_run(coro, timeout=30):
|
||||
def fake_run(coro_or_factory, timeout=30):
|
||||
coro = coro_or_factory() if callable(coro_or_factory) else coro_or_factory
|
||||
return asyncio.run(coro)
|
||||
if coro_side_effect:
|
||||
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=coro_side_effect)
|
||||
|
|
@ -485,7 +557,8 @@ class TestToolHandler:
|
|||
|
||||
try:
|
||||
handler = _make_tool_handler("test_srv", "greet", 120)
|
||||
def _interrupting_run(coro, timeout=30):
|
||||
def _interrupting_run(coro_or_factory, timeout=30):
|
||||
coro = coro_or_factory() if callable(coro_or_factory) else coro_or_factory
|
||||
coro.close()
|
||||
raise InterruptedError("User sent a new message")
|
||||
with patch(
|
||||
|
|
@ -1792,7 +1865,8 @@ class TestUtilityHandlers:
|
|||
|
||||
def _patch_mcp_loop(self):
|
||||
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
|
||||
def fake_run(coro, timeout=30):
|
||||
def fake_run(coro_or_factory, timeout=30):
|
||||
coro = coro_or_factory() if callable(coro_or_factory) else coro_or_factory
|
||||
return asyncio.run(coro)
|
||||
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
|
||||
|
||||
|
|
@ -3688,3 +3762,135 @@ class TestRegisterMcpServers:
|
|||
)
|
||||
|
||||
_servers.pop("srv", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for parallel tool call support (port from openai/codex#17667)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMcpParallelToolCalls:
|
||||
"""Tests for the supports_parallel_tool_calls config option."""
|
||||
|
||||
def test_is_mcp_tool_parallel_safe_non_mcp_tool(self):
|
||||
"""Non-MCP tool names always return False."""
|
||||
from tools.mcp_tool import is_mcp_tool_parallel_safe
|
||||
assert is_mcp_tool_parallel_safe("web_search") is False
|
||||
assert is_mcp_tool_parallel_safe("read_file") is False
|
||||
assert is_mcp_tool_parallel_safe("terminal") is False
|
||||
assert is_mcp_tool_parallel_safe("") is False
|
||||
|
||||
def test_is_mcp_tool_parallel_safe_no_servers(self):
|
||||
"""MCP tool from unknown server returns False."""
|
||||
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
|
||||
with _lock:
|
||||
_parallel_safe_servers.clear()
|
||||
assert is_mcp_tool_parallel_safe("mcp_docs_search") is False
|
||||
|
||||
def test_is_mcp_tool_parallel_safe_with_flag(self):
|
||||
"""MCP tool from a parallel-safe server returns True."""
|
||||
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
|
||||
with _lock:
|
||||
_parallel_safe_servers.add("docs")
|
||||
try:
|
||||
assert is_mcp_tool_parallel_safe("mcp_docs_search") is True
|
||||
assert is_mcp_tool_parallel_safe("mcp_docs_read_file") is True
|
||||
# Different server should be False
|
||||
assert is_mcp_tool_parallel_safe("mcp_github_list_repos") is False
|
||||
finally:
|
||||
with _lock:
|
||||
_parallel_safe_servers.discard("docs")
|
||||
|
||||
def test_is_mcp_tool_parallel_safe_server_with_underscores(self):
|
||||
"""Server names containing underscores are correctly matched."""
|
||||
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
|
||||
with _lock:
|
||||
_parallel_safe_servers.add("my_server")
|
||||
try:
|
||||
assert is_mcp_tool_parallel_safe("mcp_my_server_query") is True
|
||||
finally:
|
||||
with _lock:
|
||||
_parallel_safe_servers.discard("my_server")
|
||||
|
||||
def test_is_mcp_tool_parallel_safe_no_tool_suffix(self):
|
||||
"""Tool name that is just 'mcp_{server}' without a tool part returns False."""
|
||||
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
|
||||
with _lock:
|
||||
_parallel_safe_servers.add("docs")
|
||||
try:
|
||||
# "mcp_docs" has no tool part after the server name
|
||||
assert is_mcp_tool_parallel_safe("mcp_docs") is False
|
||||
# "mcp_docs_" has empty tool part
|
||||
assert is_mcp_tool_parallel_safe("mcp_docs_") is False
|
||||
finally:
|
||||
with _lock:
|
||||
_parallel_safe_servers.discard("docs")
|
||||
|
||||
def test_register_mcp_servers_tracks_parallel_flag(self):
|
||||
"""register_mcp_servers populates _parallel_safe_servers from config."""
|
||||
from tools.mcp_tool import (
|
||||
register_mcp_servers, _parallel_safe_servers, _lock,
|
||||
sanitize_mcp_name_component,
|
||||
)
|
||||
fake_config = {
|
||||
"parallel_srv": {
|
||||
"command": "echo",
|
||||
"supports_parallel_tool_calls": True,
|
||||
},
|
||||
"serial_srv": {
|
||||
"command": "echo",
|
||||
"supports_parallel_tool_calls": False,
|
||||
},
|
||||
"default_srv": {
|
||||
"command": "echo",
|
||||
# no supports_parallel_tool_calls key
|
||||
},
|
||||
}
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._ensure_mcp_loop"), \
|
||||
patch("tools.mcp_tool._run_on_mcp_loop"), \
|
||||
patch("tools.mcp_tool._existing_tool_names", return_value=[]):
|
||||
register_mcp_servers(fake_config)
|
||||
|
||||
with _lock:
|
||||
assert sanitize_mcp_name_component("parallel_srv") in _parallel_safe_servers
|
||||
assert sanitize_mcp_name_component("serial_srv") not in _parallel_safe_servers
|
||||
assert sanitize_mcp_name_component("default_srv") not in _parallel_safe_servers
|
||||
# Cleanup
|
||||
_parallel_safe_servers.discard(sanitize_mcp_name_component("parallel_srv"))
|
||||
|
||||
def test_register_mcp_servers_removes_parallel_flag_on_toggle(self):
|
||||
"""Toggling supports_parallel_tool_calls to false removes server from the set."""
|
||||
from tools.mcp_tool import (
|
||||
register_mcp_servers, _parallel_safe_servers, _lock,
|
||||
sanitize_mcp_name_component,
|
||||
)
|
||||
|
||||
# First registration: parallel enabled
|
||||
config_on = {
|
||||
"toggle_srv": {
|
||||
"command": "echo",
|
||||
"supports_parallel_tool_calls": True,
|
||||
},
|
||||
}
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._ensure_mcp_loop"), \
|
||||
patch("tools.mcp_tool._run_on_mcp_loop"), \
|
||||
patch("tools.mcp_tool._existing_tool_names", return_value=[]):
|
||||
register_mcp_servers(config_on)
|
||||
with _lock:
|
||||
assert sanitize_mcp_name_component("toggle_srv") in _parallel_safe_servers
|
||||
|
||||
# Second registration: parallel disabled
|
||||
config_off = {
|
||||
"toggle_srv": {
|
||||
"command": "echo",
|
||||
"supports_parallel_tool_calls": False,
|
||||
},
|
||||
}
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._ensure_mcp_loop"), \
|
||||
patch("tools.mcp_tool._run_on_mcp_loop"), \
|
||||
patch("tools.mcp_tool._existing_tool_names", return_value=[]):
|
||||
register_mcp_servers(config_off)
|
||||
with _lock:
|
||||
assert sanitize_mcp_name_component("toggle_srv") not in _parallel_safe_servers
|
||||
|
|
|
|||
|
|
@ -865,3 +865,138 @@ class TestProcessToolHandler:
|
|||
from tools.process_registry import _handle_process
|
||||
result = json.loads(_handle_process({"action": "unknown_action"}))
|
||||
assert "error" in result
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# format_process_notification + drain_notifications (shared helpers)
|
||||
# =========================================================================
|
||||
|
||||
from tools.process_registry import format_process_notification
|
||||
|
||||
|
||||
def test_format_completion_event():
|
||||
evt = {
|
||||
"type": "completion",
|
||||
"session_id": "proc_abc",
|
||||
"command": "sleep 5",
|
||||
"exit_code": 0,
|
||||
"output": "done",
|
||||
}
|
||||
result = format_process_notification(evt)
|
||||
assert "[IMPORTANT: Background process proc_abc completed" in result
|
||||
assert "exit code 0" in result
|
||||
assert "Command: sleep 5" in result
|
||||
assert "Output:\ndone]" in result
|
||||
|
||||
|
||||
def test_format_watch_match_event():
|
||||
evt = {
|
||||
"type": "watch_match",
|
||||
"session_id": "proc_xyz",
|
||||
"command": "tail -f log",
|
||||
"pattern": "ERROR",
|
||||
"output": "ERROR: disk full",
|
||||
"suppressed": 0,
|
||||
}
|
||||
result = format_process_notification(evt)
|
||||
assert 'watch pattern "ERROR"' in result
|
||||
assert "Matched output:\nERROR: disk full" in result
|
||||
|
||||
|
||||
def test_format_watch_match_with_suppressed():
|
||||
evt = {
|
||||
"type": "watch_match",
|
||||
"session_id": "proc_xyz",
|
||||
"command": "tail -f log",
|
||||
"pattern": "WARN",
|
||||
"output": "WARN: low mem",
|
||||
"suppressed": 3,
|
||||
}
|
||||
result = format_process_notification(evt)
|
||||
assert "3 earlier matches were suppressed" in result
|
||||
|
||||
|
||||
def test_format_watch_disabled_event():
|
||||
evt = {
|
||||
"type": "watch_disabled",
|
||||
"message": "Watch disabled for proc_xyz: too many matches",
|
||||
}
|
||||
result = format_process_notification(evt)
|
||||
assert "[IMPORTANT: Watch disabled for proc_xyz" in result
|
||||
|
||||
|
||||
def test_format_returns_none_for_empty_event():
|
||||
evt = {}
|
||||
result = format_process_notification(evt)
|
||||
assert result is not None
|
||||
assert "unknown" in result
|
||||
|
||||
|
||||
def test_drain_notifications_returns_pending_events():
|
||||
from tools.process_registry import process_registry
|
||||
|
||||
while not process_registry.completion_queue.empty():
|
||||
process_registry.completion_queue.get_nowait()
|
||||
|
||||
process_registry.completion_queue.put({
|
||||
"type": "completion",
|
||||
"session_id": "proc_drain1",
|
||||
"command": "echo hi",
|
||||
"exit_code": 0,
|
||||
"output": "hi",
|
||||
})
|
||||
process_registry.completion_queue.put({
|
||||
"type": "watch_match",
|
||||
"session_id": "proc_drain2",
|
||||
"command": "tail -f x",
|
||||
"pattern": "ERR",
|
||||
"output": "ERR found",
|
||||
"suppressed": 0,
|
||||
})
|
||||
|
||||
try:
|
||||
results = process_registry.drain_notifications()
|
||||
assert len(results) == 2
|
||||
assert results[0][0]["session_id"] == "proc_drain1"
|
||||
assert "proc_drain1 completed" in results[0][1]
|
||||
assert results[1][0]["session_id"] == "proc_drain2"
|
||||
assert "watch pattern" in results[1][1]
|
||||
finally:
|
||||
while not process_registry.completion_queue.empty():
|
||||
process_registry.completion_queue.get_nowait()
|
||||
process_registry._completion_consumed.discard("proc_drain1")
|
||||
process_registry._completion_consumed.discard("proc_drain2")
|
||||
|
||||
|
||||
def test_drain_notifications_skips_consumed():
|
||||
from tools.process_registry import process_registry
|
||||
|
||||
while not process_registry.completion_queue.empty():
|
||||
process_registry.completion_queue.get_nowait()
|
||||
|
||||
process_registry._completion_consumed.add("proc_consumed")
|
||||
process_registry.completion_queue.put({
|
||||
"type": "completion",
|
||||
"session_id": "proc_consumed",
|
||||
"command": "echo done",
|
||||
"exit_code": 0,
|
||||
"output": "done",
|
||||
})
|
||||
|
||||
try:
|
||||
results = process_registry.drain_notifications()
|
||||
assert len(results) == 0
|
||||
finally:
|
||||
process_registry._completion_consumed.discard("proc_consumed")
|
||||
while not process_registry.completion_queue.empty():
|
||||
process_registry.completion_queue.get_nowait()
|
||||
|
||||
|
||||
def test_drain_notifications_empty_queue():
|
||||
from tools.process_registry import process_registry
|
||||
|
||||
while not process_registry.completion_queue.empty():
|
||||
process_registry.completion_queue.get_nowait()
|
||||
|
||||
results = process_registry.drain_notifications()
|
||||
assert results == []
|
||||
|
|
|
|||
|
|
@ -333,6 +333,103 @@ class TestEnsureInstalled:
|
|||
_tirith_mod._resolved_path = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unsupported platform (Windows etc.) — silent fast-path everywhere
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUnsupportedPlatform:
|
||||
"""When _detect_target() returns None (no tirith binary for this OS+arch),
|
||||
the entire subsystem must stay silent: no PATH probes, no download thread,
|
||||
no disk failure marker, no spawn attempts, no CLI banner. Pattern-matching
|
||||
guards still cover the gap; tirith content scanning is just absent."""
|
||||
|
||||
def test_is_platform_supported_true_on_linux_x86_64(self):
|
||||
with patch("tools.tirith_security.platform.system", return_value="Linux"), \
|
||||
patch("tools.tirith_security.platform.machine", return_value="x86_64"):
|
||||
assert _tirith_mod.is_platform_supported() is True
|
||||
|
||||
def test_is_platform_supported_true_on_darwin_arm64(self):
|
||||
with patch("tools.tirith_security.platform.system", return_value="Darwin"), \
|
||||
patch("tools.tirith_security.platform.machine", return_value="arm64"):
|
||||
assert _tirith_mod.is_platform_supported() is True
|
||||
|
||||
def test_is_platform_supported_false_on_windows(self):
|
||||
with patch("tools.tirith_security.platform.system", return_value="Windows"), \
|
||||
patch("tools.tirith_security.platform.machine", return_value="AMD64"):
|
||||
assert _tirith_mod.is_platform_supported() is False
|
||||
|
||||
def test_is_platform_supported_false_on_unknown_arch(self):
|
||||
with patch("tools.tirith_security.platform.system", return_value="Linux"), \
|
||||
patch("tools.tirith_security.platform.machine", return_value="riscv64"):
|
||||
assert _tirith_mod.is_platform_supported() is False
|
||||
|
||||
@patch("tools.tirith_security._load_security_config")
|
||||
def test_ensure_installed_unsupported_returns_none_no_thread(self, mock_cfg):
|
||||
"""Windows: don't start a background install thread, don't write a
|
||||
failure marker — just cache the verdict and return None."""
|
||||
mock_cfg.return_value = {"tirith_enabled": True, "tirith_path": "tirith",
|
||||
"tirith_timeout": 5, "tirith_fail_open": True}
|
||||
_tirith_mod._resolved_path = None
|
||||
with patch("tools.tirith_security.is_platform_supported", return_value=False), \
|
||||
patch("tools.tirith_security.threading.Thread") as MockThread, \
|
||||
patch("tools.tirith_security._mark_install_failed") as mock_mark, \
|
||||
patch("tools.tirith_security.shutil.which") as mock_which:
|
||||
result = ensure_installed()
|
||||
assert result is None
|
||||
MockThread.assert_not_called()
|
||||
mock_mark.assert_not_called()
|
||||
mock_which.assert_not_called()
|
||||
assert _tirith_mod._resolved_path is _tirith_mod._INSTALL_FAILED
|
||||
assert _tirith_mod._install_failure_reason == "unsupported_platform"
|
||||
|
||||
@patch("tools.tirith_security._load_security_config")
|
||||
def test_check_command_security_unsupported_allows_silently(self, mock_cfg):
|
||||
"""Windows: skip the resolver and spawn entirely — return allow with
|
||||
an empty summary so callers can't accidentally surface 'tirith
|
||||
unavailable' messaging to the user."""
|
||||
mock_cfg.return_value = {"tirith_enabled": True, "tirith_path": "tirith",
|
||||
"tirith_timeout": 5, "tirith_fail_open": True}
|
||||
with patch("tools.tirith_security.is_platform_supported", return_value=False), \
|
||||
patch("tools.tirith_security.subprocess.run") as mock_run, \
|
||||
patch("tools.tirith_security._resolve_tirith_path") as mock_resolve:
|
||||
result = check_command_security("rm -rf /")
|
||||
assert result == {"action": "allow", "findings": [], "summary": ""}
|
||||
mock_run.assert_not_called()
|
||||
mock_resolve.assert_not_called()
|
||||
|
||||
@patch("tools.tirith_security._load_security_config")
|
||||
def test_resolve_path_unsupported_caches_failure_without_probing(self, mock_cfg):
|
||||
"""The per-command resolver must also short-circuit on Windows so
|
||||
long-running gateways don't churn through `shutil.which` and disk
|
||||
I/O for every scanned command."""
|
||||
mock_cfg.return_value = {"tirith_enabled": True, "tirith_path": "tirith",
|
||||
"tirith_timeout": 5, "tirith_fail_open": True}
|
||||
_tirith_mod._resolved_path = None
|
||||
with patch("tools.tirith_security.is_platform_supported", return_value=False), \
|
||||
patch("tools.tirith_security.shutil.which") as mock_which:
|
||||
result = _tirith_mod._resolve_tirith_path("tirith")
|
||||
assert result == "tirith"
|
||||
mock_which.assert_not_called()
|
||||
assert _tirith_mod._resolved_path is _tirith_mod._INSTALL_FAILED
|
||||
assert _tirith_mod._install_failure_reason == "unsupported_platform"
|
||||
|
||||
@patch("tools.tirith_security._load_security_config")
|
||||
def test_explicit_path_still_honored_on_unsupported_platform(self, mock_cfg):
|
||||
"""If a user explicitly configured a tirith_path (e.g. they built it
|
||||
themselves under WSL), the unsupported-platform short-circuit must
|
||||
NOT override that — explicit config wins."""
|
||||
mock_cfg.return_value = {"tirith_enabled": True,
|
||||
"tirith_path": "/opt/custom/tirith",
|
||||
"tirith_timeout": 5, "tirith_fail_open": True}
|
||||
_tirith_mod._resolved_path = None
|
||||
with patch("tools.tirith_security.is_platform_supported", return_value=False), \
|
||||
patch("os.path.isfile", return_value=True), \
|
||||
patch("os.access", return_value=True):
|
||||
result = _tirith_mod._resolve_tirith_path("/opt/custom/tirith")
|
||||
assert result == "/opt/custom/tirith"
|
||||
assert _tirith_mod._resolved_path == "/opt/custom/tirith"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Failed download caches the miss (Finding #1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -1007,3 +1104,120 @@ class TestHermesHomeIsolation:
|
|||
expected = os.path.join(os.path.expanduser("~"), ".hermes")
|
||||
result = _get_hermes_home()
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Warn-once dedupe (issue: tirith spawn failed spamming on Windows)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSpawnWarningDedup:
|
||||
"""When tirith isn't installed yet (background install in flight, or
|
||||
install marked failed), every terminal command spammed an identical
|
||||
``tirith spawn failed: [WinError 2]`` warning to ``errors.log``. The
|
||||
dedupe set in ``_warn_once`` collapses repeats by ``(exc class, errno)``
|
||||
while still surfacing the first occurrence so users see the failure.
|
||||
"""
|
||||
|
||||
@patch("tools.tirith_security.subprocess.run")
|
||||
@patch("tools.tirith_security._load_security_config")
|
||||
def test_repeated_spawn_failure_logs_once(self, mock_cfg, mock_run, caplog):
|
||||
mock_cfg.return_value = {
|
||||
"tirith_enabled": True, "tirith_path": "tirith",
|
||||
"tirith_timeout": 5, "tirith_fail_open": True,
|
||||
}
|
||||
mock_run.side_effect = FileNotFoundError("[WinError 2]")
|
||||
# Fresh dedupe state — clear any keys left by other tests.
|
||||
_tirith_mod._reset_spawn_warning_state()
|
||||
|
||||
with caplog.at_level("WARNING", logger="tools.tirith_security"):
|
||||
for _ in range(15):
|
||||
result = check_command_security("echo hi")
|
||||
# Behavior must remain the same on every call —
|
||||
# fail-open allow, with the exception captured in summary.
|
||||
assert result["action"] == "allow"
|
||||
assert "unavailable" in result["summary"]
|
||||
|
||||
spawn_warnings = [
|
||||
rec for rec in caplog.records
|
||||
if "tirith spawn failed" in rec.message
|
||||
]
|
||||
assert len(spawn_warnings) == 1, (
|
||||
f"expected exactly 1 spawn-failed warning across 15 commands, "
|
||||
f"got {len(spawn_warnings)}: {[r.message for r in spawn_warnings]}"
|
||||
)
|
||||
|
||||
@patch("tools.tirith_security.subprocess.run")
|
||||
@patch("tools.tirith_security._load_security_config")
|
||||
def test_distinct_exception_types_each_log_once(self, mock_cfg, mock_run, caplog):
|
||||
"""``FileNotFoundError`` and ``PermissionError`` are distinct
|
||||
failure modes and each deserves its own first-occurrence log
|
||||
line; the dedupe key includes the exception class."""
|
||||
mock_cfg.return_value = {
|
||||
"tirith_enabled": True, "tirith_path": "tirith",
|
||||
"tirith_timeout": 5, "tirith_fail_open": True,
|
||||
}
|
||||
_tirith_mod._reset_spawn_warning_state()
|
||||
|
||||
with caplog.at_level("WARNING", logger="tools.tirith_security"):
|
||||
mock_run.side_effect = FileNotFoundError("[WinError 2]")
|
||||
for _ in range(3):
|
||||
check_command_security("a")
|
||||
mock_run.side_effect = PermissionError("denied")
|
||||
for _ in range(3):
|
||||
check_command_security("b")
|
||||
|
||||
spawn_warnings = [
|
||||
rec for rec in caplog.records
|
||||
if "tirith spawn failed" in rec.message
|
||||
]
|
||||
assert len(spawn_warnings) == 2, (
|
||||
f"expected 2 distinct first-occurrence warnings, "
|
||||
f"got {len(spawn_warnings)}"
|
||||
)
|
||||
|
||||
@patch("tools.tirith_security.subprocess.run")
|
||||
@patch("tools.tirith_security._load_security_config")
|
||||
def test_repeated_timeout_logs_once(self, mock_cfg, mock_run, caplog):
|
||||
mock_cfg.return_value = {
|
||||
"tirith_enabled": True, "tirith_path": "tirith",
|
||||
"tirith_timeout": 5, "tirith_fail_open": True,
|
||||
}
|
||||
mock_run.side_effect = subprocess.TimeoutExpired(cmd="tirith", timeout=5)
|
||||
_tirith_mod._reset_spawn_warning_state()
|
||||
|
||||
with caplog.at_level("WARNING", logger="tools.tirith_security"):
|
||||
for _ in range(10):
|
||||
result = check_command_security("slow")
|
||||
assert result["action"] == "allow"
|
||||
|
||||
timeout_warnings = [
|
||||
rec for rec in caplog.records
|
||||
if "tirith timed out" in rec.message
|
||||
]
|
||||
assert len(timeout_warnings) == 1
|
||||
|
||||
@patch("tools.tirith_security._load_security_config")
|
||||
def test_path_none_logs_once(self, mock_cfg, caplog):
|
||||
"""``_resolve_tirith_path`` returning ``None`` (explicit path set
|
||||
but resolver returned None — unusual) should not spam the log
|
||||
either."""
|
||||
mock_cfg.return_value = {
|
||||
"tirith_enabled": True, "tirith_path": "tirith",
|
||||
"tirith_timeout": 5, "tirith_fail_open": True,
|
||||
}
|
||||
_tirith_mod._reset_spawn_warning_state()
|
||||
|
||||
with patch(
|
||||
"tools.tirith_security._resolve_tirith_path", return_value=None
|
||||
):
|
||||
with caplog.at_level("WARNING", logger="tools.tirith_security"):
|
||||
for _ in range(10):
|
||||
result = check_command_security("echo")
|
||||
assert result["action"] == "allow"
|
||||
assert "tirith path unavailable" in result["summary"]
|
||||
|
||||
none_warnings = [
|
||||
rec for rec in caplog.records
|
||||
if "tirith path resolved to None" in rec.message
|
||||
]
|
||||
assert len(none_warnings) == 1
|
||||
|
|
|
|||
|
|
@ -170,7 +170,15 @@ class TestTranscribeCallSitesReadDotenv:
|
|||
assert seen_keys == ["mistral-dotenv-key"]
|
||||
|
||||
def test_transcribe_xai_forwards_dotenv_key(self):
|
||||
"""xAI STT now resolves credentials through ``tools.xai_http`` so the
|
||||
OAuth bearer wins when present and ``XAI_API_KEY`` is the fallback.
|
||||
Patch the resolver's ``get_env_value`` to simulate a dotenv-only key
|
||||
and confirm it reaches the HTTP call. The per-call-site
|
||||
``transcription_tools.get_env_value`` is still consulted for the
|
||||
``XAI_STT_BASE_URL`` override (covered by ``test_custom_base_url``).
|
||||
"""
|
||||
from tools import transcription_tools as tt
|
||||
from tools import xai_http
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
|
|
@ -183,15 +191,12 @@ class TestTranscribeCallSitesReadDotenv:
|
|||
response.json.return_value = {"text": "hello"}
|
||||
return response
|
||||
|
||||
# get_env_value is consulted for both XAI_API_KEY and XAI_STT_BASE_URL.
|
||||
# Return the key for the first call, None for base-url override
|
||||
# (so it defaults to the module-level XAI_STT_BASE_URL).
|
||||
def fake_get_env_value(name, default=None):
|
||||
if name == "XAI_API_KEY":
|
||||
return "xai-dotenv-key"
|
||||
return None
|
||||
|
||||
with patch.object(tt, "get_env_value", side_effect=fake_get_env_value), \
|
||||
with patch.object(xai_http, "get_env_value", side_effect=fake_get_env_value), \
|
||||
patch("requests.post", side_effect=fake_post), \
|
||||
patch("builtins.open", MagicMock()):
|
||||
result = tt._transcribe_xai("/tmp/fake.mp3", "grok-stt")
|
||||
|
|
|
|||
|
|
@ -57,7 +57,12 @@ class TestDotenvFallbackPerProvider:
|
|||
mock_import.return_value.assert_called_once_with(api_key="el-dotenv-key")
|
||||
|
||||
def test_xai_reads_dotenv_key(self, tmp_path):
|
||||
"""xAI TTS now resolves credentials through ``tools.xai_http``; the
|
||||
dotenv fallback contract from #17140 is preserved by patching the
|
||||
resolver's ``get_env_value`` rather than ``tts_tool.get_env_value``.
|
||||
"""
|
||||
from tools import tts_tool
|
||||
from tools import xai_http
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
|
|
@ -69,7 +74,7 @@ class TestDotenvFallbackPerProvider:
|
|||
response.raise_for_status = MagicMock()
|
||||
return response
|
||||
|
||||
with patch.object(tts_tool, "get_env_value", return_value="xai-dotenv-key"), \
|
||||
with patch.object(xai_http, "get_env_value", return_value="xai-dotenv-key"), \
|
||||
patch("requests.post", side_effect=fake_post):
|
||||
tts_tool._generate_xai_tts("hi", str(tmp_path / "out.mp3"), {})
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,14 @@ class TestIsSafeUrl:
|
|||
]):
|
||||
assert is_safe_url("https://example.com/image.png") is True
|
||||
|
||||
def test_ftp_scheme_blocked(self):
|
||||
"""Only http/https should be allowed for fetch tools."""
|
||||
assert is_safe_url("ftp://example.com/file.txt") is False
|
||||
|
||||
def test_missing_scheme_blocked(self):
|
||||
"""Bare host/path should be rejected to avoid ambiguous handling."""
|
||||
assert is_safe_url("example.com/path") is False
|
||||
|
||||
def test_localhost_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("127.0.0.1", 0)),
|
||||
|
|
|
|||
438
tests/tools/test_x_search_tool.py
Normal file
438
tests/tools/test_x_search_tool.py
Normal file
|
|
@ -0,0 +1,438 @@
|
|||
"""Tests for the X (Twitter) Search tool backed by xAI Responses API.
|
||||
|
||||
Covers:
|
||||
- HTTP request shape (URL, headers, payload, model from config)
|
||||
- Handle filter validation (allowed vs excluded mutual exclusion)
|
||||
- Inline url_citation extraction from message annotations
|
||||
- Structured error handling (4xx with code, 5xx retry, ReadTimeout retry)
|
||||
- Credential resolution: API key path, OAuth path, both-set preference, none-set
|
||||
- check_x_search_requirements gating in registry
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, payload, *, status_code=200, text=None):
|
||||
self._payload = payload
|
||||
self.status_code = status_code
|
||||
self.text = text if text is not None else json.dumps(payload)
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status_code >= 400:
|
||||
err = requests.HTTPError(f"{self.status_code} Client Error")
|
||||
err.response = self
|
||||
raise err
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Original PR #10786 test coverage (HTTP shape, handle validation, citations,
|
||||
# retry behavior) — preserved verbatim. Uses XAI_API_KEY env var via the
|
||||
# default resolver path.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_x_search_posts_responses_request(monkeypatch):
|
||||
from tools.x_search_tool import x_search_tool
|
||||
from hermes_cli import __version__
|
||||
|
||||
captured = {}
|
||||
|
||||
def _fake_post(url, headers=None, json=None, timeout=None):
|
||||
captured["url"] = url
|
||||
captured["headers"] = headers
|
||||
captured["json"] = json
|
||||
captured["timeout"] = timeout
|
||||
return _FakeResponse(
|
||||
{
|
||||
"output_text": "People on X are discussing xAI's latest launch.",
|
||||
"citations": [{"url": "https://x.com/example/status/1", "title": "Example post"}],
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
monkeypatch.setattr("requests.post", _fake_post)
|
||||
|
||||
result = json.loads(
|
||||
x_search_tool(
|
||||
query="What are people saying about xAI on X?",
|
||||
allowed_x_handles=["xai", "@grok"],
|
||||
from_date="2026-04-01",
|
||||
to_date="2026-04-10",
|
||||
enable_image_understanding=True,
|
||||
)
|
||||
)
|
||||
|
||||
tool_def = captured["json"]["tools"][0]
|
||||
assert captured["url"] == "https://api.x.ai/v1/responses"
|
||||
assert captured["headers"]["User-Agent"] == f"Hermes-Agent/{__version__}"
|
||||
assert captured["json"]["model"] == "grok-4.20-reasoning"
|
||||
assert captured["json"]["store"] is False
|
||||
assert tool_def["type"] == "x_search"
|
||||
assert tool_def["allowed_x_handles"] == ["xai", "grok"]
|
||||
assert tool_def["from_date"] == "2026-04-01"
|
||||
assert tool_def["to_date"] == "2026-04-10"
|
||||
assert tool_def["enable_image_understanding"] is True
|
||||
assert result["success"] is True
|
||||
assert result["answer"] == "People on X are discussing xAI's latest launch."
|
||||
|
||||
|
||||
def test_x_search_rejects_conflicting_handle_filters(monkeypatch):
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
|
||||
result = json.loads(
|
||||
x_search_tool(
|
||||
query="latest xAI discussion",
|
||||
allowed_x_handles=["xai"],
|
||||
excluded_x_handles=["grok"],
|
||||
)
|
||||
)
|
||||
|
||||
assert result["error"] == "allowed_x_handles and excluded_x_handles cannot be used together"
|
||||
|
||||
|
||||
def test_x_search_extracts_inline_url_citations(monkeypatch):
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
def _fake_post(url, headers=None, json=None, timeout=None):
|
||||
return _FakeResponse(
|
||||
{
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "xAI posted an update on X.",
|
||||
"annotations": [
|
||||
{
|
||||
"type": "url_citation",
|
||||
"url": "https://x.com/xai/status/123",
|
||||
"title": "xAI update",
|
||||
"start_index": 0,
|
||||
"end_index": 3,
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
monkeypatch.setattr("requests.post", _fake_post)
|
||||
|
||||
result = json.loads(x_search_tool(query="latest post from xai"))
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["answer"] == "xAI posted an update on X."
|
||||
assert result["inline_citations"] == [
|
||||
{
|
||||
"url": "https://x.com/xai/status/123",
|
||||
"title": "xAI update",
|
||||
"start_index": 0,
|
||||
"end_index": 3,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_x_search_returns_structured_http_error(monkeypatch):
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
class _FailingResponse:
|
||||
status_code = 403
|
||||
text = '{"code":"forbidden","error":"x_search is not enabled for this model"}'
|
||||
|
||||
def json(self):
|
||||
return {
|
||||
"code": "forbidden",
|
||||
"error": "x_search is not enabled for this model",
|
||||
}
|
||||
|
||||
def raise_for_status(self):
|
||||
err = requests.HTTPError("403 Client Error: Forbidden")
|
||||
err.response = self
|
||||
raise err
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
monkeypatch.setattr("requests.post", lambda *a, **k: _FailingResponse())
|
||||
|
||||
result = json.loads(x_search_tool(query="latest xai discussion"))
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["provider"] == "xai"
|
||||
assert result["tool"] == "x_search"
|
||||
assert result["error_type"] == "HTTPError"
|
||||
assert result["error"] == "forbidden: x_search is not enabled for this model"
|
||||
|
||||
|
||||
def test_x_search_retries_read_timeout_then_succeeds(monkeypatch):
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
calls = {"count": 0}
|
||||
|
||||
def _fake_post(url, headers=None, json=None, timeout=None):
|
||||
calls["count"] += 1
|
||||
if calls["count"] == 1:
|
||||
raise requests.ReadTimeout("timed out")
|
||||
return _FakeResponse(
|
||||
{
|
||||
"output_text": "Recovered after retry.",
|
||||
"citations": [],
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
monkeypatch.setattr("requests.post", _fake_post)
|
||||
monkeypatch.setattr("tools.x_search_tool.time.sleep", lambda *_: None)
|
||||
|
||||
result = json.loads(x_search_tool(query="grok xai"))
|
||||
|
||||
assert calls["count"] == 2
|
||||
assert result["success"] is True
|
||||
assert result["answer"] == "Recovered after retry."
|
||||
|
||||
|
||||
def test_x_search_retries_5xx_then_succeeds(monkeypatch):
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
calls = {"count": 0}
|
||||
|
||||
def _fake_post(url, headers=None, json=None, timeout=None):
|
||||
calls["count"] += 1
|
||||
if calls["count"] == 1:
|
||||
return _FakeResponse(
|
||||
{"code": "Internal error", "error": "Service temporarily unavailable."},
|
||||
status_code=500,
|
||||
)
|
||||
return _FakeResponse({"output_text": "Recovered after 5xx retry."})
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
monkeypatch.setattr("requests.post", _fake_post)
|
||||
monkeypatch.setattr("tools.x_search_tool.time.sleep", lambda *_: None)
|
||||
|
||||
result = json.loads(x_search_tool(query="grok xai"))
|
||||
|
||||
assert calls["count"] == 2
|
||||
assert result["success"] is True
|
||||
assert result["answer"] == "Recovered after 5xx retry."
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Credential-resolution coverage — the OAuth-or-API-key gating contract.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _no_xai_env(monkeypatch):
|
||||
"""Strip any XAI_* env vars so the resolver doesn't see a leaked dev key."""
|
||||
for var in ("XAI_API_KEY", "XAI_BASE_URL", "HERMES_XAI_BASE_URL"):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
def test_x_search_uses_xai_oauth_when_only_oauth_available(monkeypatch):
|
||||
"""OAuth-only user: credential_source should be ``xai-oauth``."""
|
||||
from tools.registry import invalidate_check_fn_cache
|
||||
from tools.x_search_tool import check_x_search_requirements, x_search_tool
|
||||
|
||||
_no_xai_env(monkeypatch)
|
||||
|
||||
def _fake_resolve():
|
||||
return {
|
||||
"provider": "xai-oauth",
|
||||
"api_key": "oauth-bearer-token",
|
||||
"base_url": "https://api.x.ai/v1",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tools.x_search_tool.resolve_xai_http_credentials", _fake_resolve
|
||||
)
|
||||
invalidate_check_fn_cache()
|
||||
|
||||
assert check_x_search_requirements() is True
|
||||
|
||||
captured = {}
|
||||
|
||||
def _fake_post(url, headers=None, json=None, timeout=None):
|
||||
captured["headers"] = headers
|
||||
return _FakeResponse({"output_text": "Found posts via OAuth."})
|
||||
|
||||
monkeypatch.setattr("requests.post", _fake_post)
|
||||
|
||||
result = json.loads(x_search_tool(query="anything about xai"))
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["credential_source"] == "xai-oauth"
|
||||
assert captured["headers"]["Authorization"] == "Bearer oauth-bearer-token"
|
||||
|
||||
|
||||
def test_x_search_uses_api_key_when_only_xai_api_key_set(monkeypatch):
|
||||
"""API-key-only user: credential_source should be ``xai``."""
|
||||
from tools.registry import invalidate_check_fn_cache
|
||||
from tools.x_search_tool import check_x_search_requirements, x_search_tool
|
||||
|
||||
_no_xai_env(monkeypatch)
|
||||
|
||||
def _fake_resolve():
|
||||
# Real ``resolve_xai_http_credentials`` returns ``"xai"`` when it
|
||||
# falls through to the XAI_API_KEY env var path.
|
||||
return {
|
||||
"provider": "xai",
|
||||
"api_key": "raw-api-key",
|
||||
"base_url": "https://api.x.ai/v1",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tools.x_search_tool.resolve_xai_http_credentials", _fake_resolve
|
||||
)
|
||||
invalidate_check_fn_cache()
|
||||
|
||||
assert check_x_search_requirements() is True
|
||||
|
||||
captured = {}
|
||||
|
||||
def _fake_post(url, headers=None, json=None, timeout=None):
|
||||
captured["headers"] = headers
|
||||
return _FakeResponse({"output_text": "Found posts via API key."})
|
||||
|
||||
monkeypatch.setattr("requests.post", _fake_post)
|
||||
|
||||
result = json.loads(x_search_tool(query="anything"))
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["credential_source"] == "xai"
|
||||
assert captured["headers"]["Authorization"] == "Bearer raw-api-key"
|
||||
|
||||
|
||||
def test_x_search_prefers_oauth_when_both_available(monkeypatch):
|
||||
"""Both credentials present: OAuth wins (matches Teknium's billing preference).
|
||||
|
||||
The real ordering is implemented in ``tools.xai_http.resolve_xai_http_credentials``
|
||||
— OAuth runtime first, fallback OAuth resolver second, ``XAI_API_KEY`` third.
|
||||
This test exercises the contract by having the resolver return the OAuth
|
||||
bearer (the ``xai-oauth`` ``provider`` tag is the marker).
|
||||
"""
|
||||
from tools.registry import invalidate_check_fn_cache
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "raw-api-key")
|
||||
|
||||
# Mimic xai_http's preference: OAuth wins, so we return the OAuth tuple
|
||||
# even though XAI_API_KEY is also set.
|
||||
def _fake_resolve():
|
||||
return {
|
||||
"provider": "xai-oauth",
|
||||
"api_key": "oauth-bearer-token",
|
||||
"base_url": "https://api.x.ai/v1",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tools.x_search_tool.resolve_xai_http_credentials", _fake_resolve
|
||||
)
|
||||
invalidate_check_fn_cache()
|
||||
|
||||
captured = {}
|
||||
|
||||
def _fake_post(url, headers=None, json=None, timeout=None):
|
||||
captured["headers"] = headers
|
||||
return _FakeResponse({"output_text": "OAuth preferred."})
|
||||
|
||||
monkeypatch.setattr("requests.post", _fake_post)
|
||||
|
||||
result = json.loads(x_search_tool(query="anything"))
|
||||
|
||||
assert result["credential_source"] == "xai-oauth"
|
||||
assert captured["headers"]["Authorization"] == "Bearer oauth-bearer-token"
|
||||
|
||||
|
||||
def test_x_search_returns_tool_error_when_no_credentials(monkeypatch):
|
||||
"""No credentials anywhere: tool returns a clear error, not a 401 from xAI."""
|
||||
from tools.registry import invalidate_check_fn_cache
|
||||
from tools.x_search_tool import check_x_search_requirements, x_search_tool
|
||||
|
||||
_no_xai_env(monkeypatch)
|
||||
|
||||
def _fake_resolve():
|
||||
return {
|
||||
"provider": "xai",
|
||||
"api_key": "",
|
||||
"base_url": "https://api.x.ai/v1",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tools.x_search_tool.resolve_xai_http_credentials", _fake_resolve
|
||||
)
|
||||
invalidate_check_fn_cache()
|
||||
|
||||
assert check_x_search_requirements() is False
|
||||
|
||||
# If a model somehow invokes the tool despite a False check_fn, the call
|
||||
# surfaces a friendly error rather than an HTTP exception.
|
||||
result = x_search_tool(query="anything")
|
||||
assert "No xAI credentials available" in result
|
||||
assert "hermes auth add xai-oauth" in result
|
||||
|
||||
|
||||
def test_x_search_check_fn_false_when_resolver_raises(monkeypatch):
|
||||
"""Resolver exceptions (e.g. expired token + failed refresh) gate the tool out."""
|
||||
from tools.registry import invalidate_check_fn_cache
|
||||
from tools.x_search_tool import check_x_search_requirements
|
||||
|
||||
_no_xai_env(monkeypatch)
|
||||
|
||||
def _boom():
|
||||
raise RuntimeError("token revoked and refresh failed")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tools.x_search_tool.resolve_xai_http_credentials", _boom
|
||||
)
|
||||
invalidate_check_fn_cache()
|
||||
|
||||
assert check_x_search_requirements() is False
|
||||
|
||||
|
||||
def test_x_search_honors_config_model_and_timeout(monkeypatch, tmp_path):
|
||||
"""``x_search.model`` and ``x_search.timeout_seconds`` override the defaults."""
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
|
||||
# Patch the in-module config loader so tests don't touch ~/.hermes/config.yaml.
|
||||
monkeypatch.setattr(
|
||||
"tools.x_search_tool._load_x_search_config",
|
||||
lambda: {"model": "grok-custom-test", "timeout_seconds": 45, "retries": 0},
|
||||
)
|
||||
|
||||
captured = {}
|
||||
|
||||
def _fake_post(url, headers=None, json=None, timeout=None):
|
||||
captured["model"] = json["model"]
|
||||
captured["timeout"] = timeout
|
||||
return _FakeResponse({"output_text": "Custom model OK."})
|
||||
|
||||
monkeypatch.setattr("requests.post", _fake_post)
|
||||
|
||||
result = json.loads(x_search_tool(query="anything"))
|
||||
|
||||
assert result["success"] is True
|
||||
assert captured["model"] == "grok-custom-test"
|
||||
assert captured["timeout"] == 45
|
||||
|
||||
|
||||
def test_x_search_registered_in_registry_with_check_fn():
|
||||
"""The tool is registered under the x_search toolset with the gating check_fn."""
|
||||
import tools.x_search_tool # noqa: F401 — ensures registration runs
|
||||
from tools.registry import registry
|
||||
|
||||
entry = registry.get_entry("x_search")
|
||||
assert entry is not None
|
||||
assert entry.toolset == "x_search"
|
||||
assert entry.check_fn is not None
|
||||
assert entry.check_fn.__name__ == "check_x_search_requirements"
|
||||
assert "XAI_API_KEY" in entry.requires_env
|
||||
assert entry.emoji == "🐦"
|
||||
Loading…
Add table
Add a link
Reference in a new issue