Merge branch 'main' into fix/tui-provider-resolution

This commit is contained in:
Kaio 2026-04-22 11:47:49 -07:00 committed by GitHub
commit ec374c0599
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
625 changed files with 68938 additions and 11055 deletions

View file

@ -0,0 +1,170 @@
"""Tests for GHSA-96vc-wcxf-jjff and GHSA-qg5c-hvr5-hjgr.
Two related ACP approval-flow issues:
- 96vc: ACP didn't set HERMES_EXEC_ASK, so `check_all_command_guards`
took the non-interactive auto-approve path and never consulted the
ACP-supplied callback.
- qg5c: `_approval_callback` was a module-global in terminal_tool;
overlapping ACP sessions overwrote each other's callback slot.
Both fixed together by:
1. Setting HERMES_EXEC_ASK inside _run_agent (wraps the agent call).
2. Storing the callback in thread-local state so concurrent executor
threads don't collide.
"""
import os
import threading
from unittest.mock import MagicMock
import pytest
class TestThreadLocalApprovalCallback:
"""GHSA-qg5c-hvr5-hjgr: set_approval_callback must be per-thread so
concurrent ACP sessions don't stomp on each other's handlers."""
def test_set_and_get_in_same_thread(self):
from tools.terminal_tool import (
set_approval_callback,
_get_approval_callback,
)
cb1 = lambda cmd, desc: "once" # noqa: E731
set_approval_callback(cb1)
assert _get_approval_callback() is cb1
def test_callback_not_visible_in_different_thread(self):
"""Thread A's callback is NOT visible to Thread B."""
from tools.terminal_tool import (
set_approval_callback,
_get_approval_callback,
)
cb_a = lambda cmd, desc: "thread_a" # noqa: E731
cb_b = lambda cmd, desc: "thread_b" # noqa: E731
seen_in_a = []
seen_in_b = []
def thread_a():
set_approval_callback(cb_a)
# Pause so thread B has time to set its own callback
import time
time.sleep(0.05)
seen_in_a.append(_get_approval_callback())
def thread_b():
set_approval_callback(cb_b)
import time
time.sleep(0.05)
seen_in_b.append(_get_approval_callback())
ta = threading.Thread(target=thread_a)
tb = threading.Thread(target=thread_b)
ta.start()
tb.start()
ta.join()
tb.join()
# Each thread must see ONLY its own callback — not the other's
assert seen_in_a == [cb_a]
assert seen_in_b == [cb_b]
def test_main_thread_callback_not_leaked_to_worker(self):
"""A callback set in the main thread does NOT leak into a
freshly-spawned worker thread."""
from tools.terminal_tool import (
set_approval_callback,
_get_approval_callback,
)
cb_main = lambda cmd, desc: "main" # noqa: E731
set_approval_callback(cb_main)
worker_saw = []
def worker():
worker_saw.append(_get_approval_callback())
t = threading.Thread(target=worker)
t.start()
t.join()
# Worker thread has no callback set — TLS is empty for it
assert worker_saw == [None]
# Main thread still has its callback
assert _get_approval_callback() is cb_main
def test_sudo_password_callback_also_thread_local(self):
"""Same protection applies to the sudo password callback."""
from tools.terminal_tool import (
set_sudo_password_callback,
_get_sudo_password_callback,
)
cb_main = lambda: "main-password" # noqa: E731
set_sudo_password_callback(cb_main)
worker_saw = []
def worker():
worker_saw.append(_get_sudo_password_callback())
t = threading.Thread(target=worker)
t.start()
t.join()
assert worker_saw == [None]
assert _get_sudo_password_callback() is cb_main
class TestAcpExecAskGate:
"""GHSA-96vc-wcxf-jjff: ACP's _run_agent must set HERMES_INTERACTIVE so
that tools.approval.check_all_command_guards takes the CLI-interactive
path (consults the registered callback via prompt_dangerous_approval)
instead of the non-interactive auto-approve shortcut.
(HERMES_EXEC_ASK takes the gateway-queue path which requires a
notify_cb registered in _gateway_notify_cbs not applicable to ACP,
which uses a direct callback shape.)"""
def test_interactive_env_var_routes_to_callback(self, monkeypatch):
"""When HERMES_INTERACTIVE is set and an approval callback is
registered, a dangerous command must route through the callback."""
# Clean env
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
monkeypatch.delenv("HERMES_EXEC_ASK", raising=False)
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
from tools.approval import check_all_command_guards
called_with = []
def fake_cb(command, description, *, allow_permanent=True):
called_with.append((command, description))
return "once"
# Without HERMES_INTERACTIVE: takes auto-approve path, callback NOT called
result = check_all_command_guards(
"rm -rf /tmp/test-exec-ask", "local", approval_callback=fake_cb,
)
assert result["approved"] is True
assert called_with == [], (
"without HERMES_INTERACTIVE the non-interactive auto-approve "
"path should fire without consulting the callback"
)
# With HERMES_INTERACTIVE: callback IS called, approval flows through it
monkeypatch.setenv("HERMES_INTERACTIVE", "1")
called_with.clear()
result = check_all_command_guards(
"rm -rf /tmp/test-exec-ask", "local", approval_callback=fake_cb,
)
assert called_with, (
"with HERMES_INTERACTIVE the approval path should consult the "
"registered callback — this was the ACP bypass in "
"GHSA-96vc-wcxf-jjff"
)
assert result["approved"] is True

View file

@ -73,3 +73,17 @@ class TestApprovalMapping:
result = cb("rm -rf /", "dangerous")
assert result == "deny"
def test_approval_none_response_returns_deny(self):
"""When request_permission resolves to None, the callback should return 'deny'."""
loop = MagicMock(spec=asyncio.AbstractEventLoop)
mock_rp = MagicMock(name="request_permission")
future = MagicMock(spec=Future)
future.result.return_value = None
with patch("acp_adapter.permissions.asyncio.run_coroutine_threadsafe", return_value=future):
cb = make_approval_callback(mock_rp, loop, session_id="s1", timeout=1.0)
result = cb("echo hi", "demo")
assert result == "deny"

View file

@ -0,0 +1,210 @@
"""Tests for acp_adapter.entry._BenignProbeMethodFilter.
Covers both the isolated filter logic and the full end-to-end path where a
client sends a bare JSON-RPC ``ping`` request over stdio and the acp runtime
surfaces the resulting ``RequestError`` via ``logging.exception("Background
task failed", ...)``.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
from io import StringIO
import pytest
from acp.exceptions import RequestError
from acp_adapter.entry import _BenignProbeMethodFilter
# -- Unit tests on the filter itself ----------------------------------------
def _make_record(msg: str, exc: BaseException | None) -> logging.LogRecord:
record = logging.LogRecord(
name="root",
level=logging.ERROR,
pathname=__file__,
lineno=0,
msg=msg,
args=(),
exc_info=(type(exc), exc, exc.__traceback__) if exc else None,
)
return record
def _bake_tb(exc: BaseException) -> BaseException:
try:
raise exc
except BaseException as e: # noqa: BLE001
return e
@pytest.mark.parametrize("method", ["ping", "health", "healthcheck"])
def test_filter_suppresses_benign_probe(method: str) -> None:
f = _BenignProbeMethodFilter()
exc = _bake_tb(RequestError.method_not_found(method))
record = _make_record("Background task failed", exc)
assert f.filter(record) is False
def test_filter_allows_real_method_not_found() -> None:
f = _BenignProbeMethodFilter()
exc = _bake_tb(RequestError.method_not_found("session/custom"))
record = _make_record("Background task failed", exc)
assert f.filter(record) is True
def test_filter_allows_non_request_error() -> None:
f = _BenignProbeMethodFilter()
exc = _bake_tb(RuntimeError("boom"))
record = _make_record("Background task failed", exc)
assert f.filter(record) is True
def test_filter_allows_different_message_even_for_ping() -> None:
"""Only 'Background task failed' is muted — other messages pass through."""
f = _BenignProbeMethodFilter()
exc = _bake_tb(RequestError.method_not_found("ping"))
record = _make_record("Some other context", exc)
assert f.filter(record) is True
def test_filter_allows_request_error_with_different_code() -> None:
f = _BenignProbeMethodFilter()
exc = _bake_tb(RequestError.invalid_params({"method": "ping"}))
record = _make_record("Background task failed", exc)
assert f.filter(record) is True
def test_filter_allows_log_without_exc_info() -> None:
f = _BenignProbeMethodFilter()
record = _make_record("Background task failed", None)
assert f.filter(record) is True
# -- End-to-end: drive a real JSON-RPC `ping` through acp.run_agent ---------
class _FakeAgent:
"""Minimal acp.Agent stub — we only need the router to build."""
async def initialize(self, **kwargs): # noqa: ANN003
from acp.schema import AgentCapabilities, InitializeResponse
return InitializeResponse(protocol_version=1, agent_capabilities=AgentCapabilities())
async def new_session(self, cwd, mcp_servers=None, **kwargs): # noqa: ANN001, ANN003
from acp.schema import NewSessionResponse
return NewSessionResponse(session_id="test")
async def prompt(self, session_id, prompt, **kwargs): # noqa: ANN001, ANN003
from acp.schema import PromptResponse
return PromptResponse(stop_reason="end_turn")
async def cancel(self, session_id, **kwargs): # noqa: ANN001, ANN003
pass
async def authenticate(self, **kwargs): # noqa: ANN003
pass
def on_connect(self, conn): # noqa: ANN001
pass
@pytest.mark.asyncio
async def test_bare_ping_request_produces_proper_response_and_no_stderr_noise(
caplog: pytest.LogCaptureFixture,
) -> None:
"""A bare ``ping`` must get a JSON-RPC -32601 back AND leave stderr clean
when the filter is installed on the handler.
"""
import acp
# Attach the filter to a fresh stream handler that mirrors entry._setup_logging.
stream = StringIO()
handler = logging.StreamHandler(stream)
handler.setFormatter(logging.Formatter("%(name)s|%(levelname)s|%(message)s"))
handler.addFilter(_BenignProbeMethodFilter())
root = logging.getLogger()
prior_handlers = root.handlers[:]
prior_level = root.level
root.handlers = [handler]
root.setLevel(logging.INFO)
# Also suppress propagation of caplog's default handler interfering with
# our stream (caplog still captures via its own propagation hook).
try:
loop = asyncio.get_running_loop()
# Pipe client -> agent
client_to_agent_r, client_to_agent_w = os.pipe()
# Pipe agent -> client
agent_to_client_r, agent_to_client_w = os.pipe()
in_read_file = os.fdopen(client_to_agent_r, "rb", buffering=0)
in_write_file = os.fdopen(client_to_agent_w, "wb", buffering=0)
out_read_file = os.fdopen(agent_to_client_r, "rb", buffering=0)
out_write_file = os.fdopen(agent_to_client_w, "wb", buffering=0)
# Agent reads its input from this StreamReader:
agent_input = asyncio.StreamReader(limit=1024 * 1024, loop=loop)
agent_input_proto = asyncio.StreamReaderProtocol(agent_input, loop=loop)
await loop.connect_read_pipe(lambda: agent_input_proto, in_read_file)
# Agent writes its output via this StreamWriter:
out_transport, out_protocol = await loop.connect_write_pipe(
asyncio.streams.FlowControlMixin, out_write_file
)
agent_output = asyncio.StreamWriter(out_transport, out_protocol, None, loop)
# Test harness reads agent output via this StreamReader:
client_input = asyncio.StreamReader(limit=1024 * 1024, loop=loop)
client_input_proto = asyncio.StreamReaderProtocol(client_input, loop=loop)
await loop.connect_read_pipe(lambda: client_input_proto, out_read_file)
agent_task = asyncio.create_task(
acp.run_agent(
_FakeAgent(),
input_stream=agent_output,
output_stream=agent_input,
use_unstable_protocol=True,
)
)
# Send a bare `ping`
request = {"jsonrpc": "2.0", "id": 1, "method": "ping", "params": {}}
in_write_file.write((json.dumps(request) + "\n").encode())
in_write_file.flush()
response_line = await asyncio.wait_for(client_input.readline(), timeout=5.0)
# Give the supervisor task a tick to fire (filter should eat it)
await asyncio.sleep(0.2)
response = json.loads(response_line.decode())
assert response["error"]["code"] == -32601, response
assert response["error"]["data"] == {"method": "ping"}, response
logs = stream.getvalue()
assert "Background task failed" not in logs, (
f"ping noise leaked to stderr:\n{logs}"
)
# Clean shutdown
in_write_file.close()
try:
await asyncio.wait_for(agent_task, timeout=2.0)
except (asyncio.TimeoutError, Exception):
agent_task.cancel()
try:
await agent_task
except BaseException: # noqa: BLE001
pass
finally:
root.handlers = prior_handlers
root.setLevel(prior_level)

View file

@ -95,19 +95,37 @@ class TestInitialize:
class TestAuthenticate:
@pytest.mark.asyncio
async def test_authenticate_with_provider_configured(self, agent, monkeypatch):
async def test_authenticate_with_matching_method_id(self, agent, monkeypatch):
monkeypatch.setattr(
"acp_adapter.server.has_provider",
lambda: True,
"acp_adapter.server.detect_provider",
lambda: "openrouter",
)
resp = await agent.authenticate(method_id="openrouter")
assert isinstance(resp, AuthenticateResponse)
@pytest.mark.asyncio
async def test_authenticate_is_case_insensitive(self, agent, monkeypatch):
monkeypatch.setattr(
"acp_adapter.server.detect_provider",
lambda: "openrouter",
)
resp = await agent.authenticate(method_id="OpenRouter")
assert isinstance(resp, AuthenticateResponse)
@pytest.mark.asyncio
async def test_authenticate_rejects_mismatched_method_id(self, agent, monkeypatch):
monkeypatch.setattr(
"acp_adapter.server.detect_provider",
lambda: "openrouter",
)
resp = await agent.authenticate(method_id="totally-invalid-method")
assert resp is None
@pytest.mark.asyncio
async def test_authenticate_without_provider(self, agent, monkeypatch):
monkeypatch.setattr(
"acp_adapter.server.has_provider",
lambda: False,
"acp_adapter.server.detect_provider",
lambda: None,
)
resp = await agent.authenticate(method_id="openrouter")
assert resp is None
@ -252,6 +270,57 @@ class TestListAndFork:
mock_list.assert_called_once_with(cwd="/mnt/e/Projects/AI/browser-link-3")
@pytest.mark.asyncio
async def test_list_sessions_pagination_first_page(self, agent):
from acp_adapter import server as acp_server
infos = [
{"session_id": f"s{i}", "cwd": "/tmp", "title": None, "updated_at": 0.0}
for i in range(acp_server._LIST_SESSIONS_PAGE_SIZE + 5)
]
with patch.object(agent.session_manager, "list_sessions", return_value=infos):
resp = await agent.list_sessions()
assert len(resp.sessions) == acp_server._LIST_SESSIONS_PAGE_SIZE
assert resp.next_cursor == resp.sessions[-1].session_id
@pytest.mark.asyncio
async def test_list_sessions_pagination_no_more(self, agent):
infos = [
{"session_id": f"s{i}", "cwd": "/tmp", "title": None, "updated_at": 0.0}
for i in range(3)
]
with patch.object(agent.session_manager, "list_sessions", return_value=infos):
resp = await agent.list_sessions()
assert len(resp.sessions) == 3
assert resp.next_cursor is None
@pytest.mark.asyncio
async def test_list_sessions_cursor_resumes_after_match(self, agent):
infos = [
{"session_id": "s1", "cwd": "/tmp", "title": None, "updated_at": 0.0},
{"session_id": "s2", "cwd": "/tmp", "title": None, "updated_at": 0.0},
{"session_id": "s3", "cwd": "/tmp", "title": None, "updated_at": 0.0},
]
with patch.object(agent.session_manager, "list_sessions", return_value=infos):
resp = await agent.list_sessions(cursor="s1")
assert [s.session_id for s in resp.sessions] == ["s2", "s3"]
assert resp.next_cursor is None
@pytest.mark.asyncio
async def test_list_sessions_unknown_cursor_returns_empty(self, agent):
infos = [
{"session_id": "s1", "cwd": "/tmp", "title": None, "updated_at": 0.0},
{"session_id": "s2", "cwd": "/tmp", "title": None, "updated_at": 0.0},
]
with patch.object(agent.session_manager, "list_sessions", return_value=infos):
resp = await agent.list_sessions(cursor="does-not-exist")
assert resp.sessions == []
assert resp.next_cursor is None
# ---------------------------------------------------------------------------
# session configuration / model routing
# ---------------------------------------------------------------------------

View file

@ -414,7 +414,11 @@ class TestRunOauthSetupToken:
token = run_oauth_setup_token()
assert token == "from-cred-file"
mock_run.assert_called_once()
# Don't assert exact call count — the contract is "credentials flow
# through", not "exactly one subprocess call". xdist cross-test
# pollution (other tests shimming subprocess via plugins) has flaked
# assert_called_once() in CI.
assert mock_run.called
def test_returns_token_from_env_var(self, monkeypatch, tmp_path):
"""Falls back to CLAUDE_CODE_OAUTH_TOKEN env var when no cred files."""

View file

@ -0,0 +1,238 @@
"""Regression tests: normalize_anthropic_response_v2 vs v1.
Constructs mock Anthropic responses and asserts that the v2 function
(returning NormalizedResponse) produces identical field values to the
original v1 function (returning SimpleNamespace + finish_reason).
"""
import json
import pytest
from types import SimpleNamespace
from agent.anthropic_adapter import (
normalize_anthropic_response,
normalize_anthropic_response_v2,
)
from agent.transports.types import NormalizedResponse, ToolCall
# ---------------------------------------------------------------------------
# Helpers to build mock Anthropic SDK responses
# ---------------------------------------------------------------------------
def _text_block(text: str):
return SimpleNamespace(type="text", text=text)
def _thinking_block(thinking: str, signature: str = "sig_abc"):
return SimpleNamespace(type="thinking", thinking=thinking, signature=signature)
def _tool_use_block(id: str, name: str, input: dict):
return SimpleNamespace(type="tool_use", id=id, name=name, input=input)
def _response(content_blocks, stop_reason="end_turn"):
return SimpleNamespace(
content=content_blocks,
stop_reason=stop_reason,
usage=SimpleNamespace(
input_tokens=10,
output_tokens=5,
),
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestTextOnly:
"""Text-only response — no tools, no thinking."""
def setup_method(self):
self.resp = _response([_text_block("Hello world")])
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
self.v2 = normalize_anthropic_response_v2(self.resp)
def test_type(self):
assert isinstance(self.v2, NormalizedResponse)
def test_content_matches(self):
assert self.v2.content == self.v1_msg.content
def test_finish_reason_matches(self):
assert self.v2.finish_reason == self.v1_finish
def test_no_tool_calls(self):
assert self.v2.tool_calls is None
assert self.v1_msg.tool_calls is None
def test_no_reasoning(self):
assert self.v2.reasoning is None
assert self.v1_msg.reasoning is None
class TestWithToolCalls:
"""Response with tool calls."""
def setup_method(self):
self.resp = _response(
[
_text_block("I'll check that"),
_tool_use_block("toolu_abc", "terminal", {"command": "ls"}),
_tool_use_block("toolu_def", "read_file", {"path": "/tmp"}),
],
stop_reason="tool_use",
)
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
self.v2 = normalize_anthropic_response_v2(self.resp)
def test_finish_reason(self):
assert self.v2.finish_reason == "tool_calls"
assert self.v1_finish == "tool_calls"
def test_tool_call_count(self):
assert len(self.v2.tool_calls) == 2
assert len(self.v1_msg.tool_calls) == 2
def test_tool_call_ids_match(self):
for i in range(2):
assert self.v2.tool_calls[i].id == self.v1_msg.tool_calls[i].id
def test_tool_call_names_match(self):
assert self.v2.tool_calls[0].name == "terminal"
assert self.v2.tool_calls[1].name == "read_file"
for i in range(2):
assert self.v2.tool_calls[i].name == self.v1_msg.tool_calls[i].function.name
def test_tool_call_arguments_match(self):
for i in range(2):
assert self.v2.tool_calls[i].arguments == self.v1_msg.tool_calls[i].function.arguments
def test_content_preserved(self):
assert self.v2.content == self.v1_msg.content
assert "check that" in self.v2.content
class TestWithThinking:
"""Response with thinking blocks (Claude 3.5+ extended thinking)."""
def setup_method(self):
self.resp = _response([
_thinking_block("Let me think about this carefully..."),
_text_block("The answer is 42."),
])
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
self.v2 = normalize_anthropic_response_v2(self.resp)
def test_reasoning_matches(self):
assert self.v2.reasoning == self.v1_msg.reasoning
assert "think about this" in self.v2.reasoning
def test_reasoning_details_in_provider_data(self):
v1_details = self.v1_msg.reasoning_details
v2_details = self.v2.provider_data.get("reasoning_details") if self.v2.provider_data else None
assert v1_details is not None
assert v2_details is not None
assert len(v2_details) == len(v1_details)
def test_content_excludes_thinking(self):
assert self.v2.content == "The answer is 42."
class TestMixed:
"""Response with thinking + text + tool calls."""
def setup_method(self):
self.resp = _response(
[
_thinking_block("Planning my approach..."),
_text_block("I'll run the command"),
_tool_use_block("toolu_xyz", "terminal", {"command": "pwd"}),
],
stop_reason="tool_use",
)
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
self.v2 = normalize_anthropic_response_v2(self.resp)
def test_all_fields_present(self):
assert self.v2.content is not None
assert self.v2.tool_calls is not None
assert self.v2.reasoning is not None
assert self.v2.finish_reason == "tool_calls"
def test_content_matches(self):
assert self.v2.content == self.v1_msg.content
def test_reasoning_matches(self):
assert self.v2.reasoning == self.v1_msg.reasoning
def test_tool_call_matches(self):
assert self.v2.tool_calls[0].id == self.v1_msg.tool_calls[0].id
assert self.v2.tool_calls[0].name == self.v1_msg.tool_calls[0].function.name
class TestStopReasons:
"""Verify finish_reason mapping matches between v1 and v2."""
@pytest.mark.parametrize("stop_reason,expected", [
("end_turn", "stop"),
("tool_use", "tool_calls"),
("max_tokens", "length"),
("stop_sequence", "stop"),
("refusal", "content_filter"),
("model_context_window_exceeded", "length"),
("unknown_future_reason", "stop"),
])
def test_stop_reason_mapping(self, stop_reason, expected):
resp = _response([_text_block("x")], stop_reason=stop_reason)
v1_msg, v1_finish = normalize_anthropic_response(resp)
v2 = normalize_anthropic_response_v2(resp)
assert v2.finish_reason == v1_finish == expected
class TestStripToolPrefix:
"""Verify mcp_ prefix stripping works identically."""
def test_prefix_stripped(self):
resp = _response(
[_tool_use_block("toolu_1", "mcp_terminal", {"cmd": "ls"})],
stop_reason="tool_use",
)
v1_msg, _ = normalize_anthropic_response(resp, strip_tool_prefix=True)
v2 = normalize_anthropic_response_v2(resp, strip_tool_prefix=True)
assert v1_msg.tool_calls[0].function.name == "terminal"
assert v2.tool_calls[0].name == "terminal"
def test_prefix_kept(self):
resp = _response(
[_tool_use_block("toolu_1", "mcp_terminal", {"cmd": "ls"})],
stop_reason="tool_use",
)
v1_msg, _ = normalize_anthropic_response(resp, strip_tool_prefix=False)
v2 = normalize_anthropic_response_v2(resp, strip_tool_prefix=False)
assert v1_msg.tool_calls[0].function.name == "mcp_terminal"
assert v2.tool_calls[0].name == "mcp_terminal"
class TestEdgeCases:
"""Edge cases: empty content, no blocks, etc."""
def test_empty_content_blocks(self):
resp = _response([])
v1_msg, v1_finish = normalize_anthropic_response(resp)
v2 = normalize_anthropic_response_v2(resp)
assert v2.content == v1_msg.content
assert v2.content is None
def test_no_reasoning_details_means_none_provider_data(self):
resp = _response([_text_block("hi")])
v2 = normalize_anthropic_response_v2(resp)
assert v2.provider_data is None
def test_v2_returns_dataclass_not_namespace(self):
resp = _response([_text_block("hi")])
v2 = normalize_anthropic_response_v2(resp)
assert isinstance(v2, NormalizedResponse)
assert not isinstance(v2, SimpleNamespace)

View file

@ -476,6 +476,133 @@ class TestGetTextAuxiliaryClient:
assert isinstance(client, CodexAuxiliaryClient)
assert model == "gpt-5.2-codex"
class TestNousAuxiliaryRefresh:
def test_try_nous_prefers_runtime_credentials(self):
fresh_base = "https://inference-api.nousresearch.com/v1"
with (
patch("agent.auxiliary_client._read_nous_auth", return_value={"access_token": "stale-token"}),
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", fresh_base)),
patch("hermes_cli.models.get_nous_recommended_aux_model", return_value=None),
patch("agent.auxiliary_client.OpenAI") as mock_openai,
):
from agent.auxiliary_client import _try_nous
mock_openai.return_value = MagicMock()
client, model = _try_nous()
assert client is not None
# No Portal recommendation → falls back to the hardcoded default.
assert model == "google/gemini-3-flash-preview"
assert mock_openai.call_args.kwargs["api_key"] == "fresh-agent-key"
assert mock_openai.call_args.kwargs["base_url"] == fresh_base
def test_try_nous_uses_portal_recommendation_for_text(self):
"""When the Portal recommends a compaction model, _try_nous honors it."""
fresh_base = "https://inference-api.nousresearch.com/v1"
with (
patch("agent.auxiliary_client._read_nous_auth", return_value={"access_token": "***"}),
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", fresh_base)),
patch("hermes_cli.models.get_nous_recommended_aux_model", return_value="minimax/minimax-m2.7") as mock_rec,
patch("agent.auxiliary_client.OpenAI") as mock_openai,
):
from agent.auxiliary_client import _try_nous
mock_openai.return_value = MagicMock()
client, model = _try_nous(vision=False)
assert client is not None
assert model == "minimax/minimax-m2.7"
assert mock_rec.call_args.kwargs["vision"] is False
def test_try_nous_uses_portal_recommendation_for_vision(self):
"""Vision tasks should ask for the vision-specific recommendation."""
fresh_base = "https://inference-api.nousresearch.com/v1"
with (
patch("agent.auxiliary_client._read_nous_auth", return_value={"access_token": "***"}),
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", fresh_base)),
patch("hermes_cli.models.get_nous_recommended_aux_model", return_value="google/gemini-3-flash-preview") as mock_rec,
patch("agent.auxiliary_client.OpenAI"),
):
from agent.auxiliary_client import _try_nous
client, model = _try_nous(vision=True)
assert client is not None
assert model == "google/gemini-3-flash-preview"
assert mock_rec.call_args.kwargs["vision"] is True
def test_try_nous_falls_back_when_recommendation_lookup_raises(self):
"""If the Portal lookup throws, we must still return a usable model."""
fresh_base = "https://inference-api.nousresearch.com/v1"
with (
patch("agent.auxiliary_client._read_nous_auth", return_value={"access_token": "***"}),
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", fresh_base)),
patch("hermes_cli.models.get_nous_recommended_aux_model", side_effect=RuntimeError("portal down")),
patch("agent.auxiliary_client.OpenAI"),
):
from agent.auxiliary_client import _try_nous
client, model = _try_nous()
assert client is not None
assert model == "google/gemini-3-flash-preview"
def test_call_llm_retries_nous_after_401(self):
class _Auth401(Exception):
status_code = 401
stale_client = MagicMock()
stale_client.base_url = "https://inference-api.nousresearch.com/v1"
stale_client.chat.completions.create.side_effect = _Auth401("stale nous key")
fresh_client = MagicMock()
fresh_client.base_url = "https://inference-api.nousresearch.com/v1"
fresh_client.chat.completions.create.return_value = {"ok": True}
with (
patch("agent.auxiliary_client._resolve_task_provider_model", return_value=("nous", "nous-model", None, None, None)),
patch("agent.auxiliary_client._get_cached_client", return_value=(stale_client, "nous-model")),
patch("agent.auxiliary_client.OpenAI", return_value=fresh_client),
patch("agent.auxiliary_client._validate_llm_response", side_effect=lambda resp, _task: resp),
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", "https://inference-api.nousresearch.com/v1")),
):
result = call_llm(
task="compression",
messages=[{"role": "user", "content": "hi"}],
)
assert result == {"ok": True}
assert stale_client.chat.completions.create.call_count == 1
assert fresh_client.chat.completions.create.call_count == 1
@pytest.mark.asyncio
async def test_async_call_llm_retries_nous_after_401(self):
class _Auth401(Exception):
status_code = 401
stale_client = MagicMock()
stale_client.base_url = "https://inference-api.nousresearch.com/v1"
stale_client.chat.completions.create = AsyncMock(side_effect=_Auth401("stale nous key"))
fresh_async_client = MagicMock()
fresh_async_client.base_url = "https://inference-api.nousresearch.com/v1"
fresh_async_client.chat.completions.create = AsyncMock(return_value={"ok": True})
with (
patch("agent.auxiliary_client._resolve_task_provider_model", return_value=("nous", "nous-model", None, None, None)),
patch("agent.auxiliary_client._get_cached_client", return_value=(stale_client, "nous-model")),
patch("agent.auxiliary_client._to_async_client", return_value=(fresh_async_client, "nous-model")),
patch("agent.auxiliary_client._validate_llm_response", side_effect=lambda resp, _task: resp),
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", "https://inference-api.nousresearch.com/v1")),
):
result = await async_call_llm(
task="session_search",
messages=[{"role": "user", "content": "hi"}],
)
assert result == {"ok": True}
assert stale_client.chat.completions.create.await_count == 1
assert fresh_async_client.chat.completions.create.await_count == 1
# ── Payment / credit exhaustion fallback ─────────────────────────────────
@ -696,27 +823,46 @@ class TestIsConnectionError:
assert _is_connection_error(err) is False
class TestKimiForCodingTemperature:
"""Moonshot kimi-for-coding models require fixed temperatures.
class TestKimiTemperatureOmitted:
"""Kimi/Moonshot models should have temperature OMITTED from API kwargs.
k2.5 / k2-turbo-preview / k2-0905-preview 0.6 (non-thinking lock).
k2-thinking / k2-thinking-turbo 1.0 (thinking lock).
kimi-k2-instruct* and every other model preserve the caller's temperature.
The Kimi gateway selects the correct temperature server-side based on the
active mode (thinking 1.0, non-thinking 0.6). Sending any temperature
value conflicts with gateway-managed defaults.
"""
def test_build_call_kwargs_forces_fixed_temperature(self):
@pytest.mark.parametrize(
"model",
[
"kimi-for-coding",
"kimi-k2.5",
"kimi-k2.6",
"kimi-k2-turbo-preview",
"kimi-k2-0905-preview",
"kimi-k2-thinking",
"kimi-k2-thinking-turbo",
"kimi-k2-instruct",
"kimi-k2-instruct-0905",
"moonshotai/kimi-k2.5",
"moonshotai/Kimi-K2-Thinking",
"moonshotai/Kimi-K2-Instruct",
],
)
def test_kimi_models_omit_temperature(self, model):
"""No kimi model should have a temperature key in kwargs."""
from agent.auxiliary_client import _build_call_kwargs
kwargs = _build_call_kwargs(
provider="kimi-coding",
model="kimi-for-coding",
model=model,
messages=[{"role": "user", "content": "hello"}],
temperature=0.3,
)
assert kwargs["temperature"] == 0.6
assert "temperature" not in kwargs
def test_build_call_kwargs_injects_temperature_when_missing(self):
def test_kimi_for_coding_no_temperature_when_none(self):
"""When caller passes temperature=None, still no temperature key."""
from agent.auxiliary_client import _build_call_kwargs
kwargs = _build_call_kwargs(
@ -726,9 +872,9 @@ class TestKimiForCodingTemperature:
temperature=None,
)
assert kwargs["temperature"] == 0.6
assert "temperature" not in kwargs
def test_auto_routed_kimi_for_coding_sync_call_uses_fixed_temperature(self):
def test_sync_call_omits_temperature(self):
client = MagicMock()
client.base_url = "https://api.kimi.com/coding/v1"
response = MagicMock()
@ -750,10 +896,10 @@ class TestKimiForCodingTemperature:
assert result is response
kwargs = client.chat.completions.create.call_args.kwargs
assert kwargs["model"] == "kimi-for-coding"
assert kwargs["temperature"] == 0.6
assert "temperature" not in kwargs
@pytest.mark.asyncio
async def test_auto_routed_kimi_for_coding_async_call_uses_fixed_temperature(self):
async def test_async_call_omits_temperature(self):
client = MagicMock()
client.base_url = "https://api.kimi.com/coding/v1"
response = MagicMock()
@ -775,52 +921,17 @@ class TestKimiForCodingTemperature:
assert result is response
kwargs = client.chat.completions.create.call_args.kwargs
assert kwargs["model"] == "kimi-for-coding"
assert kwargs["temperature"] == 0.6
@pytest.mark.parametrize(
"model,expected",
[
("kimi-k2.5", 0.6),
("kimi-k2-turbo-preview", 0.6),
("kimi-k2-0905-preview", 0.6),
("kimi-k2-thinking", 1.0),
("kimi-k2-thinking-turbo", 1.0),
("moonshotai/kimi-k2.5", 0.6),
("moonshotai/Kimi-K2-Thinking", 1.0),
],
)
def test_kimi_k2_family_temperature_override(self, model, expected):
"""Moonshot kimi-k2.* models only accept fixed temperatures.
Non-thinking models 0.6, thinking-mode models 1.0.
"""
from agent.auxiliary_client import _build_call_kwargs
kwargs = _build_call_kwargs(
provider="kimi-coding",
model=model,
messages=[{"role": "user", "content": "hello"}],
temperature=0.3,
)
assert kwargs["temperature"] == expected
assert "temperature" not in kwargs
@pytest.mark.parametrize(
"model",
[
"anthropic/claude-sonnet-4-6",
"gpt-5.4",
# kimi-k2-instruct is the non-coding K2 family — temperature is
# variable (recommended 0.6 but not enforced). Must not clamp.
"kimi-k2-instruct",
"moonshotai/Kimi-K2-Instruct",
"moonshotai/Kimi-K2-Instruct-0905",
"kimi-k2-instruct-0905",
# Hypothetical future kimi name not in the whitelist.
"kimi-k2-experimental",
"deepseek-chat",
],
)
def test_non_restricted_model_preserves_temperature(self, model):
def test_non_kimi_models_preserve_temperature(self, model):
from agent.auxiliary_client import _build_call_kwargs
kwargs = _build_call_kwargs(
@ -832,6 +943,28 @@ class TestKimiForCodingTemperature:
assert kwargs["temperature"] == 0.3
@pytest.mark.parametrize(
"base_url",
[
"https://api.moonshot.ai/v1",
"https://api.moonshot.cn/v1",
"https://api.kimi.com/coding/v1",
],
)
def test_kimi_k2_5_omits_temperature_regardless_of_endpoint(self, base_url):
"""Temperature is omitted regardless of which Kimi endpoint is used."""
from agent.auxiliary_client import _build_call_kwargs
kwargs = _build_call_kwargs(
provider="kimi-coding",
model="kimi-k2.5",
messages=[{"role": "user", "content": "hello"}],
temperature=0.1,
base_url=base_url,
)
assert "temperature" not in kwargs
# ---------------------------------------------------------------------------
# async_call_llm payment / connection fallback (#7512 bug 2)
@ -858,6 +991,70 @@ class TestStaleBaseUrlWarning:
"Expected a warning about stale OPENAI_BASE_URL"
assert mod._stale_base_url_warned is True
class TestAuxiliaryTaskExtraBody:
def test_sync_call_merges_task_extra_body_from_config(self):
client = MagicMock()
client.base_url = "https://api.example.com/v1"
response = MagicMock()
client.chat.completions.create.return_value = response
config = {
"auxiliary": {
"session_search": {
"extra_body": {
"enable_thinking": False,
"reasoning": {"effort": "none"},
}
}
}
}
with patch("hermes_cli.config.load_config", return_value=config), patch(
"agent.auxiliary_client._get_cached_client",
return_value=(client, "glm-4.5-air"),
):
result = call_llm(
task="session_search",
messages=[{"role": "user", "content": "hello"}],
extra_body={"metadata": {"source": "test"}},
)
assert result is response
kwargs = client.chat.completions.create.call_args.kwargs
assert kwargs["extra_body"]["enable_thinking"] is False
assert kwargs["extra_body"]["reasoning"] == {"effort": "none"}
assert kwargs["extra_body"]["metadata"] == {"source": "test"}
@pytest.mark.asyncio
async def test_async_call_explicit_extra_body_overrides_task_config(self):
client = MagicMock()
client.base_url = "https://api.example.com/v1"
response = MagicMock()
client.chat.completions.create = AsyncMock(return_value=response)
config = {
"auxiliary": {
"session_search": {
"extra_body": {"enable_thinking": False}
}
}
}
with patch("hermes_cli.config.load_config", return_value=config), patch(
"agent.auxiliary_client._get_cached_client",
return_value=(client, "glm-4.5-air"),
):
result = await async_call_llm(
task="session_search",
messages=[{"role": "user", "content": "hello"}],
extra_body={"enable_thinking": True},
)
assert result is response
kwargs = client.chat.completions.create.call_args.kwargs
assert kwargs["extra_body"]["enable_thinking"] is True
def test_no_warning_when_provider_is_custom(self, monkeypatch, caplog):
"""No warning when the provider is 'custom' — OPENAI_BASE_URL is expected."""
import agent.auxiliary_client as mod

View file

@ -0,0 +1,107 @@
"""Tests for agent.auxiliary_client._try_custom_endpoint's anthropic_messages branch.
When a user configures a custom endpoint with ``api_mode: anthropic_messages``
(e.g. MiniMax, Zhipu GLM, LiteLLM in Anthropic-proxy mode), auxiliary tasks
(compression, web_extract, session_search, title generation) must use the
native Anthropic transport rather than being silently downgraded to an
OpenAI-wire client that speaks the wrong protocol.
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
@pytest.fixture(autouse=True)
def _clean_env(monkeypatch):
for key in (
"OPENAI_API_KEY", "OPENAI_BASE_URL",
"ANTHROPIC_API_KEY", "ANTHROPIC_TOKEN",
):
monkeypatch.delenv(key, raising=False)
def _install_anthropic_adapter_mocks():
"""Patch build_anthropic_client so the test doesn't need the SDK."""
fake_client = MagicMock(name="anthropic_client")
return patch(
"agent.anthropic_adapter.build_anthropic_client",
return_value=fake_client,
), fake_client
def test_custom_endpoint_anthropic_messages_builds_anthropic_wrapper():
"""api_mode=anthropic_messages → returns AnthropicAuxiliaryClient, not OpenAI."""
from agent.auxiliary_client import _try_custom_endpoint, AnthropicAuxiliaryClient
with patch(
"agent.auxiliary_client._resolve_custom_runtime",
return_value=(
"https://api.minimax.io/anthropic",
"minimax-key",
"anthropic_messages",
),
), patch(
"agent.auxiliary_client._read_main_model",
return_value="claude-sonnet-4-6",
):
adapter_patch, fake_client = _install_anthropic_adapter_mocks()
with adapter_patch:
client, model = _try_custom_endpoint()
assert isinstance(client, AnthropicAuxiliaryClient), (
"Custom endpoint with api_mode=anthropic_messages must return the "
f"native Anthropic wrapper, got {type(client).__name__}"
)
assert model == "claude-sonnet-4-6"
# Wrapper should NOT be marked as OAuth — third-party endpoints are
# always API-key authenticated.
assert client.api_key == "minimax-key"
assert client.base_url == "https://api.minimax.io/anthropic"
def test_custom_endpoint_anthropic_messages_falls_back_when_sdk_missing():
"""Graceful degradation when anthropic SDK is unavailable."""
from agent.auxiliary_client import _try_custom_endpoint
import_error = ImportError("anthropic package not installed")
with patch(
"agent.auxiliary_client._resolve_custom_runtime",
return_value=("https://api.minimax.io/anthropic", "k", "anthropic_messages"),
), patch(
"agent.auxiliary_client._read_main_model",
return_value="claude-sonnet-4-6",
), patch(
"agent.anthropic_adapter.build_anthropic_client",
side_effect=import_error,
):
client, model = _try_custom_endpoint()
# Should fall back to an OpenAI-wire client rather than returning
# (None, None) — the tool still needs to do *something*.
assert client is not None
assert model == "claude-sonnet-4-6"
# OpenAI client, not AnthropicAuxiliaryClient.
from agent.auxiliary_client import AnthropicAuxiliaryClient
assert not isinstance(client, AnthropicAuxiliaryClient)
def test_custom_endpoint_chat_completions_still_uses_openai_wire():
"""Regression: default path (no api_mode) must remain OpenAI client."""
from agent.auxiliary_client import _try_custom_endpoint, AnthropicAuxiliaryClient
with patch(
"agent.auxiliary_client._resolve_custom_runtime",
return_value=("https://api.example.com/v1", "key", None),
), patch(
"agent.auxiliary_client._read_main_model",
return_value="my-model",
):
client, model = _try_custom_endpoint()
assert client is not None
assert model == "my-model"
assert not isinstance(client, AnthropicAuxiliaryClient)

View file

@ -167,7 +167,7 @@ class TestResolveAutoMainFirst:
class TestResolveVisionMainFirst:
"""Vision auto-detection prefers main provider + main model first."""
"""Vision auto-detection prefers the main provider first."""
def test_openrouter_main_vision_uses_main_model(self, monkeypatch):
"""OpenRouter main with vision-capable model → aux vision uses main model."""
@ -200,28 +200,49 @@ class TestResolveVisionMainFirst:
assert mock_resolve.call_args.args[0] == "openrouter"
assert mock_resolve.call_args.args[1] == "anthropic/claude-sonnet-4.6"
def test_nous_main_vision_uses_main_model(self):
"""Nous Portal main → aux vision uses main model, not free-tier MiMo-V2-Omni."""
def test_nous_main_vision_uses_paid_nous_vision_backend(self):
"""Paid Nous main → aux vision uses the dedicated Nous vision backend."""
with patch(
"agent.auxiliary_client._read_main_provider", return_value="nous",
), patch(
"agent.auxiliary_client._read_main_model",
return_value="openai/gpt-5",
), patch(
"agent.auxiliary_client.resolve_provider_client"
) as mock_resolve, patch(
"agent.auxiliary_client._resolve_task_provider_model",
return_value=("auto", None, None, None, None),
), patch(
"agent.auxiliary_client._resolve_strict_vision_backend",
return_value=(MagicMock(), "google/gemini-3-flash-preview"),
):
mock_client = MagicMock()
mock_resolve.return_value = (mock_client, "openai/gpt-5")
from agent.auxiliary_client import resolve_vision_provider_client
provider, client, model = resolve_vision_provider_client()
assert provider == "nous"
assert model == "openai/gpt-5"
assert client is not None
assert model == "google/gemini-3-flash-preview"
def test_nous_main_vision_uses_free_tier_nous_vision_backend(self):
"""Free-tier Nous main → aux vision uses MiMo omni, not the text main model."""
with patch(
"agent.auxiliary_client._read_main_provider", return_value="nous",
), patch(
"agent.auxiliary_client._read_main_model",
return_value="xiaomi/mimo-v2-pro",
), patch(
"agent.auxiliary_client._resolve_task_provider_model",
return_value=("auto", None, None, None, None),
), patch(
"agent.auxiliary_client._resolve_strict_vision_backend",
return_value=(MagicMock(), "xiaomi/mimo-v2-omni"),
):
from agent.auxiliary_client import resolve_vision_provider_client
provider, client, model = resolve_vision_provider_client()
assert provider == "nous"
assert client is not None
assert model == "xiaomi/mimo-v2-omni"
def test_exotic_provider_with_vision_override_preserved(self):
"""xiaomi → mimo-v2-omni override still wins over main_model."""

View file

@ -267,3 +267,174 @@ class TestPackaging:
from pathlib import Path
content = (Path(__file__).parent.parent.parent / "pyproject.toml").read_text()
assert '"hermes-agent[bedrock]"' in content
# ---------------------------------------------------------------------------
# Model ID dot preservation — regression for #11976
# ---------------------------------------------------------------------------
# AWS Bedrock inference-profile model IDs embed structural dots:
#
# global.anthropic.claude-opus-4-7
# us.anthropic.claude-sonnet-4-5-20250929-v1:0
# apac.anthropic.claude-haiku-4-5
#
# ``agent.anthropic_adapter.normalize_model_name`` converts dots to hyphens
# unless the caller opts in via ``preserve_dots=True``. Before this fix,
# ``AIAgent._anthropic_preserve_dots`` returned False for the ``bedrock``
# provider, so Claude-on-Bedrock requests went out with
# ``global-anthropic-claude-opus-4-7`` (all dots mangled to hyphens) and
# Bedrock rejected them with:
#
# HTTP 400: The provided model identifier is invalid.
#
# The fix adds ``bedrock`` to the preserve-dots provider allowlist and
# ``bedrock-runtime.`` to the base-URL heuristic, mirroring the shape of
# the opencode-go fix for #5211 (commit f77be22c), which extended this
# same allowlist.
class TestBedrockPreserveDotsFlag:
"""``AIAgent._anthropic_preserve_dots`` must return True on Bedrock so
inference-profile IDs survive the normalize step intact."""
def test_bedrock_provider_preserves_dots(self):
from types import SimpleNamespace
agent = SimpleNamespace(provider="bedrock", base_url="")
from run_agent import AIAgent
assert AIAgent._anthropic_preserve_dots(agent) is True
def test_bedrock_runtime_us_east_1_url_preserves_dots(self):
"""Defense-in-depth: even without an explicit ``provider="bedrock"``,
a ``bedrock-runtime.us-east-1.amazonaws.com`` base URL must not
mangle dots."""
from types import SimpleNamespace
agent = SimpleNamespace(
provider="custom",
base_url="https://bedrock-runtime.us-east-1.amazonaws.com",
)
from run_agent import AIAgent
assert AIAgent._anthropic_preserve_dots(agent) is True
def test_bedrock_runtime_ap_northeast_2_url_preserves_dots(self):
"""Reporter-reported region (ap-northeast-2) exercises the same
base-URL heuristic."""
from types import SimpleNamespace
agent = SimpleNamespace(
provider="custom",
base_url="https://bedrock-runtime.ap-northeast-2.amazonaws.com",
)
from run_agent import AIAgent
assert AIAgent._anthropic_preserve_dots(agent) is True
def test_non_bedrock_aws_url_does_not_preserve_dots(self):
"""Unrelated AWS endpoints (e.g. ``s3.us-east-1.amazonaws.com``)
must not accidentally activate the dot-preservation heuristic
the heuristic is scoped to the ``bedrock-runtime.`` substring
specifically."""
from types import SimpleNamespace
agent = SimpleNamespace(
provider="custom",
base_url="https://s3.us-east-1.amazonaws.com",
)
from run_agent import AIAgent
assert AIAgent._anthropic_preserve_dots(agent) is False
def test_anthropic_native_still_does_not_preserve_dots(self):
"""Canary: adding Bedrock to the allowlist must not weaken the
existing Anthropic native behaviour ``claude-sonnet-4.6`` still
becomes ``claude-sonnet-4-6`` for the Anthropic API."""
from types import SimpleNamespace
agent = SimpleNamespace(provider="anthropic", base_url="https://api.anthropic.com")
from run_agent import AIAgent
assert AIAgent._anthropic_preserve_dots(agent) is False
class TestBedrockModelNameNormalization:
"""End-to-end: ``normalize_model_name`` + the preserve-dots flag
reproduce the exact production request shape for each Bedrock model
family, confirming the fix resolves the reporter's HTTP 400."""
def test_global_anthropic_inference_profile_preserved(self):
"""The reporter's exact model ID."""
from agent.anthropic_adapter import normalize_model_name
assert normalize_model_name(
"global.anthropic.claude-opus-4-7", preserve_dots=True
) == "global.anthropic.claude-opus-4-7"
def test_us_anthropic_dated_inference_profile_preserved(self):
"""Regional + dated Sonnet inference profile."""
from agent.anthropic_adapter import normalize_model_name
assert normalize_model_name(
"us.anthropic.claude-sonnet-4-5-20250929-v1:0",
preserve_dots=True,
) == "us.anthropic.claude-sonnet-4-5-20250929-v1:0"
def test_apac_anthropic_haiku_inference_profile_preserved(self):
"""APAC inference profile — same structural-dot shape."""
from agent.anthropic_adapter import normalize_model_name
assert normalize_model_name(
"apac.anthropic.claude-haiku-4-5", preserve_dots=True
) == "apac.anthropic.claude-haiku-4-5"
def test_preserve_false_mangles_as_documented(self):
"""Canary: with ``preserve_dots=False`` the function still
produces the broken all-hyphen form this is the shape that
Bedrock rejected and that the fix avoids. Keeping this test
locks in the existing behaviour of ``normalize_model_name`` so a
future refactor doesn't accidentally decouple the knob from its
effect."""
from agent.anthropic_adapter import normalize_model_name
assert normalize_model_name(
"global.anthropic.claude-opus-4-7", preserve_dots=False
) == "global-anthropic-claude-opus-4-7"
def test_bare_foundation_model_id_preserved(self):
"""Non-inference-profile Bedrock IDs
(e.g. ``anthropic.claude-3-5-sonnet-20241022-v2:0``) use dots as
vendor separators and must also survive intact under
``preserve_dots=True``."""
from agent.anthropic_adapter import normalize_model_name
assert normalize_model_name(
"anthropic.claude-3-5-sonnet-20241022-v2:0",
preserve_dots=True,
) == "anthropic.claude-3-5-sonnet-20241022-v2:0"
class TestBedrockBuildAnthropicKwargsEndToEnd:
"""Integration: calling ``build_anthropic_kwargs`` with a Bedrock-
shaped model ID and ``preserve_dots=True`` produces the unmangled
model string in the outgoing kwargs the exact body sent to the
``bedrock-runtime.`` endpoint. This is the integration-level
regression for the reporter's HTTP 400."""
def test_bedrock_inference_profile_survives_build_kwargs(self):
from agent.anthropic_adapter import build_anthropic_kwargs
kwargs = build_anthropic_kwargs(
model="global.anthropic.claude-opus-4-7",
messages=[{"role": "user", "content": "hi"}],
tools=None,
max_tokens=1024,
reasoning_config=None,
preserve_dots=True,
)
assert kwargs["model"] == "global.anthropic.claude-opus-4-7", (
"Bedrock inference-profile ID was mangled in build_anthropic_kwargs: "
f"{kwargs['model']!r}"
)
def test_bedrock_model_mangled_without_preserve_dots(self):
"""Inverse canary: without the flag, ``build_anthropic_kwargs``
still produces the broken form so the fix in
``_anthropic_preserve_dots`` is the load-bearing piece that
wires ``preserve_dots=True`` through to this builder for the
Bedrock case."""
from agent.anthropic_adapter import build_anthropic_kwargs
kwargs = build_anthropic_kwargs(
model="global.anthropic.claude-opus-4-7",
messages=[{"role": "user", "content": "hi"}],
tools=None,
max_tokens=1024,
reasoning_config=None,
preserve_dots=False,
)
assert kwargs["model"] == "global-anthropic-claude-opus-4-7"

View file

@ -0,0 +1,253 @@
"""Regression guard: Codex Cloudflare 403 mitigation headers.
The ``chatgpt.com/backend-api/codex`` endpoint sits behind a Cloudflare layer
that whitelists a small set of first-party originators (``codex_cli_rs``,
``codex_vscode``, ``codex_sdk_ts``, ``Codex*``). Requests from non-residential
IPs (VPS, always-on servers, some corporate egress) that don't advertise an
allowed originator are served 403 with ``cf-mitigated: challenge`` regardless
of auth correctness.
``_codex_cloudflare_headers`` in ``agent.auxiliary_client`` centralizes the
header set so the primary chat client (``run_agent.AIAgent.__init__`` +
``_apply_client_headers_for_base_url``) and the auxiliary client paths
(``_try_codex`` and the ``raw_codex`` branch of ``resolve_provider_client``)
all emit the same headers.
These tests pin:
- the originator value (must be ``codex_cli_rs`` the whitelisted one)
- the User-Agent shape (codex_cli_rs-prefixed)
- ``ChatGPT-Account-ID`` extraction from the OAuth JWT (canonical casing,
from codex-rs ``auth.rs``)
- graceful handling of malformed tokens (drop the account-ID header, don't
raise)
- primary-client wiring at both entry points in ``run_agent.py``
- aux-client wiring at both entry points in ``agent/auxiliary_client.py``
"""
from __future__ import annotations
import base64
import json
from unittest.mock import MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _make_codex_jwt(account_id: str = "acct-test-123") -> str:
"""Build a syntactically valid Codex-style JWT with the account_id claim."""
def b64url(data: bytes) -> str:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
header = b64url(b'{"alg":"RS256","typ":"JWT"}')
claims = {
"sub": "user-xyz",
"exp": 9999999999,
"https://api.openai.com/auth": {
"chatgpt_account_id": account_id,
"chatgpt_plan_type": "plus",
},
}
payload = b64url(json.dumps(claims).encode())
sig = b64url(b"fake-sig")
return f"{header}.{payload}.{sig}"
# ---------------------------------------------------------------------------
# _codex_cloudflare_headers — the shared helper
# ---------------------------------------------------------------------------
class TestCodexCloudflareHeaders:
def test_originator_is_codex_cli_rs(self):
"""Cloudflare whitelists codex_cli_rs — any other value is 403'd."""
from agent.auxiliary_client import _codex_cloudflare_headers
headers = _codex_cloudflare_headers(_make_codex_jwt())
assert headers["originator"] == "codex_cli_rs"
def test_user_agent_advertises_codex_cli_rs(self):
from agent.auxiliary_client import _codex_cloudflare_headers
headers = _codex_cloudflare_headers(_make_codex_jwt())
assert headers["User-Agent"].startswith("codex_cli_rs/")
def test_account_id_extracted_from_jwt(self):
from agent.auxiliary_client import _codex_cloudflare_headers
headers = _codex_cloudflare_headers(_make_codex_jwt("acct-abc-999"))
# Canonical casing — matches codex-rs auth.rs
assert headers["ChatGPT-Account-ID"] == "acct-abc-999"
def test_canonical_header_casing(self):
"""Upstream codex-rs uses PascalCase with trailing -ID. Match exactly."""
from agent.auxiliary_client import _codex_cloudflare_headers
headers = _codex_cloudflare_headers(_make_codex_jwt())
assert "ChatGPT-Account-ID" in headers
# The lowercase/titlecase variants MUST NOT be used — pin to be explicit
assert "chatgpt-account-id" not in headers
assert "ChatGPT-Account-Id" not in headers
def test_malformed_token_drops_account_id_without_raising(self):
from agent.auxiliary_client import _codex_cloudflare_headers
for bad in ["not-a-jwt", "", "only.one", " ", "...."]:
headers = _codex_cloudflare_headers(bad)
# Still returns base headers — never raises
assert headers["originator"] == "codex_cli_rs"
assert "ChatGPT-Account-ID" not in headers
def test_non_string_token_handled(self):
from agent.auxiliary_client import _codex_cloudflare_headers
headers = _codex_cloudflare_headers(None) # type: ignore[arg-type]
assert headers["originator"] == "codex_cli_rs"
assert "ChatGPT-Account-ID" not in headers
def test_jwt_without_chatgpt_account_id_claim(self):
"""A valid JWT that lacks the account_id claim should still return headers."""
from agent.auxiliary_client import _codex_cloudflare_headers
import base64 as _b64, json as _json
def b64url(data: bytes) -> str:
return _b64.urlsafe_b64encode(data).rstrip(b"=").decode()
payload = b64url(_json.dumps({"sub": "user-xyz", "exp": 9999999999}).encode())
token = f"{b64url(b'{}')}.{payload}.{b64url(b'sig')}"
headers = _codex_cloudflare_headers(token)
assert headers["originator"] == "codex_cli_rs"
assert "ChatGPT-Account-ID" not in headers
# ---------------------------------------------------------------------------
# Primary chat client wiring (run_agent.AIAgent)
# ---------------------------------------------------------------------------
class TestPrimaryClientWiring:
def test_init_wires_codex_headers_for_chatgpt_base_url(self):
from run_agent import AIAgent
token = _make_codex_jwt("acct-primary-init")
with patch("run_agent.OpenAI") as mock_openai:
mock_openai.return_value = MagicMock()
AIAgent(
api_key=token,
base_url="https://chatgpt.com/backend-api/codex",
provider="openai-codex",
model="gpt-5.4",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
headers = mock_openai.call_args.kwargs.get("default_headers") or {}
assert headers.get("originator") == "codex_cli_rs"
assert headers.get("ChatGPT-Account-ID") == "acct-primary-init"
assert headers.get("User-Agent", "").startswith("codex_cli_rs/")
def test_apply_client_headers_on_base_url_change(self):
"""Credential-rotation / base-url change path must also emit codex headers."""
from run_agent import AIAgent
token = _make_codex_jwt("acct-rotation")
with patch("run_agent.OpenAI") as mock_openai:
mock_openai.return_value = MagicMock()
agent = AIAgent(
api_key="placeholder-openrouter-key",
base_url="https://openrouter.ai/api/v1",
provider="openrouter",
model="anthropic/claude-sonnet-4.6",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
# Simulate rotation into a Codex credential
agent._client_kwargs["api_key"] = token
agent._apply_client_headers_for_base_url(
"https://chatgpt.com/backend-api/codex"
)
headers = agent._client_kwargs.get("default_headers") or {}
assert headers.get("originator") == "codex_cli_rs"
assert headers.get("ChatGPT-Account-ID") == "acct-rotation"
assert headers.get("User-Agent", "").startswith("codex_cli_rs/")
def test_apply_client_headers_clears_codex_headers_off_chatgpt(self):
"""Switching AWAY from chatgpt.com must drop the codex headers."""
from run_agent import AIAgent
token = _make_codex_jwt()
with patch("run_agent.OpenAI") as mock_openai:
mock_openai.return_value = MagicMock()
agent = AIAgent(
api_key=token,
base_url="https://chatgpt.com/backend-api/codex",
provider="openai-codex",
model="gpt-5.4",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
# Sanity: headers are set initially
assert "originator" in (agent._client_kwargs.get("default_headers") or {})
agent._apply_client_headers_for_base_url(
"https://api.anthropic.com"
)
# default_headers should be popped for anthropic base
assert "default_headers" not in agent._client_kwargs
def test_openrouter_base_url_does_not_get_codex_headers(self):
from run_agent import AIAgent
with patch("run_agent.OpenAI") as mock_openai:
mock_openai.return_value = MagicMock()
AIAgent(
api_key="sk-or-test",
base_url="https://openrouter.ai/api/v1",
provider="openrouter",
model="anthropic/claude-sonnet-4.6",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
headers = mock_openai.call_args.kwargs.get("default_headers") or {}
assert headers.get("originator") != "codex_cli_rs"
# ---------------------------------------------------------------------------
# Auxiliary client wiring (agent.auxiliary_client)
# ---------------------------------------------------------------------------
class TestAuxiliaryClientWiring:
def test_try_codex_passes_codex_headers(self, monkeypatch):
"""_try_codex builds the OpenAI client used for compression / vision /
title generation when routed through Codex. Must emit codex headers."""
from agent import auxiliary_client
token = _make_codex_jwt("acct-aux-try-codex")
# Force _select_pool_entry to return "no pool" so we fall through to
# _read_codex_access_token.
monkeypatch.setattr(
auxiliary_client, "_select_pool_entry",
lambda provider: (False, None),
)
monkeypatch.setattr(
auxiliary_client, "_read_codex_access_token",
lambda: token,
)
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
mock_openai.return_value = MagicMock()
client, model = auxiliary_client._try_codex()
assert client is not None
headers = mock_openai.call_args.kwargs.get("default_headers") or {}
assert headers.get("originator") == "codex_cli_rs"
assert headers.get("ChatGPT-Account-ID") == "acct-aux-try-codex"
assert headers.get("User-Agent", "").startswith("codex_cli_rs/")
def test_resolve_provider_client_raw_codex_passes_codex_headers(self, monkeypatch):
"""The ``raw_codex=True`` branch (used by the main agent loop for direct
responses.stream() access) must also emit codex headers."""
from agent import auxiliary_client
token = _make_codex_jwt("acct-aux-raw-codex")
monkeypatch.setattr(
auxiliary_client, "_read_codex_access_token",
lambda: token,
)
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
mock_openai.return_value = MagicMock()
client, model = auxiliary_client.resolve_provider_client(
"openai-codex", raw_codex=True,
)
assert client is not None
headers = mock_openai.call_args.kwargs.get("default_headers") or {}
assert headers.get("originator") == "codex_cli_rs"
assert headers.get("ChatGPT-Account-ID") == "acct-aux-raw-codex"
assert headers.get("User-Agent", "").startswith("codex_cli_rs/")

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import subprocess
from pathlib import Path
from unittest.mock import patch
import pytest
@ -124,6 +125,31 @@ def test_expand_file_range_and_folder_listing(sample_repo: Path):
assert not result.warnings
def test_folder_listing_falls_back_when_rg_is_blocked(sample_repo: Path):
from agent.context_references import preprocess_context_references
real_run = subprocess.run
def blocked_rg(*args, **kwargs):
cmd = args[0] if args else kwargs.get("args")
if isinstance(cmd, list) and cmd and cmd[0] == "rg":
raise PermissionError("rg blocked by policy")
return real_run(*args, **kwargs)
with patch("agent.context_references.subprocess.run", side_effect=blocked_rg):
result = preprocess_context_references(
"Review @folder:src/",
cwd=sample_repo,
context_length=100_000,
)
assert result.expanded
assert "src/" in result.message
assert "main.py" in result.message
assert "helper.py" in result.message
assert not result.warnings
def test_expand_quoted_file_reference_with_spaces(tmp_path: Path):
from agent.context_references import preprocess_context_references

View file

@ -0,0 +1,146 @@
"""Focused regressions for the Copilot ACP shim safety layer."""
from __future__ import annotations
import io
import json
import os
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
from agent.copilot_acp_client import CopilotACPClient
class _FakeProcess:
def __init__(self) -> None:
self.stdin = io.StringIO()
class CopilotACPClientSafetyTests(unittest.TestCase):
def setUp(self) -> None:
self.client = CopilotACPClient(acp_cwd="/tmp")
def _dispatch(self, message: dict, *, cwd: str) -> dict:
process = _FakeProcess()
handled = self.client._handle_server_message(
message,
process=process,
cwd=cwd,
text_parts=[],
reasoning_parts=[],
)
self.assertTrue(handled)
payload = process.stdin.getvalue().strip()
self.assertTrue(payload)
return json.loads(payload)
def test_request_permission_is_not_auto_allowed(self) -> None:
response = self._dispatch(
{
"jsonrpc": "2.0",
"id": 1,
"method": "session/request_permission",
"params": {},
},
cwd="/tmp",
)
outcome = (((response.get("result") or {}).get("outcome") or {}).get("outcome"))
self.assertEqual(outcome, "cancelled")
def test_read_text_file_blocks_internal_hermes_hub_files(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
home = Path(tmpdir) / "home"
blocked = home / ".hermes" / "skills" / ".hub" / "index-cache" / "entry.json"
blocked.parent.mkdir(parents=True, exist_ok=True)
blocked.write_text('{"token":"sk-test-secret-1234567890"}')
with patch.dict(
os.environ,
{"HOME": str(home), "HERMES_HOME": str(home / ".hermes")},
clear=False,
):
response = self._dispatch(
{
"jsonrpc": "2.0",
"id": 2,
"method": "fs/read_text_file",
"params": {"path": str(blocked)},
},
cwd=str(home),
)
self.assertIn("error", response)
def test_read_text_file_redacts_sensitive_content(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
root = Path(tmpdir)
secret_file = root / "config.env"
secret_file.write_text("OPENAI_API_KEY=sk-proj-abc123def456ghi789jkl012")
response = self._dispatch(
{
"jsonrpc": "2.0",
"id": 3,
"method": "fs/read_text_file",
"params": {"path": str(secret_file)},
},
cwd=str(root),
)
content = ((response.get("result") or {}).get("content") or "")
self.assertNotIn("abc123def456", content)
self.assertIn("OPENAI_API_KEY=", content)
def test_write_text_file_reuses_write_denylist(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
home = Path(tmpdir) / "home"
target = home / ".ssh" / "id_rsa"
target.parent.mkdir(parents=True, exist_ok=True)
with patch("agent.copilot_acp_client.is_write_denied", return_value=True, create=True):
response = self._dispatch(
{
"jsonrpc": "2.0",
"id": 4,
"method": "fs/write_text_file",
"params": {
"path": str(target),
"content": "fake-private-key",
},
},
cwd=str(home),
)
self.assertIn("error", response)
self.assertFalse(target.exists())
def test_write_text_file_respects_safe_root(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
root = Path(tmpdir)
safe_root = root / "workspace"
safe_root.mkdir()
outside = root / "outside.txt"
with patch.dict(os.environ, {"HERMES_WRITE_SAFE_ROOT": str(safe_root)}, clear=False):
response = self._dispatch(
{
"jsonrpc": "2.0",
"id": 5,
"method": "fs/write_text_file",
"params": {
"path": str(outside),
"content": "should-not-write",
},
},
cwd=str(root),
)
self.assertIn("error", response)
self.assertFalse(outside.exists())
if __name__ == "__main__":
unittest.main()

View file

@ -1,129 +1,25 @@
"""Tests for credential pool preservation through smart routing and 429 recovery.
"""Tests for credential pool preservation through turn config and 429 recovery.
Covers:
1. credential_pool flows through resolve_turn_route (no-route and fallback paths)
2. CLI _resolve_turn_agent_config passes credential_pool to primary dict
3. Gateway _resolve_turn_agent_config passes credential_pool to primary dict
4. Eager fallback deferred when credential pool has credentials
5. Eager fallback fires when no credential pool exists
6. Full 429 rotation cycle: retry-same rotate exhaust fallback
1. CLI _resolve_turn_agent_config passes credential_pool to runtime dict
2. Gateway _resolve_turn_agent_config passes credential_pool to runtime dict
3. Eager fallback deferred when credential pool has credentials
4. Eager fallback fires when no credential pool exists
5. Full 429 rotation cycle: retry-same rotate exhaust fallback
"""
import os
import time
from types import SimpleNamespace
from unittest.mock import MagicMock, patch, PropertyMock
import pytest
from unittest.mock import MagicMock, patch
# ---------------------------------------------------------------------------
# 1. smart_model_routing: credential_pool preserved in no-route path
# ---------------------------------------------------------------------------
class TestSmartRoutingPoolPreservation:
def test_no_route_preserves_credential_pool(self):
from agent.smart_model_routing import resolve_turn_route
fake_pool = MagicMock(name="CredentialPool")
primary = {
"model": "gpt-5.4",
"api_key": "sk-test",
"base_url": None,
"provider": "openai-codex",
"api_mode": "codex_responses",
"command": None,
"args": [],
"credential_pool": fake_pool,
}
# routing disabled
result = resolve_turn_route("hello", None, primary)
assert result["runtime"]["credential_pool"] is fake_pool
def test_no_route_none_pool(self):
from agent.smart_model_routing import resolve_turn_route
primary = {
"model": "gpt-5.4",
"api_key": "sk-test",
"base_url": None,
"provider": "openai-codex",
"api_mode": "codex_responses",
"command": None,
"args": [],
}
result = resolve_turn_route("hello", None, primary)
assert result["runtime"]["credential_pool"] is None
def test_routing_disabled_preserves_pool(self):
from agent.smart_model_routing import resolve_turn_route
fake_pool = MagicMock(name="CredentialPool")
primary = {
"model": "gpt-5.4",
"api_key": "sk-test",
"base_url": None,
"provider": "openai-codex",
"api_mode": "codex_responses",
"command": None,
"args": [],
"credential_pool": fake_pool,
}
# routing explicitly disabled
result = resolve_turn_route("hello", {"enabled": False}, primary)
assert result["runtime"]["credential_pool"] is fake_pool
def test_route_fallback_on_resolve_error_preserves_pool(self, monkeypatch):
"""When smart routing picks a cheap model but resolve_runtime_provider
fails, the fallback to primary must still include credential_pool."""
from agent.smart_model_routing import resolve_turn_route
fake_pool = MagicMock(name="CredentialPool")
primary = {
"model": "gpt-5.4",
"api_key": "sk-test",
"base_url": None,
"provider": "openai-codex",
"api_mode": "codex_responses",
"command": None,
"args": [],
"credential_pool": fake_pool,
}
routing_config = {
"enabled": True,
"cheap_model": "openai/gpt-4.1-mini",
"cheap_provider": "openrouter",
"max_tokens": 200,
"patterns": ["^(hi|hello|hey)"],
}
# Force resolve_runtime_provider to fail so it falls back to primary
monkeypatch.setattr(
"hermes_cli.runtime_provider.resolve_runtime_provider",
MagicMock(side_effect=RuntimeError("no credentials")),
)
result = resolve_turn_route("hi", routing_config, primary)
assert result["runtime"]["credential_pool"] is fake_pool
# ---------------------------------------------------------------------------
# 2 & 3. CLI and Gateway _resolve_turn_agent_config include credential_pool
# 1. CLI _resolve_turn_agent_config includes credential_pool
# ---------------------------------------------------------------------------
class TestCliTurnRoutePool:
def test_resolve_turn_includes_pool(self, monkeypatch, tmp_path):
"""CLI's _resolve_turn_agent_config must pass credential_pool to primary."""
from agent.smart_model_routing import resolve_turn_route
captured = {}
def spy_resolve(user_message, routing_config, primary):
captured["primary"] = primary
return resolve_turn_route(user_message, routing_config, primary)
monkeypatch.setattr(
"agent.smart_model_routing.resolve_turn_route", spy_resolve
)
# Build a minimal HermesCLI-like object with the method
def test_resolve_turn_includes_pool(self):
"""CLI's _resolve_turn_agent_config must pass credential_pool in runtime."""
fake_pool = MagicMock(name="FakePool")
shell = SimpleNamespace(
model="gpt-5.4",
api_key="sk-test",
@ -132,58 +28,46 @@ class TestCliTurnRoutePool:
api_mode="codex_responses",
acp_command=None,
acp_args=[],
_credential_pool=MagicMock(name="FakePool"),
_smart_model_routing={"enabled": False},
_credential_pool=fake_pool,
service_tier=None,
)
# Import and bind the real method
from cli import HermesCLI
bound = HermesCLI._resolve_turn_agent_config.__get__(shell)
bound("test message")
route = bound("test message")
assert "credential_pool" in captured["primary"]
assert captured["primary"]["credential_pool"] is shell._credential_pool
assert route["runtime"]["credential_pool"] is fake_pool
# ---------------------------------------------------------------------------
# 2. Gateway _resolve_turn_agent_config includes credential_pool
# ---------------------------------------------------------------------------
class TestGatewayTurnRoutePool:
def test_resolve_turn_includes_pool(self, monkeypatch):
def test_resolve_turn_includes_pool(self):
"""Gateway's _resolve_turn_agent_config must pass credential_pool."""
from agent.smart_model_routing import resolve_turn_route
captured = {}
def spy_resolve(user_message, routing_config, primary):
captured["primary"] = primary
return resolve_turn_route(user_message, routing_config, primary)
monkeypatch.setattr(
"agent.smart_model_routing.resolve_turn_route", spy_resolve
)
from gateway.run import GatewayRunner
runner = SimpleNamespace(
_smart_model_routing={"enabled": False},
)
fake_pool = MagicMock(name="FakePool")
runner = SimpleNamespace(_service_tier=None)
runtime_kwargs = {
"api_key": "sk-test",
"api_key": "***",
"base_url": None,
"provider": "openai-codex",
"api_mode": "codex_responses",
"command": None,
"args": [],
"credential_pool": MagicMock(name="FakePool"),
"credential_pool": fake_pool,
}
bound = GatewayRunner._resolve_turn_agent_config.__get__(runner)
bound("test message", "gpt-5.4", runtime_kwargs)
route = bound("test message", "gpt-5.4", runtime_kwargs)
assert "credential_pool" in captured["primary"]
assert captured["primary"]["credential_pool"] is runtime_kwargs["credential_pool"]
assert route["runtime"]["credential_pool"] is fake_pool
# ---------------------------------------------------------------------------
# 4 & 5. Eager fallback deferred/fires based on credential pool
# 3 & 4. Eager fallback deferred/fires based on credential pool
# ---------------------------------------------------------------------------
class TestEagerFallbackWithPool:
@ -251,7 +135,7 @@ class TestEagerFallbackWithPool:
# ---------------------------------------------------------------------------
# 6. Full 429 rotation cycle via _recover_with_credential_pool
# 5. Full 429 rotation cycle via _recover_with_credential_pool
# ---------------------------------------------------------------------------
class TestPoolRotationCycle:

View file

@ -0,0 +1,27 @@
from __future__ import annotations
from run_agent import AIAgent
def _agent_with_base_url(base_url: str) -> AIAgent:
agent = object.__new__(AIAgent)
agent.base_url = base_url
return agent
def test_direct_openai_url_requires_openai_host():
agent = _agent_with_base_url("https://api.openai.com.example/v1")
assert agent._is_direct_openai_url() is False
def test_direct_openai_url_ignores_path_segment_match():
agent = _agent_with_base_url("https://proxy.example.test/api.openai.com/v1")
assert agent._is_direct_openai_url() is False
def test_direct_openai_url_accepts_native_host():
agent = _agent_with_base_url("https://api.openai.com/v1")
assert agent._is_direct_openai_url() is True

View file

@ -83,6 +83,13 @@ class TestBuildToolPreview:
assert result is not None
assert "user" in result
def test_memory_replace_missing_old_text_marked(self):
# Avoid empty quotes "" in the preview when old_text is missing/None.
result = build_tool_preview("memory", {"action": "replace", "target": "memory"})
assert result == '~memory: "<missing old_text>"'
result = build_tool_preview("memory", {"action": "remove", "target": "memory", "old_text": None})
assert result == '-memory: "<missing old_text>"'
def test_session_search_preview(self):
result = build_tool_preview("session_search", {"query": "find something"})
assert result is not None

View file

@ -298,9 +298,15 @@ class TestClassifyApiError:
assert result.retryable is False
def test_404_generic(self):
# Generic 404 with no "model not found" signal — common for local
# llama.cpp/Ollama/vLLM endpoints with slightly wrong paths. Treat
# as unknown (retryable) so the real error surfaces, rather than
# claiming the model is missing and silently falling back.
e = MockAPIError("Not Found", status_code=404)
result = classify_api_error(e)
assert result.reason == FailoverReason.model_not_found
assert result.reason == FailoverReason.unknown
assert result.retryable is True
assert result.should_fallback is False
# ── Payload too large ──
@ -849,3 +855,97 @@ class TestAdversarialEdgeCases:
)
result = classify_api_error(e, provider="openrouter")
assert result.reason == FailoverReason.model_not_found
# ── Regression: dict-typed message field (Issue #11233) ──
def test_pydantic_dict_message_no_crash(self):
"""Pydantic validation errors return message as dict, not string.
Regression: classify_api_error must not crash when body['message']
is a dict (e.g. {"detail": [...]} from FastAPI/Pydantic). The
'or ""' fallback only handles None/falsy values a non-empty
dict is truthy and passed to .lower(), causing AttributeError.
"""
e = MockAPIError(
"Unprocessable Entity",
status_code=422,
body={
"object": "error",
"message": {
"detail": [
{
"type": "extra_forbidden",
"loc": ["body", "think"],
"msg": "Extra inputs are not permitted",
}
]
},
},
)
result = classify_api_error(e)
assert result.reason == FailoverReason.format_error
assert result.status_code == 422
assert result.retryable is False
def test_nested_error_dict_message_no_crash(self):
"""Nested body['error']['message'] as dict must not crash.
Some providers wrap Pydantic errors in an 'error' object.
"""
e = MockAPIError(
"Validation error",
status_code=400,
body={
"error": {
"message": {
"detail": [
{"type": "missing", "loc": ["body", "required"]}
]
}
}
},
)
result = classify_api_error(e, approx_tokens=1000)
assert result.reason == FailoverReason.format_error
assert result.status_code == 400
def test_metadata_raw_dict_message_no_crash(self):
"""OpenRouter metadata.raw with dict message must not crash."""
e = MockAPIError(
"Provider error",
status_code=400,
body={
"error": {
"message": "Provider error",
"metadata": {
"raw": '{"error":{"message":{"detail":[{"type":"invalid"}]}}}'
}
}
},
)
result = classify_api_error(e)
assert result.reason == FailoverReason.format_error
# Broader non-string type guards — defense against other provider quirks.
def test_list_message_no_crash(self):
"""Some providers return message as a list of error entries."""
e = MockAPIError(
"validation",
status_code=400,
body={"message": [{"msg": "field required"}]},
)
result = classify_api_error(e)
assert result is not None
def test_int_message_no_crash(self):
"""Any non-string type must be coerced safely."""
e = MockAPIError("server error", status_code=500, body={"message": 42})
result = classify_api_error(e)
assert result is not None
def test_none_message_still_works(self):
"""Regression: None fallback (the 'or \"\"' path) must still work."""
e = MockAPIError("server error", status_code=500, body={"message": None})
result = classify_api_error(e)
assert result is not None

View file

@ -652,6 +652,42 @@ class TestBuildGeminiRequest:
assert decls[0]["description"] == "foo"
assert decls[0]["parameters"] == {"type": "object"}
def test_tools_strip_json_schema_only_fields_from_parameters(self):
from agent.gemini_cloudcode_adapter import build_gemini_request
req = build_gemini_request(
messages=[{"role": "user", "content": "hi"}],
tools=[
{"type": "function", "function": {
"name": "fn1",
"description": "foo",
"parameters": {
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"additionalProperties": False,
"properties": {
"city": {
"type": "string",
"$schema": "ignored",
"description": "City name",
"additionalProperties": False,
}
},
"required": ["city"],
},
}},
],
)
params = req["tools"][0]["functionDeclarations"][0]["parameters"]
assert "$schema" not in params
assert "additionalProperties" not in params
assert params["type"] == "object"
assert params["required"] == ["city"]
assert params["properties"]["city"] == {
"type": "string",
"description": "City name",
}
def test_tool_choice_auto(self):
from agent.gemini_cloudcode_adapter import build_gemini_request
@ -814,6 +850,69 @@ class TestTranslateGeminiResponse:
assert _map_gemini_finish_reason("RECITATION") == "content_filter"
class TestTranslateStreamEvent:
def test_parallel_calls_to_same_tool_get_unique_indices(self):
"""Gemini may emit several functionCall parts with the same name in a
single turn (e.g. parallel file reads). Each must get its own OpenAI
``index`` otherwise downstream aggregators collapse them into one.
"""
from agent.gemini_cloudcode_adapter import _translate_stream_event
event = {
"response": {
"candidates": [{
"content": {"parts": [
{"functionCall": {"name": "read_file", "args": {"path": "a"}}},
{"functionCall": {"name": "read_file", "args": {"path": "b"}}},
{"functionCall": {"name": "read_file", "args": {"path": "c"}}},
]},
}],
}
}
counter = [0]
chunks = _translate_stream_event(event, model="gemini-2.5-flash",
tool_call_counter=counter)
indices = [c.choices[0].delta.tool_calls[0].index for c in chunks]
assert indices == [0, 1, 2]
assert counter[0] == 3
def test_counter_persists_across_events(self):
"""Index assignment must continue across SSE events in the same stream."""
from agent.gemini_cloudcode_adapter import _translate_stream_event
def _event(name):
return {"response": {"candidates": [{
"content": {"parts": [{"functionCall": {"name": name, "args": {}}}]},
}]}}
counter = [0]
chunks_a = _translate_stream_event(_event("foo"), model="m", tool_call_counter=counter)
chunks_b = _translate_stream_event(_event("bar"), model="m", tool_call_counter=counter)
chunks_c = _translate_stream_event(_event("foo"), model="m", tool_call_counter=counter)
assert chunks_a[0].choices[0].delta.tool_calls[0].index == 0
assert chunks_b[0].choices[0].delta.tool_calls[0].index == 1
assert chunks_c[0].choices[0].delta.tool_calls[0].index == 2
def test_finish_reason_switches_to_tool_calls_when_any_seen(self):
from agent.gemini_cloudcode_adapter import _translate_stream_event
counter = [0]
# First event emits one tool call.
_translate_stream_event(
{"response": {"candidates": [{
"content": {"parts": [{"functionCall": {"name": "x", "args": {}}}]},
}]}},
model="m", tool_call_counter=counter,
)
# Second event carries only the terminal finishReason.
chunks = _translate_stream_event(
{"response": {"candidates": [{"finishReason": "STOP"}]}},
model="m", tool_call_counter=counter,
)
assert chunks[-1].choices[0].finish_reason == "tool_calls"
class TestGeminiCloudCodeClient:
def test_client_exposes_openai_interface(self):
from agent.gemini_cloudcode_adapter import GeminiCloudCodeClient

View file

@ -0,0 +1,315 @@
"""Tests for the native Google AI Studio Gemini adapter."""
from __future__ import annotations
import json
from types import SimpleNamespace
import pytest
class DummyResponse:
def __init__(self, status_code=200, payload=None, headers=None, text=None):
self.status_code = status_code
self._payload = payload or {}
self.headers = headers or {}
self.text = text if text is not None else json.dumps(self._payload)
def json(self):
return self._payload
def test_build_native_request_preserves_thought_signature_on_tool_replay():
from agent.gemini_native_adapter import build_gemini_request
request = build_gemini_request(
messages=[
{"role": "system", "content": "Be helpful."},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
"extra_content": {
"google": {"thought_signature": "sig-123"}
},
}
],
},
],
tools=[],
tool_choice=None,
)
parts = request["contents"][0]["parts"]
assert parts[0]["functionCall"]["name"] == "get_weather"
assert parts[0]["thoughtSignature"] == "sig-123"
def test_build_native_request_uses_original_function_name_for_tool_result():
from agent.gemini_native_adapter import build_gemini_request
request = build_gemini_request(
messages=[
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": '{"forecast": "sunny"}',
},
],
tools=[],
tool_choice=None,
)
tool_response = request["contents"][1]["parts"][0]["functionResponse"]
assert tool_response["name"] == "get_weather"
def test_build_native_request_strips_json_schema_only_fields_from_tool_parameters():
from agent.gemini_native_adapter import build_gemini_request
request = build_gemini_request(
messages=[{"role": "user", "content": "Hello"}],
tools=[
{
"type": "function",
"function": {
"name": "lookup_weather",
"description": "Weather lookup",
"parameters": {
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"additionalProperties": False,
"properties": {
"city": {
"type": "string",
"$schema": "ignored",
"description": "City name",
}
},
"required": ["city"],
},
},
}
],
tool_choice=None,
)
params = request["tools"][0]["functionDeclarations"][0]["parameters"]
assert "$schema" not in params
assert "additionalProperties" not in params
assert params["type"] == "object"
assert params["properties"]["city"] == {
"type": "string",
"description": "City name",
}
def test_translate_native_response_surfaces_reasoning_and_tool_calls():
from agent.gemini_native_adapter import translate_gemini_response
payload = {
"candidates": [
{
"content": {
"parts": [
{"thought": True, "text": "thinking..."},
{"functionCall": {"name": "search", "args": {"q": "hermes"}}},
]
},
"finishReason": "STOP",
}
],
"usageMetadata": {
"promptTokenCount": 10,
"candidatesTokenCount": 5,
"totalTokenCount": 15,
},
}
response = translate_gemini_response(payload, model="gemini-2.5-flash")
choice = response.choices[0]
assert choice.finish_reason == "tool_calls"
assert choice.message.reasoning == "thinking..."
assert choice.message.tool_calls[0].function.name == "search"
assert json.loads(choice.message.tool_calls[0].function.arguments) == {"q": "hermes"}
def test_native_client_uses_x_goog_api_key_and_native_models_endpoint(monkeypatch):
from agent.gemini_native_adapter import GeminiNativeClient
recorded = {}
class DummyHTTP:
def post(self, url, json=None, headers=None, timeout=None):
recorded["url"] = url
recorded["json"] = json
recorded["headers"] = headers
return DummyResponse(
payload={
"candidates": [
{
"content": {"parts": [{"text": "hello"}]},
"finishReason": "STOP",
}
],
"usageMetadata": {
"promptTokenCount": 1,
"candidatesTokenCount": 1,
"totalTokenCount": 2,
},
}
)
def close(self):
return None
monkeypatch.setattr("agent.gemini_native_adapter.httpx.Client", lambda *a, **k: DummyHTTP())
client = GeminiNativeClient(api_key="AIza-test", base_url="https://generativelanguage.googleapis.com/v1beta")
response = client.chat.completions.create(
model="gemini-2.5-flash",
messages=[{"role": "user", "content": "Hello"}],
)
assert recorded["url"] == "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent"
assert recorded["headers"]["x-goog-api-key"] == "AIza-test"
assert "Authorization" not in recorded["headers"]
assert response.choices[0].message.content == "hello"
def test_native_http_error_keeps_status_and_retry_after():
from agent.gemini_native_adapter import gemini_http_error
response = DummyResponse(
status_code=429,
headers={"Retry-After": "17"},
payload={
"error": {
"code": 429,
"message": "quota exhausted",
"status": "RESOURCE_EXHAUSTED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"reason": "RESOURCE_EXHAUSTED",
"metadata": {"service": "generativelanguage.googleapis.com"},
}
],
}
},
)
err = gemini_http_error(response)
assert getattr(err, "status_code", None) == 429
assert getattr(err, "retry_after", None) == 17.0
assert "quota exhausted" in str(err)
def test_native_client_accepts_injected_http_client():
from agent.gemini_native_adapter import GeminiNativeClient
injected = SimpleNamespace(close=lambda: None)
client = GeminiNativeClient(api_key="AIza-test", http_client=injected)
assert client._http is injected
@pytest.mark.asyncio
async def test_async_native_client_streams_without_requiring_async_iterator_from_sync_client():
from agent.gemini_native_adapter import AsyncGeminiNativeClient
chunk = SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="hi"), finish_reason=None)])
sync_stream = iter([chunk])
def _advance(iterator):
try:
return False, next(iterator)
except StopIteration:
return True, None
sync_client = SimpleNamespace(
api_key="AIza-test",
base_url="https://generativelanguage.googleapis.com/v1beta",
chat=SimpleNamespace(completions=SimpleNamespace(create=lambda **kwargs: sync_stream)),
_advance_stream_iterator=_advance,
close=lambda: None,
)
async_client = AsyncGeminiNativeClient(sync_client)
stream = await async_client.chat.completions.create(stream=True)
collected = []
async for item in stream:
collected.append(item)
assert collected == [chunk]
def test_stream_event_translation_emits_tool_call_delta_with_stable_index():
from agent.gemini_native_adapter import translate_stream_event
tool_call_indices = {}
event = {
"candidates": [
{
"content": {
"parts": [
{"functionCall": {"name": "search", "args": {"q": "abc"}}}
]
},
"finishReason": "STOP",
}
]
}
first = translate_stream_event(event, model="gemini-2.5-flash", tool_call_indices=tool_call_indices)
second = translate_stream_event(event, model="gemini-2.5-flash", tool_call_indices=tool_call_indices)
assert first[0].choices[0].delta.tool_calls[0].index == 0
assert second[0].choices[0].delta.tool_calls[0].index == 0
assert first[0].choices[0].delta.tool_calls[0].id == second[0].choices[0].delta.tool_calls[0].id
assert first[0].choices[0].delta.tool_calls[0].function.arguments == '{"q": "abc"}'
assert second[0].choices[0].delta.tool_calls[0].function.arguments == ""
assert first[-1].choices[0].finish_reason == "tool_calls"
def test_stream_event_translation_keeps_identical_calls_in_distinct_parts():
from agent.gemini_native_adapter import translate_stream_event
event = {
"candidates": [
{
"content": {
"parts": [
{"functionCall": {"name": "search", "args": {"q": "abc"}}},
{"functionCall": {"name": "search", "args": {"q": "abc"}}},
]
},
"finishReason": "STOP",
}
]
}
chunks = translate_stream_event(event, model="gemini-2.5-flash", tool_call_indices={})
tool_chunks = [chunk for chunk in chunks if chunk.choices[0].delta.tool_calls]
assert tool_chunks[0].choices[0].delta.tool_calls[0].index == 0
assert tool_chunks[1].choices[0].delta.tool_calls[0].index == 1
assert tool_chunks[0].choices[0].delta.tool_calls[0].id != tool_chunks[1].choices[0].delta.tool_calls[0].id

View file

@ -0,0 +1,111 @@
"""Tests for agent/image_gen_registry.py — provider registration & active lookup."""
from __future__ import annotations
import pytest
from agent import image_gen_registry
from agent.image_gen_provider import ImageGenProvider
class _FakeProvider(ImageGenProvider):
def __init__(self, name: str, available: bool = True):
self._name = name
self._available = available
@property
def name(self) -> str:
return self._name
def is_available(self) -> bool:
return self._available
def generate(self, prompt, aspect_ratio="landscape", **kw):
return {"success": True, "image": f"{self._name}://{prompt}"}
@pytest.fixture(autouse=True)
def _reset_registry():
image_gen_registry._reset_for_tests()
yield
image_gen_registry._reset_for_tests()
class TestRegisterProvider:
def test_register_and_lookup(self):
provider = _FakeProvider("fake")
image_gen_registry.register_provider(provider)
assert image_gen_registry.get_provider("fake") is provider
def test_rejects_non_provider(self):
with pytest.raises(TypeError):
image_gen_registry.register_provider("not a provider") # type: ignore[arg-type]
def test_rejects_empty_name(self):
class Empty(ImageGenProvider):
@property
def name(self) -> str:
return ""
def generate(self, prompt, aspect_ratio="landscape", **kw):
return {}
with pytest.raises(ValueError):
image_gen_registry.register_provider(Empty())
def test_reregister_overwrites(self):
a = _FakeProvider("same")
b = _FakeProvider("same")
image_gen_registry.register_provider(a)
image_gen_registry.register_provider(b)
assert image_gen_registry.get_provider("same") is b
def test_list_is_sorted(self):
image_gen_registry.register_provider(_FakeProvider("zeta"))
image_gen_registry.register_provider(_FakeProvider("alpha"))
names = [p.name for p in image_gen_registry.list_providers()]
assert names == ["alpha", "zeta"]
class TestGetActiveProvider:
def test_single_provider_autoresolves(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
image_gen_registry.register_provider(_FakeProvider("solo"))
active = image_gen_registry.get_active_provider()
assert active is not None and active.name == "solo"
def test_fal_preferred_on_multi_without_config(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
image_gen_registry.register_provider(_FakeProvider("fal"))
image_gen_registry.register_provider(_FakeProvider("openai"))
active = image_gen_registry.get_active_provider()
assert active is not None and active.name == "fal"
def test_explicit_config_wins(self, tmp_path, monkeypatch):
import yaml
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
(tmp_path / "config.yaml").write_text(
yaml.safe_dump({"image_gen": {"provider": "openai"}})
)
image_gen_registry.register_provider(_FakeProvider("fal"))
image_gen_registry.register_provider(_FakeProvider("openai"))
active = image_gen_registry.get_active_provider()
assert active is not None and active.name == "openai"
def test_missing_configured_provider_falls_back(self, tmp_path, monkeypatch):
import yaml
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
(tmp_path / "config.yaml").write_text(
yaml.safe_dump({"image_gen": {"provider": "replicate"}})
)
# Only FAL is registered — configured provider doesn't exist
image_gen_registry.register_provider(_FakeProvider("fal"))
active = image_gen_registry.get_active_provider()
# Falls back to FAL preference (legacy default) rather than None
assert active is not None and active.name == "fal"
def test_none_when_empty(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
assert image_gen_registry.get_active_provider() is None

View file

@ -51,6 +51,12 @@ def populated_db(db):
db.append_message("s1", role="assistant", content="I found the bug. Let me fix it.",
tool_calls=[{"function": {"name": "patch"}}])
db.append_message("s1", role="tool", content="patched successfully", tool_name="patch")
db.append_message(
"s1",
role="assistant",
content="Let me load the PR workflow skill.",
tool_calls=[{"function": {"name": "skill_view", "arguments": '{"name":"github-pr-workflow"}'}}],
)
db.append_message("s1", role="user", content="Thanks!")
db.append_message("s1", role="assistant", content="You're welcome!")
@ -88,6 +94,12 @@ def populated_db(db):
db.append_message("s3", role="assistant", content="And search files",
tool_calls=[{"function": {"name": "search_files"}}])
db.append_message("s3", role="tool", content="found stuff", tool_name="search_files")
db.append_message(
"s3",
role="assistant",
content="Load the debugging skill.",
tool_calls=[{"function": {"name": "skill_view", "arguments": '{"name":"systematic-debugging"}'}}],
)
# Session 4: Discord, same model as s1, ended, 1 day ago
db.create_session(
@ -100,6 +112,15 @@ def populated_db(db):
db.update_token_counts("s4", input_tokens=10000, output_tokens=5000)
db.append_message("s4", role="user", content="Quick question")
db.append_message("s4", role="assistant", content="Sure, go ahead")
db.append_message(
"s4",
role="assistant",
content="Load and update GitHub skills.",
tool_calls=[
{"function": {"name": "skill_view", "arguments": '{"name":"github-pr-workflow"}'}},
{"function": {"name": "skill_manage", "arguments": '{"name":"github-code-review"}'}},
],
)
# Session 5: Old session, 45 days ago (should be excluded from 30-day window)
db.create_session(
@ -332,6 +353,35 @@ class TestInsightsPopulated:
total_pct = sum(t["percentage"] for t in tools)
assert total_pct == pytest.approx(100.0, abs=0.1)
def test_skill_breakdown(self, populated_db):
engine = InsightsEngine(populated_db)
report = engine.generate(days=30)
skills = report["skills"]
assert skills["summary"]["distinct_skills_used"] == 3
assert skills["summary"]["total_skill_loads"] == 3
assert skills["summary"]["total_skill_edits"] == 1
assert skills["summary"]["total_skill_actions"] == 4
top_skill = skills["top_skills"][0]
assert top_skill["skill"] == "github-pr-workflow"
assert top_skill["view_count"] == 2
assert top_skill["manage_count"] == 0
assert top_skill["total_count"] == 2
assert top_skill["last_used_at"] is not None
def test_skill_breakdown_respects_days_filter(self, populated_db):
engine = InsightsEngine(populated_db)
report = engine.generate(days=3)
skills = report["skills"]
assert skills["summary"]["distinct_skills_used"] == 2
assert skills["summary"]["total_skill_loads"] == 2
assert skills["summary"]["total_skill_edits"] == 1
skill_names = [s["skill"] for s in skills["top_skills"]]
assert "systematic-debugging" not in skill_names
def test_activity_patterns(self, populated_db):
engine = InsightsEngine(populated_db)
report = engine.generate(days=30)
@ -401,6 +451,7 @@ class TestTerminalFormatting:
assert "Overview" in text
assert "Models Used" in text
assert "Top Tools" in text
assert "Top Skills" in text
assert "Activity Patterns" in text
assert "Notable Sessions" in text
@ -465,12 +516,12 @@ class TestGatewayFormatting:
assert "**" in text # Markdown bold
def test_gateway_format_hides_cost(self, populated_db):
"""Gateway format omits dollar figures and internal cache details."""
engine = InsightsEngine(populated_db)
report = engine.generate(days=30)
text = engine.format_gateway(report)
assert "$" not in text
assert "Est. cost" not in text
assert "cache" not in text.lower()
def test_gateway_format_shows_models(self, populated_db):

View file

@ -0,0 +1,115 @@
"""Regression guard: don't send Anthropic ``thinking`` to Kimi's /coding endpoint.
Kimi's ``api.kimi.com/coding`` endpoint speaks the Anthropic Messages protocol
but has its own thinking semantics. When ``thinking.enabled`` is present in
the request, Kimi validates the message history and requires every prior
assistant tool-call message to carry OpenAI-style ``reasoning_content``.
The Anthropic path never populates that field, and
``convert_messages_to_anthropic`` strips Anthropic thinking blocks on
third-party endpoints so after one turn with tool calls the next request
fails with HTTP 400::
thinking is enabled but reasoning_content is missing in assistant
tool call message at index N
Kimi on the chat_completions route handles ``thinking`` via ``extra_body`` in
``ChatCompletionsTransport`` (#13503). On the Anthropic route the right
thing to do is drop the parameter entirely and let Kimi drive reasoning
server-side.
"""
from __future__ import annotations
import pytest
class TestKimiCodingSkipsAnthropicThinking:
"""build_anthropic_kwargs must not inject ``thinking`` for Kimi /coding."""
@pytest.mark.parametrize(
"base_url",
[
"https://api.kimi.com/coding",
"https://api.kimi.com/coding/v1",
"https://api.kimi.com/coding/anthropic",
"https://api.kimi.com/coding/",
],
)
def test_kimi_coding_endpoint_omits_thinking(self, base_url: str) -> None:
from agent.anthropic_adapter import build_anthropic_kwargs
kwargs = build_anthropic_kwargs(
model="kimi-k2.5",
messages=[{"role": "user", "content": "hello"}],
tools=None,
max_tokens=4096,
reasoning_config={"enabled": True, "effort": "medium"},
base_url=base_url,
)
assert "thinking" not in kwargs, (
"Anthropic thinking must not be sent to Kimi /coding — "
"endpoint requires reasoning_content on history we don't preserve."
)
assert "output_config" not in kwargs
def test_kimi_coding_with_explicit_disabled_also_omits(self) -> None:
from agent.anthropic_adapter import build_anthropic_kwargs
kwargs = build_anthropic_kwargs(
model="kimi-k2.5",
messages=[{"role": "user", "content": "hello"}],
tools=None,
max_tokens=4096,
reasoning_config={"enabled": False},
base_url="https://api.kimi.com/coding",
)
assert "thinking" not in kwargs
def test_non_kimi_third_party_still_gets_thinking(self) -> None:
"""MiniMax and other third-party Anthropic endpoints must retain thinking."""
from agent.anthropic_adapter import build_anthropic_kwargs
kwargs = build_anthropic_kwargs(
model="MiniMax-M2.7",
messages=[{"role": "user", "content": "hello"}],
tools=None,
max_tokens=4096,
reasoning_config={"enabled": True, "effort": "medium"},
base_url="https://api.minimax.io/anthropic",
)
assert "thinking" in kwargs
assert kwargs["thinking"]["type"] == "enabled"
def test_native_anthropic_still_gets_thinking(self) -> None:
from agent.anthropic_adapter import build_anthropic_kwargs
kwargs = build_anthropic_kwargs(
model="claude-sonnet-4-20250514",
messages=[{"role": "user", "content": "hello"}],
tools=None,
max_tokens=4096,
reasoning_config={"enabled": True, "effort": "medium"},
base_url=None,
)
assert "thinking" in kwargs
def test_kimi_root_endpoint_unaffected(self) -> None:
"""Only the /coding route is special-cased — plain api.kimi.com is not.
``api.kimi.com`` without ``/coding`` uses the chat_completions transport
(see runtime_provider._detect_api_mode_for_url); build_anthropic_kwargs
should never see it, but if it somehow does we should not suppress
thinking there that path has different semantics.
"""
from agent.anthropic_adapter import build_anthropic_kwargs
kwargs = build_anthropic_kwargs(
model="kimi-k2.5",
messages=[{"role": "user", "content": "hello"}],
tools=None,
max_tokens=4096,
reasoning_config={"enabled": True, "effort": "medium"},
base_url="https://api.kimi.com/v1",
)
assert "thinking" in kwargs

View file

@ -971,8 +971,6 @@ class TestHonchoCadenceTracking:
class FakeManager:
def prefetch_context(self, key, query=None):
pass
def prefetch_dialectic(self, key, query):
pass
p._manager = FakeManager()

View file

@ -79,6 +79,28 @@ class TestMemoryManagerUserIdThreading:
assert p._init_kwargs.get("platform") == "telegram"
assert p._init_session_id == "sess-123"
def test_chat_context_forwarded_to_provider(self):
mgr = MemoryManager()
p = RecordingProvider()
mgr.add_provider(p)
mgr.initialize_all(
session_id="sess-chat",
platform="discord",
user_id="discord_u_7",
user_name="fakeusername",
chat_id="1485316232612941897",
chat_name="fakeassistantname-forums",
chat_type="thread",
thread_id="1491249007475949698",
)
assert p._init_kwargs.get("user_name") == "fakeusername"
assert p._init_kwargs.get("chat_id") == "1485316232612941897"
assert p._init_kwargs.get("chat_name") == "fakeassistantname-forums"
assert p._init_kwargs.get("chat_type") == "thread"
assert p._init_kwargs.get("thread_id") == "1491249007475949698"
def test_no_user_id_when_cli(self):
"""CLI sessions should not have user_id in kwargs."""
mgr = MemoryManager()
@ -208,34 +230,81 @@ class TestMem0UserIdScoping:
class TestHonchoUserIdScoping:
"""Verify Honcho plugin uses gateway user_id for peer_name when provided."""
"""Verify Honcho plugin keeps runtime user scoping separate from config peer_name."""
def test_gateway_user_id_overrides_peer_name(self):
"""When user_id is in kwargs and no explicit peer_name, user_id should be used."""
def test_gateway_user_id_is_passed_as_runtime_peer(self):
"""Gateway user_id should scope Honcho sessions without mutating config peer_name."""
from plugins.memory.honcho import HonchoMemoryProvider
provider = HonchoMemoryProvider()
# Create a mock config with NO explicit peer_name
mock_cfg = MagicMock()
mock_cfg.enabled = True
mock_cfg.api_key = "test-key"
mock_cfg.base_url = None
mock_cfg.peer_name = "" # No explicit peer_name — user_id should fill it
mock_cfg.recall_mode = "tools" # Use tools mode to defer session init
mock_cfg.peer_name = "static-user"
mock_cfg.recall_mode = "context"
mock_cfg.context_tokens = None
mock_cfg.raw = {}
mock_cfg.dialectic_depth = 1
mock_cfg.dialectic_depth_levels = None
mock_cfg.init_on_session_start = False
mock_cfg.ai_peer = "hermes"
mock_cfg.resolve_session_name.return_value = "test-sess"
mock_cfg.session_strategy = "shared"
with patch(
"plugins.memory.honcho.client.HonchoClientConfig.from_global_config",
return_value=mock_cfg,
):
), patch(
"plugins.memory.honcho.client.get_honcho_client",
return_value=MagicMock(),
), patch(
"plugins.memory.honcho.session.HonchoSessionManager",
) as mock_manager_cls:
mock_manager = MagicMock()
mock_manager.get_or_create.return_value = MagicMock(messages=[])
mock_manager_cls.return_value = mock_manager
provider.initialize(
session_id="test-sess",
user_id="discord_user_789",
platform="discord",
)
# The config's peer_name should have been overridden with the user_id
assert mock_cfg.peer_name == "discord_user_789"
assert mock_cfg.peer_name == "static-user"
assert mock_manager_cls.call_args.kwargs["runtime_user_peer_name"] == "discord_user_789"
def test_session_manager_prefers_runtime_user_id_over_config_peer_name(self):
"""Session manager should isolate gateway users even when config peer_name is static."""
from plugins.memory.honcho.session import HonchoSessionManager
mock_cfg = MagicMock()
mock_cfg.peer_name = "static-user"
mock_cfg.ai_peer = "hermes"
mock_cfg.write_frequency = "sync"
mock_cfg.dialectic_reasoning_level = "low"
mock_cfg.dialectic_dynamic = True
mock_cfg.dialectic_max_chars = 600
mock_cfg.observation_mode = "directional"
mock_cfg.user_observe_me = True
mock_cfg.user_observe_others = True
mock_cfg.ai_observe_me = True
mock_cfg.ai_observe_others = True
manager = HonchoSessionManager(
honcho=MagicMock(),
config=mock_cfg,
runtime_user_peer_name="discord_user_789",
)
with patch.object(manager, "_get_or_create_peer", return_value=MagicMock()), patch.object(
manager,
"_get_or_create_honcho_session",
return_value=(MagicMock(), []),
):
session = manager.get_or_create("discord:channel-1")
assert session.user_peer_id == "discord_user_789"
def test_no_user_id_preserves_config_peer_name(self):
"""Without user_id, the config peer_name should be preserved."""
@ -287,3 +356,4 @@ class TestAIAgentUserIdPropagation:
agent = object.__new__(AIAgent)
agent._user_id = None
assert agent._user_id is None

View file

@ -84,38 +84,6 @@ class TestMinimaxAuxModel:
assert "highspeed" not in _API_KEY_PROVIDER_AUX_MODELS["minimax-cn"]
class TestMinimaxModelCatalog:
"""Verify the model catalog matches official Anthropic-compat endpoint models.
Source: https://platform.minimax.io/docs/api-reference/text-anthropic-api
"""
def test_catalog_includes_current_models(self):
from hermes_cli.models import _PROVIDER_MODELS
for provider in ("minimax", "minimax-cn"):
models = _PROVIDER_MODELS[provider]
assert "MiniMax-M2.7" in models
assert "MiniMax-M2.5" in models
assert "MiniMax-M2.1" in models
assert "MiniMax-M2" in models
def test_catalog_excludes_m1_family(self):
"""M1 models are not available on the /anthropic endpoint."""
from hermes_cli.models import _PROVIDER_MODELS
for provider in ("minimax", "minimax-cn"):
models = _PROVIDER_MODELS[provider]
assert "MiniMax-M1" not in models
def test_catalog_excludes_highspeed(self):
"""Highspeed variants are available but not shown in default catalog
(users can still specify them manually)."""
from hermes_cli.models import _PROVIDER_MODELS
for provider in ("minimax", "minimax-cn"):
models = _PROVIDER_MODELS[provider]
assert "MiniMax-M2.7-highspeed" not in models
assert "MiniMax-M2.5-highspeed" not in models
class TestMinimaxBetaHeaders:
"""MiniMax Anthropic-compat endpoints reject fine-grained-tool-streaming beta.

View file

@ -385,6 +385,7 @@ class TestStripProviderPrefix:
assert _strip_provider_prefix("local:my-model") == "my-model"
assert _strip_provider_prefix("openrouter:anthropic/claude-sonnet-4") == "anthropic/claude-sonnet-4"
assert _strip_provider_prefix("anthropic:claude-sonnet-4") == "claude-sonnet-4"
assert _strip_provider_prefix("stepfun:step-3.5-flash") == "step-3.5-flash"
def test_ollama_model_tag_preserved(self):
"""Ollama model:tag format must NOT be stripped."""

View file

@ -424,6 +424,68 @@ class TestQueryLocalContextLengthLmStudio:
)
class TestDetectLocalServerTypeAuth:
def test_passes_bearer_token_to_probe_requests(self):
from agent.model_metadata import detect_local_server_type
resp = MagicMock()
resp.status_code = 200
client_mock = MagicMock()
client_mock.__enter__ = lambda s: client_mock
client_mock.__exit__ = MagicMock(return_value=False)
client_mock.get.return_value = resp
with patch("httpx.Client", return_value=client_mock) as mock_client:
result = detect_local_server_type("http://localhost:1234/v1", api_key="lm-token")
assert result == "lm-studio"
assert mock_client.call_args.kwargs["headers"] == {
"Authorization": "Bearer lm-token"
}
class TestFetchEndpointModelMetadataLmStudio:
"""fetch_endpoint_model_metadata should use LM Studio's native models endpoint."""
def _make_resp(self, body):
resp = MagicMock()
resp.raise_for_status.return_value = None
resp.json.return_value = body
return resp
def test_uses_native_models_endpoint_only(self):
from agent.model_metadata import fetch_endpoint_model_metadata
native_resp = self._make_resp(
{
"models": [
{
"key": "lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf",
"id": "lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf",
"max_context_length": 131072,
}
]
}
)
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
patch("agent.model_metadata.requests.get", return_value=native_resp) as mock_get:
result = fetch_endpoint_model_metadata(
"http://localhost:1234/v1",
api_key="lm-token",
force_refresh=True,
)
assert mock_get.call_count == 1
assert mock_get.call_args[0][0] == "http://localhost:1234/api/v1/models"
assert mock_get.call_args.kwargs["headers"] == {
"Authorization": "Bearer lm-token"
}
assert result["lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf"]["context_length"] == 131072
assert result["Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf"]["context_length"] == 131072
class TestQueryLocalContextLengthNetworkError:
"""_query_local_context_length handles network failures gracefully."""

View file

@ -82,6 +82,7 @@ class TestProviderMapping:
def test_known_providers_mapped(self):
assert PROVIDER_TO_MODELS_DEV["anthropic"] == "anthropic"
assert PROVIDER_TO_MODELS_DEV["copilot"] == "github-copilot"
assert PROVIDER_TO_MODELS_DEV["stepfun"] == "stepfun"
assert PROVIDER_TO_MODELS_DEV["kilocode"] == "kilo"
assert PROVIDER_TO_MODELS_DEV["ai-gateway"] == "vercel"

View file

@ -354,6 +354,24 @@ class TestBuildSkillsSystemPrompt:
assert "web-search" in result
assert "old-tool" not in result
def test_rebuilds_prompt_when_disabled_skills_change(self, monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
skill_dir = tmp_path / "skills" / "tools" / "cached-skill"
skill_dir.mkdir(parents=True)
(skill_dir / "SKILL.md").write_text(
"---\nname: cached-skill\ndescription: Cached skill\n---\n"
)
first = build_skills_system_prompt()
assert "cached-skill" in first
(tmp_path / "config.yaml").write_text(
"skills:\n disabled: [cached-skill]\n"
)
second = build_skills_system_prompt()
assert "cached-skill" not in second
def test_includes_setup_needed_skills(self, monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.delenv("MISSING_API_KEY_XYZ", raising=False)
@ -771,6 +789,24 @@ class TestPromptBuilderConstants:
assert "cron" in PLATFORM_HINTS
assert "cli" in PLATFORM_HINTS
def test_cli_hint_does_not_suggest_media_tags(self):
# Regression: MEDIA:/path tags are intercepted only by messaging
# gateway platforms. On the CLI they render as literal text and
# confuse users. The CLI hint must steer the agent away from them.
cli_hint = PLATFORM_HINTS["cli"]
assert "MEDIA:" in cli_hint, (
"CLI hint should mention MEDIA: in order to tell the agent "
"NOT to use it (negative guidance)."
)
# Must contain explicit "don't" language near the MEDIA reference.
assert any(
marker in cli_hint.lower()
for marker in ("do not emit media", "not intercepted", "do not", "don't")
), "CLI hint should explicitly discourage MEDIA: tags."
# Messaging hints should still advertise MEDIA: positively (sanity
# check that this test is calibrated correctly).
assert "include MEDIA:" in PLATFORM_HINTS["telegram"]
# =========================================================================
# Environment hints

View file

@ -6,6 +6,8 @@ when proxy env vars or custom endpoint URLs are malformed.
"""
from __future__ import annotations
import os
import pytest
from agent.auxiliary_client import _validate_base_url, _validate_proxy_env_urls
@ -31,6 +33,12 @@ def test_proxy_env_accepts_empty(monkeypatch):
_validate_proxy_env_urls() # should not raise
def test_proxy_env_normalizes_socks_alias(monkeypatch):
monkeypatch.setenv("ALL_PROXY", "socks://127.0.0.1:1080/")
_validate_proxy_env_urls()
assert os.environ["ALL_PROXY"] == "socks5://127.0.0.1:1080/"
@pytest.mark.parametrize("key", [
"HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY",
"http_proxy", "https_proxy", "all_proxy",

View file

@ -376,3 +376,138 @@ class TestDiscordMentions:
result = redact_sensitive_text(text)
assert result.startswith("User ")
assert result.endswith(" said hello")
class TestUrlQueryParamRedaction:
"""URL query-string redaction (ported from nearai/ironclaw#2529).
Catches opaque tokens that don't match vendor prefix regexes by
matching on parameter NAME rather than value shape.
"""
def test_oauth_callback_code(self):
text = "GET https://api.example.com/oauth/cb?code=abc123xyz789&state=csrf_ok"
result = redact_sensitive_text(text)
assert "abc123xyz789" not in result
assert "code=***" in result
assert "state=csrf_ok" in result # state is not sensitive
def test_access_token_query(self):
text = "Fetching https://example.com/api?access_token=opaque_value_here_1234&format=json"
result = redact_sensitive_text(text)
assert "opaque_value_here_1234" not in result
assert "access_token=***" in result
assert "format=json" in result
def test_refresh_token_query(self):
text = "https://auth.example.com/token?refresh_token=somerefresh&grant_type=refresh"
result = redact_sensitive_text(text)
assert "somerefresh" not in result
assert "grant_type=refresh" in result
def test_api_key_query(self):
text = "https://api.example.com/v1/data?api_key=kABCDEF12345&limit=10"
result = redact_sensitive_text(text)
assert "kABCDEF12345" not in result
assert "limit=10" in result
def test_presigned_signature(self):
text = "https://s3.amazonaws.com/bucket/k?signature=LONG_PRESIGNED_SIG&id=public"
result = redact_sensitive_text(text)
assert "LONG_PRESIGNED_SIG" not in result
assert "id=public" in result
def test_case_insensitive_param_names(self):
"""Lowercase/mixed-case sensitive param names are redacted."""
# NOTE: All-caps names like TOKEN= are swallowed by _ENV_ASSIGN_RE
# (which matches KEY=value patterns greedily) before URL regex runs.
# This test uses lowercase names to isolate URL-query redaction.
text = "https://example.com?api_key=abcdef&secret=ghijkl"
result = redact_sensitive_text(text)
assert "abcdef" not in result
assert "ghijkl" not in result
assert "api_key=***" in result
assert "secret=***" in result
def test_substring_match_does_not_trigger(self):
"""`token_count` and `session_id` must NOT match `token` / `session`."""
text = "https://example.com/cb?token_count=42&session_id=xyz&foo=bar"
result = redact_sensitive_text(text)
assert "token_count=42" in result
assert "session_id=xyz" in result
def test_url_without_query_unchanged(self):
text = "https://example.com/path/to/resource"
assert redact_sensitive_text(text) == text
def test_url_with_fragment(self):
text = "https://example.com/page?token=xyz#section"
result = redact_sensitive_text(text)
assert "token=xyz" not in result
assert "#section" in result
def test_websocket_url_query(self):
text = "wss://api.example.com/ws?token=opaqueWsToken123"
result = redact_sensitive_text(text)
assert "opaqueWsToken123" not in result
class TestUrlUserinfoRedaction:
"""URL userinfo (`scheme://user:pass@host`) for non-DB schemes."""
def test_https_userinfo(self):
text = "URL: https://user:supersecretpw@host.example.com/path"
result = redact_sensitive_text(text)
assert "supersecretpw" not in result
assert "https://user:***@host.example.com" in result
def test_http_userinfo(self):
text = "http://admin:plaintextpass@internal.example.com/api"
result = redact_sensitive_text(text)
assert "plaintextpass" not in result
def test_ftp_userinfo(self):
text = "ftp://user:ftppass@ftp.example.com/file.txt"
result = redact_sensitive_text(text)
assert "ftppass" not in result
def test_url_without_userinfo_unchanged(self):
text = "https://example.com/path"
assert redact_sensitive_text(text) == text
def test_db_connstr_still_handled(self):
"""DB schemes are handled by _DB_CONNSTR_RE, not _URL_USERINFO_RE."""
text = "postgres://admin:dbpass@db.internal:5432/app"
result = redact_sensitive_text(text)
assert "dbpass" not in result
class TestFormBodyRedaction:
"""Form-urlencoded body redaction (k=v&k=v with no other text)."""
def test_pure_form_body(self):
text = "password=mysecret&username=bob&token=opaqueValue"
result = redact_sensitive_text(text)
assert "mysecret" not in result
assert "opaqueValue" not in result
assert "username=bob" in result
def test_oauth_token_request(self):
text = "grant_type=password&client_id=app&client_secret=topsecret&username=alice&password=alicepw"
result = redact_sensitive_text(text)
assert "topsecret" not in result
assert "alicepw" not in result
assert "client_id=app" in result
def test_non_form_text_unchanged(self):
"""Sentences with `&` should NOT trigger form redaction."""
text = "I have password=foo and other things" # contains spaces
result = redact_sensitive_text(text)
# The space breaks the form regex; passthrough expected.
assert "I have" in result
def test_multiline_text_not_form(self):
"""Multi-line text is never treated as form body."""
text = "first=1\nsecond=2"
# Should pass through (still subject to other redactors)
assert "first=1" in redact_sensitive_text(text)

View file

@ -0,0 +1,716 @@
"""Tests for the shell-hooks subprocess bridge (agent.shell_hooks).
These tests focus on the pure translation layer JSON serialisation,
JSON parsing, matcher behaviour, block-schema correctness, and the
subprocess runner's graceful error handling. Consent prompts are
covered in ``test_shell_hooks_consent.py``.
"""
from __future__ import annotations
import json
import os
import stat
from pathlib import Path
from typing import Any, Dict
import pytest
from agent import shell_hooks
# ── helpers ───────────────────────────────────────────────────────────────
def _write_script(tmp_path: Path, name: str, body: str) -> Path:
path = tmp_path / name
path.write_text(body)
path.chmod(0o755)
return path
def _allowlist_pair(monkeypatch, tmp_path, event: str, command: str) -> None:
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_home"))
shell_hooks._record_approval(event, command)
@pytest.fixture(autouse=True)
def _reset_registration_state():
shell_hooks.reset_for_tests()
yield
shell_hooks.reset_for_tests()
# ── _parse_response ───────────────────────────────────────────────────────
class TestParseResponse:
def test_block_claude_code_style(self):
r = shell_hooks._parse_response(
"pre_tool_call",
'{"decision": "block", "reason": "nope"}',
)
assert r == {"action": "block", "message": "nope"}
def test_block_canonical_style(self):
r = shell_hooks._parse_response(
"pre_tool_call",
'{"action": "block", "message": "nope"}',
)
assert r == {"action": "block", "message": "nope"}
def test_block_canonical_wins_over_claude_style(self):
r = shell_hooks._parse_response(
"pre_tool_call",
'{"action": "block", "message": "canonical", '
'"decision": "block", "reason": "claude"}',
)
assert r == {"action": "block", "message": "canonical"}
def test_empty_stdout_returns_none(self):
assert shell_hooks._parse_response("pre_tool_call", "") is None
assert shell_hooks._parse_response("pre_tool_call", " ") is None
def test_invalid_json_returns_none(self):
assert shell_hooks._parse_response("pre_tool_call", "not json") is None
def test_non_dict_json_returns_none(self):
assert shell_hooks._parse_response("pre_tool_call", "[1, 2]") is None
def test_non_block_pre_tool_call_returns_none(self):
r = shell_hooks._parse_response("pre_tool_call", '{"decision": "allow"}')
assert r is None
def test_pre_llm_call_context_passthrough(self):
r = shell_hooks._parse_response(
"pre_llm_call", '{"context": "today is Friday"}',
)
assert r == {"context": "today is Friday"}
def test_subagent_stop_context_passthrough(self):
r = shell_hooks._parse_response(
"subagent_stop", '{"context": "child role=leaf"}',
)
assert r == {"context": "child role=leaf"}
def test_pre_llm_call_block_ignored(self):
"""Only pre_tool_call honors block directives."""
r = shell_hooks._parse_response(
"pre_llm_call", '{"decision": "block", "reason": "no"}',
)
assert r is None
# ── _serialize_payload ────────────────────────────────────────────────────
class TestSerializePayload:
def test_basic_pre_tool_call_schema(self):
raw = shell_hooks._serialize_payload(
"pre_tool_call",
{
"tool_name": "terminal",
"args": {"command": "ls"},
"session_id": "sess-1",
"task_id": "t-1",
"tool_call_id": "c-1",
},
)
payload = json.loads(raw)
assert payload["hook_event_name"] == "pre_tool_call"
assert payload["tool_name"] == "terminal"
assert payload["tool_input"] == {"command": "ls"}
assert payload["session_id"] == "sess-1"
assert "cwd" in payload
# task_id / tool_call_id end up under extra
assert payload["extra"]["task_id"] == "t-1"
assert payload["extra"]["tool_call_id"] == "c-1"
def test_args_not_dict_becomes_null(self):
raw = shell_hooks._serialize_payload(
"pre_tool_call", {"args": ["not", "a", "dict"]},
)
payload = json.loads(raw)
assert payload["tool_input"] is None
def test_parent_session_id_used_when_no_session_id(self):
raw = shell_hooks._serialize_payload(
"subagent_stop", {"parent_session_id": "p-1"},
)
payload = json.loads(raw)
assert payload["session_id"] == "p-1"
def test_unserialisable_extras_stringified(self):
class Weird:
def __repr__(self) -> str:
return "<weird>"
raw = shell_hooks._serialize_payload(
"on_session_start", {"obj": Weird()},
)
payload = json.loads(raw)
assert payload["extra"]["obj"] == "<weird>"
# ── Matcher behaviour ─────────────────────────────────────────────────────
class TestMatcher:
def test_no_matcher_fires_for_any_tool(self):
spec = shell_hooks.ShellHookSpec(
event="pre_tool_call", command="echo", matcher=None,
)
assert spec.matches_tool("terminal")
assert spec.matches_tool("write_file")
def test_single_name_matcher(self):
spec = shell_hooks.ShellHookSpec(
event="pre_tool_call", command="echo", matcher="terminal",
)
assert spec.matches_tool("terminal")
assert not spec.matches_tool("web_search")
def test_alternation_matcher(self):
spec = shell_hooks.ShellHookSpec(
event="pre_tool_call", command="echo", matcher="terminal|file",
)
assert spec.matches_tool("terminal")
assert spec.matches_tool("file")
assert not spec.matches_tool("web")
def test_invalid_regex_falls_back_to_literal(self):
spec = shell_hooks.ShellHookSpec(
event="pre_tool_call", command="echo", matcher="foo[bar",
)
assert spec.matches_tool("foo[bar")
assert not spec.matches_tool("foo")
def test_matcher_ignored_when_no_tool_name(self):
spec = shell_hooks.ShellHookSpec(
event="pre_tool_call", command="echo", matcher="terminal",
)
assert not spec.matches_tool(None)
def test_matcher_leading_whitespace_stripped(self):
"""YAML quirks can introduce leading/trailing whitespace — must
not silently break the matcher."""
spec = shell_hooks.ShellHookSpec(
event="pre_tool_call", command="echo", matcher=" terminal ",
)
assert spec.matcher == "terminal"
assert spec.matches_tool("terminal")
def test_matcher_trailing_newline_stripped(self):
spec = shell_hooks.ShellHookSpec(
event="pre_tool_call", command="echo", matcher="terminal\n",
)
assert spec.matches_tool("terminal")
def test_whitespace_only_matcher_becomes_none(self):
"""A matcher that's pure whitespace is treated as 'no matcher'."""
spec = shell_hooks.ShellHookSpec(
event="pre_tool_call", command="echo", matcher=" ",
)
assert spec.matcher is None
assert spec.matches_tool("anything")
# ── End-to-end subprocess behaviour ───────────────────────────────────────
class TestCallbackSubprocess:
def test_timeout_returns_none(self, tmp_path):
# Script that sleeps forever; we set a 1s timeout.
script = _write_script(
tmp_path, "slow.sh",
"#!/usr/bin/env bash\nsleep 60\n",
)
spec = shell_hooks.ShellHookSpec(
event="post_tool_call", command=str(script), timeout=1,
)
cb = shell_hooks._make_callback(spec)
assert cb(tool_name="terminal") is None
def test_malformed_json_stdout_returns_none(self, tmp_path):
script = _write_script(
tmp_path, "bad_json.sh",
"#!/usr/bin/env bash\necho 'not json at all'\n",
)
spec = shell_hooks.ShellHookSpec(
event="pre_tool_call", command=str(script),
)
cb = shell_hooks._make_callback(spec)
# Matcher is None so the callback fires for any tool.
assert cb(tool_name="terminal") is None
def test_non_zero_exit_with_block_stdout_still_blocks(self, tmp_path):
"""A script that signals failure via exit code AND prints a block
directive must still block scripts should be free to mix exit
codes with parseable output."""
script = _write_script(
tmp_path, "exit1_block.sh",
"#!/usr/bin/env bash\n"
'printf \'{"decision": "block", "reason": "via exit 1"}\\n\'\n'
"exit 1\n",
)
spec = shell_hooks.ShellHookSpec(
event="pre_tool_call", command=str(script),
)
cb = shell_hooks._make_callback(spec)
assert cb(tool_name="terminal") == {"action": "block", "message": "via exit 1"}
def test_block_translation_end_to_end(self, tmp_path):
"""v1 schema-bug regression gate.
Shell hook returns the Claude-Code-style payload and the bridge
must translate it to the canonical Hermes block shape so that
get_pre_tool_call_block_message() surfaces the block.
"""
script = _write_script(
tmp_path, "blocker.sh",
"#!/usr/bin/env bash\n"
'printf \'{"decision": "block", "reason": "no terminal"}\\n\'\n',
)
spec = shell_hooks.ShellHookSpec(
event="pre_tool_call",
command=str(script),
matcher="terminal",
)
cb = shell_hooks._make_callback(spec)
result = cb(tool_name="terminal", args={"command": "rm -rf /"})
assert result == {"action": "block", "message": "no terminal"}
def test_block_aggregation_through_plugin_manager(self, tmp_path, monkeypatch):
"""Registering via register_from_config makes
get_pre_tool_call_block_message surface the block the real
end-to-end control flow used by run_agent._invoke_tool."""
from hermes_cli import plugins
script = _write_script(
tmp_path, "block.sh",
"#!/usr/bin/env bash\n"
'printf \'{"decision": "block", "reason": "blocked-by-shell"}\\n\'\n',
)
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home"))
monkeypatch.setenv("HERMES_ACCEPT_HOOKS", "1")
# Fresh manager
plugins._plugin_manager = plugins.PluginManager()
cfg = {
"hooks": {
"pre_tool_call": [
{"matcher": "terminal", "command": str(script)},
],
},
}
registered = shell_hooks.register_from_config(cfg, accept_hooks=True)
assert len(registered) == 1
msg = plugins.get_pre_tool_call_block_message(
tool_name="terminal",
args={"command": "rm"},
)
assert msg == "blocked-by-shell"
def test_matcher_regex_filters_callback(self, tmp_path, monkeypatch):
"""A matcher set to 'terminal' must not fire for 'web_search'."""
calls = tmp_path / "calls.log"
script = _write_script(
tmp_path, "log.sh",
f"#!/usr/bin/env bash\n"
f"echo \"$(cat -)\" >> {calls}\n"
f"printf '{{}}\\n'\n",
)
spec = shell_hooks.ShellHookSpec(
event="pre_tool_call",
command=str(script),
matcher="terminal",
)
cb = shell_hooks._make_callback(spec)
cb(tool_name="terminal", args={"command": "ls"})
cb(tool_name="web_search", args={"q": "x"})
cb(tool_name="file_read", args={"path": "x"})
assert calls.exists()
# Only the terminal call wrote to the log
assert calls.read_text().count("pre_tool_call") == 1
def test_payload_schema_delivered(self, tmp_path):
capture = tmp_path / "payload.json"
script = _write_script(
tmp_path, "capture.sh",
f"#!/usr/bin/env bash\ncat - > {capture}\nprintf '{{}}\\n'\n",
)
spec = shell_hooks.ShellHookSpec(
event="pre_tool_call", command=str(script),
)
cb = shell_hooks._make_callback(spec)
cb(
tool_name="terminal",
args={"command": "echo hi"},
session_id="sess-77",
task_id="task-77",
)
payload = json.loads(capture.read_text())
assert payload["hook_event_name"] == "pre_tool_call"
assert payload["tool_name"] == "terminal"
assert payload["tool_input"] == {"command": "echo hi"}
assert payload["session_id"] == "sess-77"
assert "cwd" in payload
assert payload["extra"]["task_id"] == "task-77"
def test_pre_llm_call_context_flows_through(self, tmp_path):
script = _write_script(
tmp_path, "ctx.sh",
"#!/usr/bin/env bash\n"
'printf \'{"context": "env-note"}\\n\'\n',
)
spec = shell_hooks.ShellHookSpec(
event="pre_llm_call", command=str(script),
)
cb = shell_hooks._make_callback(spec)
result = cb(
session_id="s1", user_message="hello",
conversation_history=[], is_first_turn=True,
model="gpt-4", platform="cli",
)
assert result == {"context": "env-note"}
def test_shlex_handles_paths_with_spaces(self, tmp_path):
dir_with_space = tmp_path / "path with space"
dir_with_space.mkdir()
script = _write_script(
dir_with_space, "ok.sh",
"#!/usr/bin/env bash\nprintf '{}\\n'\n",
)
# Quote the path so shlex keeps it as a single token.
spec = shell_hooks.ShellHookSpec(
event="post_tool_call",
command=f'"{script}"',
)
cb = shell_hooks._make_callback(spec)
# No crash = shlex parsed it correctly.
assert cb(tool_name="terminal") is None # empty object parses to None
def test_missing_binary_logged_not_raised(self, tmp_path):
spec = shell_hooks.ShellHookSpec(
event="on_session_start",
command=str(tmp_path / "does-not-exist"),
)
cb = shell_hooks._make_callback(spec)
# Must not raise — agent loop should continue.
assert cb(session_id="s") is None
def test_non_executable_binary_logged_not_raised(self, tmp_path):
path = tmp_path / "no-exec"
path.write_text("#!/usr/bin/env bash\necho hi\n")
# Intentionally do NOT chmod +x.
spec = shell_hooks.ShellHookSpec(
event="on_session_start", command=str(path),
)
cb = shell_hooks._make_callback(spec)
assert cb(session_id="s") is None
# ── config parsing ────────────────────────────────────────────────────────
class TestParseHooksBlock:
def test_valid_entry(self):
specs = shell_hooks._parse_hooks_block({
"pre_tool_call": [
{"matcher": "terminal", "command": "/tmp/hook.sh", "timeout": 30},
],
})
assert len(specs) == 1
assert specs[0].event == "pre_tool_call"
assert specs[0].matcher == "terminal"
assert specs[0].command == "/tmp/hook.sh"
assert specs[0].timeout == 30
def test_unknown_event_skipped(self, caplog):
specs = shell_hooks._parse_hooks_block({
"pre_tools_call": [ # typo
{"command": "/tmp/hook.sh"},
],
})
assert specs == []
def test_missing_command_skipped(self):
specs = shell_hooks._parse_hooks_block({
"pre_tool_call": [{"matcher": "terminal"}],
})
assert specs == []
def test_timeout_clamped_to_max(self):
specs = shell_hooks._parse_hooks_block({
"post_tool_call": [
{"command": "/tmp/slow.sh", "timeout": 9999},
],
})
assert specs[0].timeout == shell_hooks.MAX_TIMEOUT_SECONDS
def test_non_int_timeout_defaulted(self):
specs = shell_hooks._parse_hooks_block({
"post_tool_call": [
{"command": "/tmp/x.sh", "timeout": "thirty"},
],
})
assert specs[0].timeout == shell_hooks.DEFAULT_TIMEOUT_SECONDS
def test_non_list_event_skipped(self):
specs = shell_hooks._parse_hooks_block({
"pre_tool_call": "not a list",
})
assert specs == []
def test_none_hooks_block(self):
assert shell_hooks._parse_hooks_block(None) == []
assert shell_hooks._parse_hooks_block("string") == []
assert shell_hooks._parse_hooks_block([]) == []
def test_non_tool_event_matcher_warns_and_drops(self, caplog):
"""matcher: is only honored for pre/post_tool_call; must warn
and drop on other events so the spec reflects runtime."""
import logging
cfg = {"pre_llm_call": [{"matcher": "terminal", "command": "/bin/echo"}]}
with caplog.at_level(logging.WARNING, logger=shell_hooks.logger.name):
specs = shell_hooks._parse_hooks_block(cfg)
assert len(specs) == 1 and specs[0].matcher is None
assert any(
"only honored for pre_tool_call" in r.getMessage()
and "pre_llm_call" in r.getMessage()
for r in caplog.records
)
# ── Idempotent registration ───────────────────────────────────────────────
class TestIdempotentRegistration:
def test_double_call_registers_once(self, tmp_path, monkeypatch):
from hermes_cli import plugins
script = _write_script(tmp_path, "h.sh",
"#!/usr/bin/env bash\nprintf '{}\\n'\n")
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home"))
monkeypatch.setenv("HERMES_ACCEPT_HOOKS", "1")
plugins._plugin_manager = plugins.PluginManager()
cfg = {"hooks": {"on_session_start": [{"command": str(script)}]}}
first = shell_hooks.register_from_config(cfg, accept_hooks=True)
second = shell_hooks.register_from_config(cfg, accept_hooks=True)
assert len(first) == 1
assert second == []
# Only one callback on the manager
mgr = plugins.get_plugin_manager()
assert len(mgr._hooks.get("on_session_start", [])) == 1
def test_same_command_different_matcher_registers_both(
self, tmp_path, monkeypatch,
):
"""Same script used for different matchers under one event must
register both callbacks dedupe keys on (event, matcher, command)."""
from hermes_cli import plugins
script = _write_script(tmp_path, "h.sh",
"#!/usr/bin/env bash\nprintf '{}\\n'\n")
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home"))
monkeypatch.setenv("HERMES_ACCEPT_HOOKS", "1")
plugins._plugin_manager = plugins.PluginManager()
cfg = {
"hooks": {
"pre_tool_call": [
{"matcher": "terminal", "command": str(script)},
{"matcher": "web_search", "command": str(script)},
],
},
}
registered = shell_hooks.register_from_config(cfg, accept_hooks=True)
assert len(registered) == 2
mgr = plugins.get_plugin_manager()
assert len(mgr._hooks.get("pre_tool_call", [])) == 2
# ── Allowlist concurrency ─────────────────────────────────────────────────
class TestAllowlistConcurrency:
"""Regression tests for the Codex#1 finding: simultaneous
_record_approval() calls used to collide on a fixed tmp path and
silently lose entries under read-modify-write races."""
def test_parallel_record_approval_does_not_lose_entries(
self, tmp_path, monkeypatch,
):
import threading
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home"))
N = 32
barrier = threading.Barrier(N)
errors: list = []
def worker(i: int) -> None:
try:
barrier.wait(timeout=5)
shell_hooks._record_approval(
"on_session_start", f"/bin/hook-{i}.sh",
)
except Exception as exc: # pragma: no cover
errors.append(exc)
threads = [threading.Thread(target=worker, args=(i,)) for i in range(N)]
for t in threads:
t.start()
for t in threads:
t.join()
assert not errors, f"worker errors: {errors}"
data = shell_hooks.load_allowlist()
commands = {e["command"] for e in data["approvals"]}
assert commands == {f"/bin/hook-{i}.sh" for i in range(N)}, (
f"expected all {N} entries, got {len(commands)}"
)
def test_non_posix_fallback_does_not_self_deadlock(
self, tmp_path, monkeypatch,
):
"""Regression: on platforms without fcntl, the fallback lock must
be separate from _registered_lock. register_from_config holds
_registered_lock while calling _record_approval (via the consent
prompt path), so a shared non-reentrant lock would self-deadlock."""
import threading
monkeypatch.setattr(shell_hooks, "fcntl", None)
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home"))
completed = threading.Event()
errors: list = []
def target() -> None:
try:
with shell_hooks._registered_lock:
shell_hooks._record_approval(
"on_session_start", "/bin/x.sh",
)
completed.set()
except Exception as exc: # pragma: no cover
errors.append(exc)
completed.set()
t = threading.Thread(target=target, daemon=True)
t.start()
if not completed.wait(timeout=3.0):
pytest.fail(
"non-POSIX fallback self-deadlocked — "
"_locked_update_approvals must not reuse _registered_lock",
)
t.join(timeout=1.0)
assert not errors, f"errors: {errors}"
assert shell_hooks._is_allowlisted(
"on_session_start", "/bin/x.sh",
)
def test_save_allowlist_failure_logs_actionable_warning(
self, tmp_path, monkeypatch, caplog,
):
"""Persistence failures must log the path, errno, and
re-prompt consequence so "hermes keeps asking" is debuggable."""
import logging
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home"))
monkeypatch.setattr(
shell_hooks.tempfile, "mkstemp",
lambda *a, **kw: (_ for _ in ()).throw(OSError(28, "No space")),
)
with caplog.at_level(logging.WARNING, logger=shell_hooks.logger.name):
shell_hooks.save_allowlist({"approvals": []})
msg = next(
(r.getMessage() for r in caplog.records
if "Failed to persist" in r.getMessage()), "",
)
assert "shell-hooks-allowlist.json" in msg
assert "No space" in msg
assert "re-prompt" in msg
def test_script_is_executable_handles_interpreter_prefix(self, tmp_path):
"""For ``python3 hook.py`` and similar the interpreter reads
the script, so X_OK on the script itself is not required
only R_OK. Bare invocations still require X_OK."""
script = tmp_path / "hook.py"
script.write_text("print()\n") # readable, NOT executable
# Interpreter prefix: R_OK is enough.
assert shell_hooks.script_is_executable(f"python3 {script}")
assert shell_hooks.script_is_executable(f"/usr/bin/env python3 {script}")
# Bare invocation on the same non-X_OK file: not runnable.
assert not shell_hooks.script_is_executable(str(script))
# Flip +x; bare invocation is now runnable too.
script.chmod(0o755)
assert shell_hooks.script_is_executable(str(script))
def test_command_script_path_resolution(self):
"""Regression: ``_command_script_path`` used to return the first
shlex token, which picked the interpreter (``python3``, ``bash``,
``/usr/bin/env``) instead of the actual script for any
interpreter-prefixed command. That broke
``hermes hooks doctor``'s executability check and silently
disabled mtime drift detection for such hooks."""
cases = [
# bare path
("/path/hook.sh", "/path/hook.sh"),
("/bin/echo hi", "/bin/echo"),
("~/hook.sh", "~/hook.sh"),
("hook.sh", "hook.sh"),
# interpreter prefix
("python3 /path/hook.py", "/path/hook.py"),
("bash /path/hook.sh", "/path/hook.sh"),
("bash ~/hook.sh", "~/hook.sh"),
("python3 -u /path/hook.py", "/path/hook.py"),
("nice -n 10 /path/hook.sh", "/path/hook.sh"),
# /usr/bin/env shebang form — must find the *script*, not env
("/usr/bin/env python3 /path/hook.py", "/path/hook.py"),
("/usr/bin/env bash /path/hook.sh", "/path/hook.sh"),
# no path-like tokens → fallback to first token
("my-binary --verbose", "my-binary"),
("python3 -c 'print(1)'", "python3"),
# unparseable (unbalanced quotes) → return command as-is
("python3 'unterminated", "python3 'unterminated"),
# empty
("", ""),
]
for command, expected in cases:
got = shell_hooks._command_script_path(command)
assert got == expected, f"{command!r} -> {got!r}, expected {expected!r}"
def test_save_allowlist_uses_unique_tmp_paths(self, tmp_path, monkeypatch):
"""Two save_allowlist calls in flight must use distinct tmp files
so the loser's os.replace does not ENOENT on the winner's sweep."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home"))
p = shell_hooks.allowlist_path()
p.parent.mkdir(parents=True, exist_ok=True)
tmp_paths_seen: list = []
real_mkstemp = shell_hooks.tempfile.mkstemp
def spying_mkstemp(*args, **kwargs):
fd, path = real_mkstemp(*args, **kwargs)
tmp_paths_seen.append(path)
return fd, path
monkeypatch.setattr(shell_hooks.tempfile, "mkstemp", spying_mkstemp)
shell_hooks.save_allowlist({"approvals": [{"event": "a", "command": "x"}]})
shell_hooks.save_allowlist({"approvals": [{"event": "b", "command": "y"}]})
assert len(tmp_paths_seen) == 2
assert tmp_paths_seen[0] != tmp_paths_seen[1]

View file

@ -0,0 +1,242 @@
"""Consent-flow tests for the shell-hook allowlist.
Covers the prompt/non-prompt decision tree: TTY vs non-TTY, and the
three accept-hooks channels (--accept-hooks, HERMES_ACCEPT_HOOKS env,
hooks_auto_accept: config key).
"""
from __future__ import annotations
import json
from pathlib import Path
from unittest.mock import patch
import pytest
from agent import shell_hooks
@pytest.fixture(autouse=True)
def _isolated_home(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_home"))
monkeypatch.delenv("HERMES_ACCEPT_HOOKS", raising=False)
shell_hooks.reset_for_tests()
yield
shell_hooks.reset_for_tests()
def _write_hook_script(tmp_path: Path) -> Path:
script = tmp_path / "hook.sh"
script.write_text("#!/usr/bin/env bash\nprintf '{}\\n'\n")
script.chmod(0o755)
return script
# ── TTY prompt flow ───────────────────────────────────────────────────────
class TestTTYPromptFlow:
def test_first_use_prompts_and_approves(self, tmp_path):
from hermes_cli import plugins
script = _write_hook_script(tmp_path)
plugins._plugin_manager = plugins.PluginManager()
with patch("sys.stdin") as mock_stdin, patch("builtins.input", return_value="y"):
mock_stdin.isatty.return_value = True
registered = shell_hooks.register_from_config(
{"hooks": {"on_session_start": [{"command": str(script)}]}},
accept_hooks=False,
)
assert len(registered) == 1
entry = shell_hooks.allowlist_entry_for("on_session_start", str(script))
assert entry is not None
assert entry["event"] == "on_session_start"
assert entry["command"] == str(script)
def test_first_use_prompts_and_rejects(self, tmp_path):
from hermes_cli import plugins
script = _write_hook_script(tmp_path)
plugins._plugin_manager = plugins.PluginManager()
with patch("sys.stdin") as mock_stdin, patch("builtins.input", return_value="n"):
mock_stdin.isatty.return_value = True
registered = shell_hooks.register_from_config(
{"hooks": {"on_session_start": [{"command": str(script)}]}},
accept_hooks=False,
)
assert registered == []
assert shell_hooks.allowlist_entry_for(
"on_session_start", str(script),
) is None
def test_subsequent_use_does_not_prompt(self, tmp_path):
"""After the first approval, re-registration must be silent."""
from hermes_cli import plugins
script = _write_hook_script(tmp_path)
plugins._plugin_manager = plugins.PluginManager()
# First call: TTY, approved.
with patch("sys.stdin") as mock_stdin, patch("builtins.input", return_value="y"):
mock_stdin.isatty.return_value = True
shell_hooks.register_from_config(
{"hooks": {"on_session_start": [{"command": str(script)}]}},
accept_hooks=False,
)
# Reset registration set but keep the allowlist on disk.
shell_hooks.reset_for_tests()
# Second call: TTY, input() must NOT be called.
with patch("sys.stdin") as mock_stdin, patch(
"builtins.input", side_effect=AssertionError("should not prompt"),
):
mock_stdin.isatty.return_value = True
registered = shell_hooks.register_from_config(
{"hooks": {"on_session_start": [{"command": str(script)}]}},
accept_hooks=False,
)
assert len(registered) == 1
# ── non-TTY flow ──────────────────────────────────────────────────────────
class TestNonTTYFlow:
def test_no_tty_no_flag_skips_registration(self, tmp_path):
from hermes_cli import plugins
script = _write_hook_script(tmp_path)
plugins._plugin_manager = plugins.PluginManager()
with patch("sys.stdin") as mock_stdin:
mock_stdin.isatty.return_value = False
registered = shell_hooks.register_from_config(
{"hooks": {"on_session_start": [{"command": str(script)}]}},
accept_hooks=False,
)
assert registered == []
def test_no_tty_with_argument_flag_accepts(self, tmp_path):
from hermes_cli import plugins
script = _write_hook_script(tmp_path)
plugins._plugin_manager = plugins.PluginManager()
with patch("sys.stdin") as mock_stdin:
mock_stdin.isatty.return_value = False
registered = shell_hooks.register_from_config(
{"hooks": {"on_session_start": [{"command": str(script)}]}},
accept_hooks=True,
)
assert len(registered) == 1
def test_no_tty_with_env_accepts(self, tmp_path, monkeypatch):
from hermes_cli import plugins
script = _write_hook_script(tmp_path)
plugins._plugin_manager = plugins.PluginManager()
monkeypatch.setenv("HERMES_ACCEPT_HOOKS", "1")
with patch("sys.stdin") as mock_stdin:
mock_stdin.isatty.return_value = False
registered = shell_hooks.register_from_config(
{"hooks": {"on_session_start": [{"command": str(script)}]}},
accept_hooks=False,
)
assert len(registered) == 1
def test_no_tty_with_config_accepts(self, tmp_path):
from hermes_cli import plugins
script = _write_hook_script(tmp_path)
plugins._plugin_manager = plugins.PluginManager()
with patch("sys.stdin") as mock_stdin:
mock_stdin.isatty.return_value = False
registered = shell_hooks.register_from_config(
{
"hooks_auto_accept": True,
"hooks": {"on_session_start": [{"command": str(script)}]},
},
accept_hooks=False,
)
assert len(registered) == 1
# ── Allowlist + revoke + mtime ────────────────────────────────────────────
class TestAllowlistOps:
def test_mtime_recorded_on_approval(self, tmp_path):
script = _write_hook_script(tmp_path)
shell_hooks._record_approval("on_session_start", str(script))
entry = shell_hooks.allowlist_entry_for(
"on_session_start", str(script),
)
assert entry is not None
assert entry["script_mtime_at_approval"] is not None
# ISO-8601 Z-suffix
assert entry["script_mtime_at_approval"].endswith("Z")
def test_revoke_removes_entry(self, tmp_path):
script = _write_hook_script(tmp_path)
shell_hooks._record_approval("on_session_start", str(script))
assert shell_hooks.allowlist_entry_for(
"on_session_start", str(script),
) is not None
removed = shell_hooks.revoke(str(script))
assert removed == 1
assert shell_hooks.allowlist_entry_for(
"on_session_start", str(script),
) is None
def test_revoke_unknown_returns_zero(self, tmp_path):
assert shell_hooks.revoke(str(tmp_path / "never-approved.sh")) == 0
def test_tilde_path_approval_records_resolvable_mtime(self, tmp_path, monkeypatch):
"""If the command uses ~ the approval must still find the file."""
monkeypatch.setenv("HOME", str(tmp_path))
target = tmp_path / "hook.sh"
target.write_text("#!/usr/bin/env bash\n")
target.chmod(0o755)
shell_hooks._record_approval("on_session_start", "~/hook.sh")
entry = shell_hooks.allowlist_entry_for(
"on_session_start", "~/hook.sh",
)
assert entry is not None
# Must not be None — the tilde was expanded before stat().
assert entry["script_mtime_at_approval"] is not None
def test_duplicate_approval_replaces_mtime(self, tmp_path):
"""Re-approving the same pair refreshes the approval timestamp."""
script = _write_hook_script(tmp_path)
shell_hooks._record_approval("on_session_start", str(script))
original_entry = shell_hooks.allowlist_entry_for(
"on_session_start", str(script),
)
assert original_entry is not None
# Touch the script to bump its mtime then re-approve.
import os
import time
new_mtime = original_entry.get("script_mtime_at_approval")
time.sleep(0.01)
os.utime(script, None) # current time
shell_hooks._record_approval("on_session_start", str(script))
# Exactly one entry per (event, command).
approvals = shell_hooks.load_allowlist().get("approvals", [])
matching = [
e for e in approvals
if e.get("event") == "on_session_start"
and e.get("command") == str(script)
]
assert len(matching) == 1

View file

@ -405,3 +405,191 @@ class TestPlanSkillHelpers:
assert "Add a /plan command" in msg
assert ".hermes/plans/plan.md" in msg
assert "Runtime note:" in msg
class TestSkillDirectoryHeader:
"""The activation message must expose the absolute skill directory and
explain how to resolve relative paths, so skills with bundled scripts
don't force the agent into a second ``skill_view()`` round-trip."""
def test_header_contains_absolute_skill_dir(self, tmp_path):
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
skill_dir = _make_skill(tmp_path, "abs-dir-skill")
scan_skill_commands()
msg = build_skill_invocation_message("/abs-dir-skill", "go")
assert msg is not None
assert f"[Skill directory: {skill_dir}]" in msg
assert "Resolve any relative paths" in msg
def test_supporting_files_shown_with_absolute_paths(self, tmp_path):
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
skill_dir = _make_skill(tmp_path, "scripted-skill")
(skill_dir / "scripts").mkdir()
(skill_dir / "scripts" / "run.js").write_text("console.log('hi')")
scan_skill_commands()
msg = build_skill_invocation_message("/scripted-skill")
assert msg is not None
# The supporting-files block must emit both the relative form (so the
# agent can call skill_view on it) and the absolute form (so it can
# run the script directly via terminal).
assert "scripts/run.js" in msg
assert str(skill_dir / "scripts" / "run.js") in msg
assert f"node {skill_dir}/scripts/foo.js" in msg
class TestTemplateVarSubstitution:
"""``${HERMES_SKILL_DIR}`` and ``${HERMES_SESSION_ID}`` in SKILL.md body
are replaced before the agent sees the content."""
def test_substitutes_skill_dir(self, tmp_path):
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
skill_dir = _make_skill(
tmp_path,
"templated",
body="Run: node ${HERMES_SKILL_DIR}/scripts/foo.js",
)
scan_skill_commands()
msg = build_skill_invocation_message("/templated")
assert msg is not None
assert f"node {skill_dir}/scripts/foo.js" in msg
# The literal template token must not leak through.
assert "${HERMES_SKILL_DIR}" not in msg.split("[Skill directory:")[0]
def test_substitutes_session_id_when_available(self, tmp_path):
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
_make_skill(
tmp_path,
"sess-templated",
body="Session: ${HERMES_SESSION_ID}",
)
scan_skill_commands()
msg = build_skill_invocation_message(
"/sess-templated", task_id="abc-123"
)
assert msg is not None
assert "Session: abc-123" in msg
def test_leaves_session_id_token_when_missing(self, tmp_path):
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
_make_skill(
tmp_path,
"sess-missing",
body="Session: ${HERMES_SESSION_ID}",
)
scan_skill_commands()
msg = build_skill_invocation_message("/sess-missing", task_id=None)
assert msg is not None
# No session — token left intact so the author can spot it.
assert "Session: ${HERMES_SESSION_ID}" in msg
def test_disable_template_vars_via_config(self, tmp_path):
with (
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
patch(
"agent.skill_commands._load_skills_config",
return_value={"template_vars": False},
),
):
_make_skill(
tmp_path,
"no-sub",
body="Run: node ${HERMES_SKILL_DIR}/scripts/foo.js",
)
scan_skill_commands()
msg = build_skill_invocation_message("/no-sub")
assert msg is not None
# Template token must survive when substitution is disabled.
assert "${HERMES_SKILL_DIR}/scripts/foo.js" in msg
class TestInlineShellExpansion:
"""Inline ``!`cmd`` snippets in SKILL.md run before the agent sees the
content but only when the user has opted in via config."""
def test_inline_shell_is_off_by_default(self, tmp_path):
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
_make_skill(
tmp_path,
"dyn-default-off",
body="Today is !`echo INLINE_RAN`.",
)
scan_skill_commands()
msg = build_skill_invocation_message("/dyn-default-off")
assert msg is not None
# Default config has inline_shell=False — snippet must stay literal.
assert "!`echo INLINE_RAN`" in msg
assert "Today is INLINE_RAN." not in msg
def test_inline_shell_runs_when_enabled(self, tmp_path):
with (
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
patch(
"agent.skill_commands._load_skills_config",
return_value={"template_vars": True, "inline_shell": True,
"inline_shell_timeout": 5},
),
):
_make_skill(
tmp_path,
"dyn-on",
body="Marker: !`echo INLINE_RAN`.",
)
scan_skill_commands()
msg = build_skill_invocation_message("/dyn-on")
assert msg is not None
assert "Marker: INLINE_RAN." in msg
assert "!`echo INLINE_RAN`" not in msg
def test_inline_shell_runs_in_skill_directory(self, tmp_path):
"""Inline snippets get the skill dir as CWD so relative paths work."""
with (
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
patch(
"agent.skill_commands._load_skills_config",
return_value={"template_vars": True, "inline_shell": True,
"inline_shell_timeout": 5},
),
):
skill_dir = _make_skill(
tmp_path,
"dyn-cwd",
body="Here: !`pwd`",
)
scan_skill_commands()
msg = build_skill_invocation_message("/dyn-cwd")
assert msg is not None
assert f"Here: {skill_dir}" in msg
def test_inline_shell_timeout_does_not_break_message(self, tmp_path):
with (
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
patch(
"agent.skill_commands._load_skills_config",
return_value={"template_vars": True, "inline_shell": True,
"inline_shell_timeout": 1},
),
):
_make_skill(
tmp_path,
"dyn-slow",
body="Slow: !`sleep 5 && printf DYN_MARKER`",
)
scan_skill_commands()
msg = build_skill_invocation_message("/dyn-slow")
assert msg is not None
# Timeout is surfaced as a marker instead of propagating as an error,
# and the rest of the skill message still renders.
assert "inline-shell timeout" in msg
# The command's intended stdout never made it through — only the
# timeout marker (which echoes the command text) survives.
assert "DYN_MARKER" not in msg.replace("sleep 5 && printf DYN_MARKER", "")

View file

@ -1,61 +0,0 @@
from agent.smart_model_routing import choose_cheap_model_route
_BASE_CONFIG = {
"enabled": True,
"cheap_model": {
"provider": "openrouter",
"model": "google/gemini-2.5-flash",
},
}
def test_returns_none_when_disabled():
cfg = {**_BASE_CONFIG, "enabled": False}
assert choose_cheap_model_route("what time is it in tokyo?", cfg) is None
def test_routes_short_simple_prompt():
result = choose_cheap_model_route("what time is it in tokyo?", _BASE_CONFIG)
assert result is not None
assert result["provider"] == "openrouter"
assert result["model"] == "google/gemini-2.5-flash"
assert result["routing_reason"] == "simple_turn"
def test_skips_long_prompt():
prompt = "please summarize this carefully " * 20
assert choose_cheap_model_route(prompt, _BASE_CONFIG) is None
def test_skips_code_like_prompt():
prompt = "debug this traceback: ```python\nraise ValueError('bad')\n```"
assert choose_cheap_model_route(prompt, _BASE_CONFIG) is None
def test_skips_tool_heavy_prompt_keywords():
prompt = "implement a patch for this docker error"
assert choose_cheap_model_route(prompt, _BASE_CONFIG) is None
def test_resolve_turn_route_falls_back_to_primary_when_route_runtime_cannot_be_resolved(monkeypatch):
from agent.smart_model_routing import resolve_turn_route
monkeypatch.setattr(
"hermes_cli.runtime_provider.resolve_runtime_provider",
lambda **kwargs: (_ for _ in ()).throw(RuntimeError("bad route")),
)
result = resolve_turn_route(
"what time is it in tokyo?",
_BASE_CONFIG,
{
"model": "anthropic/claude-sonnet-4",
"provider": "openrouter",
"base_url": "https://openrouter.ai/api/v1",
"api_mode": "chat_completions",
"api_key": "sk-primary",
},
)
assert result["model"] == "anthropic/claude-sonnet-4"
assert result["runtime"]["provider"] == "openrouter"
assert result["label"] is None

View file

@ -193,7 +193,7 @@ class TestBuildChildProgressCallback:
# task_index=0 in a batch of 3 → prefix "[1]"
cb0 = _build_child_progress_callback(0, "test goal", parent, task_count=3)
cb0("web_search", "test")
cb0("tool.started", "web_search", "test", {})
output = buf.getvalue()
assert "[1]" in output
@ -201,7 +201,7 @@ class TestBuildChildProgressCallback:
buf.truncate(0)
buf.seek(0)
cb2 = _build_child_progress_callback(2, "test goal", parent, task_count=3)
cb2("web_search", "test")
cb2("tool.started", "web_search", "test", {})
output = buf.getvalue()
assert "[3]" in output

View file

@ -0,0 +1,224 @@
"""Tests for the subagent_stop hook event.
Covers wire-up from tools.delegate_tool.delegate_task:
* fires once per child in both single-task and batch modes
* runs on the parent thread (no re-entrancy for hook authors)
* carries child_role when the agent exposes _delegate_role
* carries child_role=None when _delegate_role is not set (pre-M3)
"""
from __future__ import annotations
import json
import threading
from unittest.mock import MagicMock, patch
import pytest
from tools.delegate_tool import delegate_task
from hermes_cli import plugins
def _make_parent(depth: int = 0, session_id: str = "parent-1"):
parent = MagicMock()
parent.base_url = "https://openrouter.ai/api/v1"
parent.api_key = "***"
parent.provider = "openrouter"
parent.api_mode = "chat_completions"
parent.model = "anthropic/claude-sonnet-4"
parent.platform = "cli"
parent.providers_allowed = None
parent.providers_ignored = None
parent.providers_order = None
parent.provider_sort = None
parent._session_db = None
parent._delegate_depth = depth
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent._print_fn = None
parent.tool_progress_callback = None
parent.thinking_callback = None
parent._memory_manager = None
parent.session_id = session_id
return parent
@pytest.fixture(autouse=True)
def _fresh_plugin_manager():
"""Each test gets a fresh PluginManager so hook callbacks don't
leak between tests."""
original = plugins._plugin_manager
plugins._plugin_manager = plugins.PluginManager()
yield
plugins._plugin_manager = original
@pytest.fixture(autouse=True)
def _stub_child_builder(monkeypatch):
"""Replace _build_child_agent with a MagicMock factory so delegate_task
never transitively imports run_agent / openai. Keeps the test runnable
in environments without heavyweight runtime deps installed."""
def _fake_build_child(task_index, **kwargs):
child = MagicMock()
child._delegate_saved_tool_names = []
child._credential_pool = None
return child
monkeypatch.setattr(
"tools.delegate_tool._build_child_agent", _fake_build_child,
)
def _register_capturing_hook():
captured = []
def _cb(**kwargs):
kwargs["_thread"] = threading.current_thread()
captured.append(kwargs)
mgr = plugins.get_plugin_manager()
mgr._hooks.setdefault("subagent_stop", []).append(_cb)
return captured
# ── single-task mode ──────────────────────────────────────────────────────
class TestSingleTask:
def test_fires_once(self):
captured = _register_capturing_hook()
with patch("tools.delegate_tool._run_single_child") as mock_run:
mock_run.return_value = {
"task_index": 0,
"status": "completed",
"summary": "Done!",
"api_calls": 3,
"duration_seconds": 5.0,
"_child_role": "analyst",
}
delegate_task(goal="do X", parent_agent=_make_parent())
assert len(captured) == 1
payload = captured[0]
assert payload["child_role"] == "analyst"
assert payload["child_status"] == "completed"
assert payload["child_summary"] == "Done!"
assert payload["duration_ms"] == 5000
def test_fires_on_parent_thread(self):
captured = _register_capturing_hook()
main_thread = threading.current_thread()
with patch("tools.delegate_tool._run_single_child") as mock_run:
mock_run.return_value = {
"task_index": 0, "status": "completed",
"summary": "x", "api_calls": 1, "duration_seconds": 0.1,
"_child_role": None,
}
delegate_task(goal="go", parent_agent=_make_parent())
assert captured[0]["_thread"] is main_thread
def test_payload_includes_parent_session_id(self):
captured = _register_capturing_hook()
with patch("tools.delegate_tool._run_single_child") as mock_run:
mock_run.return_value = {
"task_index": 0, "status": "completed",
"summary": "x", "api_calls": 1, "duration_seconds": 0.1,
"_child_role": None,
}
delegate_task(
goal="go",
parent_agent=_make_parent(session_id="sess-xyz"),
)
assert captured[0]["parent_session_id"] == "sess-xyz"
# ── batch mode ────────────────────────────────────────────────────────────
class TestBatchMode:
def test_fires_per_child(self):
captured = _register_capturing_hook()
with patch("tools.delegate_tool._run_single_child") as mock_run:
mock_run.side_effect = [
{"task_index": 0, "status": "completed",
"summary": "A", "api_calls": 1, "duration_seconds": 1.0,
"_child_role": "role-a"},
{"task_index": 1, "status": "completed",
"summary": "B", "api_calls": 2, "duration_seconds": 2.0,
"_child_role": "role-b"},
{"task_index": 2, "status": "completed",
"summary": "C", "api_calls": 3, "duration_seconds": 3.0,
"_child_role": "role-c"},
]
delegate_task(
tasks=[
{"goal": "A"}, {"goal": "B"}, {"goal": "C"},
],
parent_agent=_make_parent(),
)
assert len(captured) == 3
roles = sorted(c["child_role"] for c in captured)
assert roles == ["role-a", "role-b", "role-c"]
def test_all_fires_on_parent_thread(self):
captured = _register_capturing_hook()
main_thread = threading.current_thread()
with patch("tools.delegate_tool._run_single_child") as mock_run:
mock_run.side_effect = [
{"task_index": 0, "status": "completed",
"summary": "A", "api_calls": 1, "duration_seconds": 1.0,
"_child_role": None},
{"task_index": 1, "status": "completed",
"summary": "B", "api_calls": 2, "duration_seconds": 2.0,
"_child_role": None},
]
delegate_task(
tasks=[{"goal": "A"}, {"goal": "B"}],
parent_agent=_make_parent(),
)
for payload in captured:
assert payload["_thread"] is main_thread
# ── payload shape ─────────────────────────────────────────────────────────
class TestPayloadShape:
def test_role_absent_becomes_none(self):
captured = _register_capturing_hook()
with patch("tools.delegate_tool._run_single_child") as mock_run:
mock_run.return_value = {
"task_index": 0, "status": "completed",
"summary": "x", "api_calls": 1, "duration_seconds": 0.1,
# Deliberately omit _child_role — pre-M3 shape.
}
delegate_task(goal="do X", parent_agent=_make_parent())
assert captured[0]["child_role"] is None
def test_result_does_not_leak_child_role_field(self):
"""The internal _child_role key must be stripped before the
result dict is serialised to JSON."""
_register_capturing_hook()
with patch("tools.delegate_tool._run_single_child") as mock_run:
mock_run.return_value = {
"task_index": 0, "status": "completed",
"summary": "x", "api_calls": 1, "duration_seconds": 0.1,
"_child_role": "leaf",
}
raw = delegate_task(goal="do X", parent_agent=_make_parent())
parsed = json.loads(raw)
assert "results" in parsed
assert "_child_role" not in parsed["results"][0]

View file

View file

@ -0,0 +1,164 @@
"""Tests for the BedrockTransport."""
import json
import pytest
from types import SimpleNamespace
from agent.transports import get_transport
from agent.transports.types import NormalizedResponse, ToolCall
@pytest.fixture
def transport():
import agent.transports.bedrock # noqa: F401
return get_transport("bedrock_converse")
class TestBedrockBasic:
def test_api_mode(self, transport):
assert transport.api_mode == "bedrock_converse"
def test_registered(self, transport):
assert transport is not None
class TestBedrockBuildKwargs:
def test_basic_kwargs(self, transport):
msgs = [{"role": "user", "content": "Hello"}]
kw = transport.build_kwargs(model="anthropic.claude-3-5-sonnet-20241022-v2:0", messages=msgs)
assert kw["modelId"] == "anthropic.claude-3-5-sonnet-20241022-v2:0"
assert kw["__bedrock_converse__"] is True
assert kw["__bedrock_region__"] == "us-east-1"
assert "messages" in kw
def test_custom_region(self, transport):
msgs = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
messages=msgs,
region="eu-west-1",
)
assert kw["__bedrock_region__"] == "eu-west-1"
def test_max_tokens(self, transport):
msgs = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
messages=msgs,
max_tokens=8192,
)
assert kw["inferenceConfig"]["maxTokens"] == 8192
class TestBedrockConvertTools:
def test_convert_tools(self, transport):
tools = [{
"type": "function",
"function": {
"name": "terminal",
"description": "Run commands",
"parameters": {"type": "object", "properties": {"command": {"type": "string"}}},
}
}]
result = transport.convert_tools(tools)
assert len(result) == 1
assert result[0]["toolSpec"]["name"] == "terminal"
class TestBedrockValidate:
def test_none(self, transport):
assert transport.validate_response(None) is False
def test_raw_dict_valid(self, transport):
assert transport.validate_response({"output": {"message": {}}}) is True
def test_raw_dict_invalid(self, transport):
assert transport.validate_response({"error": "fail"}) is False
def test_normalized_valid(self, transport):
r = SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="hi"))])
assert transport.validate_response(r) is True
class TestBedrockMapFinishReason:
def test_end_turn(self, transport):
assert transport.map_finish_reason("end_turn") == "stop"
def test_tool_use(self, transport):
assert transport.map_finish_reason("tool_use") == "tool_calls"
def test_max_tokens(self, transport):
assert transport.map_finish_reason("max_tokens") == "length"
def test_guardrail(self, transport):
assert transport.map_finish_reason("guardrail_intervened") == "content_filter"
def test_unknown(self, transport):
assert transport.map_finish_reason("unknown") == "stop"
class TestBedrockNormalize:
def _make_bedrock_response(self, text="Hello", tool_calls=None, stop_reason="end_turn"):
"""Build a raw Bedrock converse response dict."""
content = []
if text:
content.append({"text": text})
if tool_calls:
for tc in tool_calls:
content.append({
"toolUse": {
"toolUseId": tc["id"],
"name": tc["name"],
"input": tc["input"],
}
})
return {
"output": {"message": {"role": "assistant", "content": content}},
"stopReason": stop_reason,
"usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15},
}
def test_text_response(self, transport):
raw = self._make_bedrock_response(text="Hello world")
nr = transport.normalize_response(raw)
assert isinstance(nr, NormalizedResponse)
assert nr.content == "Hello world"
assert nr.finish_reason == "stop"
def test_tool_call_response(self, transport):
raw = self._make_bedrock_response(
text=None,
tool_calls=[{"id": "tool_1", "name": "terminal", "input": {"command": "ls"}}],
stop_reason="tool_use",
)
nr = transport.normalize_response(raw)
assert nr.finish_reason == "tool_calls"
assert len(nr.tool_calls) == 1
assert nr.tool_calls[0].name == "terminal"
def test_already_normalized_response(self, transport):
"""Test normalize_response handles already-normalized SimpleNamespace (from dispatch site)."""
pre_normalized = SimpleNamespace(
choices=[SimpleNamespace(
message=SimpleNamespace(
content="Hello from Bedrock",
tool_calls=None,
reasoning=None,
reasoning_content=None,
),
finish_reason="stop",
)],
usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15),
)
nr = transport.normalize_response(pre_normalized)
assert isinstance(nr, NormalizedResponse)
assert nr.content == "Hello from Bedrock"
assert nr.finish_reason == "stop"
assert nr.usage is not None
assert nr.usage.prompt_tokens == 10

View file

@ -0,0 +1,349 @@
"""Tests for the ChatCompletionsTransport."""
import pytest
from types import SimpleNamespace
from agent.transports import get_transport
from agent.transports.types import NormalizedResponse, ToolCall
@pytest.fixture
def transport():
import agent.transports.chat_completions # noqa: F401
return get_transport("chat_completions")
class TestChatCompletionsBasic:
def test_api_mode(self, transport):
assert transport.api_mode == "chat_completions"
def test_registered(self, transport):
assert transport is not None
def test_convert_tools_identity(self, transport):
tools = [{"type": "function", "function": {"name": "test", "parameters": {}}}]
assert transport.convert_tools(tools) is tools
def test_convert_messages_no_codex_leaks(self, transport):
msgs = [{"role": "user", "content": "hi"}]
result = transport.convert_messages(msgs)
assert result is msgs # no copy needed
def test_convert_messages_strips_codex_fields(self, transport):
msgs = [
{"role": "assistant", "content": "ok", "codex_reasoning_items": [{"id": "rs_1"}],
"tool_calls": [{"id": "call_1", "call_id": "call_1", "response_item_id": "fc_1",
"type": "function", "function": {"name": "t", "arguments": "{}"}}]},
]
result = transport.convert_messages(msgs)
assert "codex_reasoning_items" not in result[0]
assert "call_id" not in result[0]["tool_calls"][0]
assert "response_item_id" not in result[0]["tool_calls"][0]
# Original list untouched (deepcopy-on-demand)
assert "codex_reasoning_items" in msgs[0]
class TestChatCompletionsBuildKwargs:
def test_basic_kwargs(self, transport):
msgs = [{"role": "user", "content": "Hello"}]
kw = transport.build_kwargs(model="gpt-4o", messages=msgs, timeout=30.0)
assert kw["model"] == "gpt-4o"
assert kw["messages"][0]["content"] == "Hello"
assert kw["timeout"] == 30.0
def test_developer_role_swap(self, transport):
msgs = [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(model="gpt-5.4", messages=msgs, model_lower="gpt-5.4")
assert kw["messages"][0]["role"] == "developer"
def test_no_developer_swap_for_non_gpt5(self, transport):
msgs = [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(model="claude-sonnet-4", messages=msgs, model_lower="claude-sonnet-4")
assert kw["messages"][0]["role"] == "system"
def test_tools_included(self, transport):
msgs = [{"role": "user", "content": "Hi"}]
tools = [{"type": "function", "function": {"name": "test", "parameters": {}}}]
kw = transport.build_kwargs(model="gpt-4o", messages=msgs, tools=tools)
assert kw["tools"] == tools
def test_openrouter_provider_prefs(self, transport):
msgs = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="gpt-4o", messages=msgs,
is_openrouter=True,
provider_preferences={"only": ["openai"]},
)
assert kw["extra_body"]["provider"] == {"only": ["openai"]}
def test_nous_tags(self, transport):
msgs = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(model="gpt-4o", messages=msgs, is_nous=True)
assert kw["extra_body"]["tags"] == ["product=hermes-agent"]
def test_reasoning_default(self, transport):
msgs = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="gpt-4o", messages=msgs,
supports_reasoning=True,
)
assert kw["extra_body"]["reasoning"] == {"enabled": True, "effort": "medium"}
def test_nous_omits_disabled_reasoning(self, transport):
msgs = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="gpt-4o", messages=msgs,
supports_reasoning=True,
is_nous=True,
reasoning_config={"enabled": False},
)
# Nous rejects enabled=false; reasoning omitted entirely
assert "reasoning" not in kw.get("extra_body", {})
def test_ollama_num_ctx(self, transport):
msgs = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="llama3", messages=msgs,
ollama_num_ctx=32768,
)
assert kw["extra_body"]["options"]["num_ctx"] == 32768
def test_custom_think_false(self, transport):
msgs = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="qwen3", messages=msgs,
is_custom_provider=True,
reasoning_config={"effort": "none"},
)
assert kw["extra_body"]["think"] is False
def test_max_tokens_with_fn(self, transport):
msgs = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="gpt-4o", messages=msgs,
max_tokens=4096,
max_tokens_param_fn=lambda n: {"max_tokens": n},
)
assert kw["max_tokens"] == 4096
def test_ephemeral_overrides_max_tokens(self, transport):
msgs = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="gpt-4o", messages=msgs,
max_tokens=4096,
ephemeral_max_output_tokens=2048,
max_tokens_param_fn=lambda n: {"max_tokens": n},
)
assert kw["max_tokens"] == 2048
def test_nvidia_default_max_tokens(self, transport):
msgs = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="glm-4.7", messages=msgs,
is_nvidia_nim=True,
max_tokens_param_fn=lambda n: {"max_tokens": n},
)
# NVIDIA default: 16384
assert kw["max_tokens"] == 16384
def test_qwen_default_max_tokens(self, transport):
msgs = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="qwen3-coder-plus", messages=msgs,
is_qwen_portal=True,
max_tokens_param_fn=lambda n: {"max_tokens": n},
)
# Qwen default: 65536
assert kw["max_tokens"] == 65536
def test_anthropic_max_output_for_claude_on_aggregator(self, transport):
msgs = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="anthropic/claude-sonnet-4.6", messages=msgs,
is_openrouter=True,
anthropic_max_output=64000,
)
# Set as plain max_tokens (not via fn) because the aggregator proxies to
# Anthropic Messages API which requires the field.
assert kw["max_tokens"] == 64000
def test_request_overrides_last(self, transport):
msgs = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="gpt-4o", messages=msgs,
request_overrides={"service_tier": "priority"},
)
assert kw["service_tier"] == "priority"
def test_fixed_temperature(self, transport):
msgs = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(model="gpt-4o", messages=msgs, fixed_temperature=0.6)
assert kw["temperature"] == 0.6
def test_omit_temperature(self, transport):
msgs = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(model="gpt-4o", messages=msgs, omit_temperature=True, fixed_temperature=0.5)
# omit wins
assert "temperature" not in kw
class TestChatCompletionsKimi:
"""Regression tests for the Kimi/Moonshot quirks migrated into the transport."""
def test_kimi_max_tokens_default(self, transport):
kw = transport.build_kwargs(
model="kimi-k2", messages=[{"role": "user", "content": "Hi"}],
is_kimi=True,
max_tokens_param_fn=lambda n: {"max_tokens": n},
)
# Kimi CLI default: 32000
assert kw["max_tokens"] == 32000
def test_kimi_reasoning_effort_top_level(self, transport):
kw = transport.build_kwargs(
model="kimi-k2", messages=[{"role": "user", "content": "Hi"}],
is_kimi=True,
reasoning_config={"effort": "high"},
max_tokens_param_fn=lambda n: {"max_tokens": n},
)
# Kimi requires reasoning_effort as a top-level parameter
assert kw["reasoning_effort"] == "high"
def test_kimi_reasoning_effort_omitted_when_thinking_disabled(self, transport):
kw = transport.build_kwargs(
model="kimi-k2", messages=[{"role": "user", "content": "Hi"}],
is_kimi=True,
reasoning_config={"enabled": False},
max_tokens_param_fn=lambda n: {"max_tokens": n},
)
# Mirror Kimi CLI: omit reasoning_effort entirely when thinking off
assert "reasoning_effort" not in kw
def test_kimi_thinking_enabled_extra_body(self, transport):
kw = transport.build_kwargs(
model="kimi-k2", messages=[{"role": "user", "content": "Hi"}],
is_kimi=True,
max_tokens_param_fn=lambda n: {"max_tokens": n},
)
assert kw["extra_body"]["thinking"] == {"type": "enabled"}
def test_kimi_thinking_disabled_extra_body(self, transport):
kw = transport.build_kwargs(
model="kimi-k2", messages=[{"role": "user", "content": "Hi"}],
is_kimi=True,
reasoning_config={"enabled": False},
max_tokens_param_fn=lambda n: {"max_tokens": n},
)
assert kw["extra_body"]["thinking"] == {"type": "disabled"}
class TestChatCompletionsValidate:
def test_none(self, transport):
assert transport.validate_response(None) is False
def test_no_choices(self, transport):
r = SimpleNamespace(choices=None)
assert transport.validate_response(r) is False
def test_empty_choices(self, transport):
r = SimpleNamespace(choices=[])
assert transport.validate_response(r) is False
def test_valid(self, transport):
r = SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="hi"))])
assert transport.validate_response(r) is True
class TestChatCompletionsNormalize:
def test_text_response(self, transport):
r = SimpleNamespace(
choices=[SimpleNamespace(
message=SimpleNamespace(content="Hello", tool_calls=None, reasoning_content=None),
finish_reason="stop",
)],
usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15),
)
nr = transport.normalize_response(r)
assert isinstance(nr, NormalizedResponse)
assert nr.content == "Hello"
assert nr.finish_reason == "stop"
assert nr.tool_calls is None
def test_tool_call_response(self, transport):
tc = SimpleNamespace(
id="call_123",
function=SimpleNamespace(name="terminal", arguments='{"command": "ls"}'),
)
r = SimpleNamespace(
choices=[SimpleNamespace(
message=SimpleNamespace(content=None, tool_calls=[tc], reasoning_content=None),
finish_reason="tool_calls",
)],
usage=SimpleNamespace(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
nr = transport.normalize_response(r)
assert len(nr.tool_calls) == 1
assert nr.tool_calls[0].name == "terminal"
assert nr.tool_calls[0].id == "call_123"
def test_tool_call_extra_content_preserved(self, transport):
"""Gemini 3 thinking models attach extra_content with thought_signature
on tool_calls. Without this replay on the next turn, the API rejects
the request with 400. The transport MUST surface extra_content so the
agent loop can write it back into the assistant message."""
tc = SimpleNamespace(
id="call_gem",
function=SimpleNamespace(name="terminal", arguments='{"command": "ls"}'),
extra_content={"google": {"thought_signature": "SIG_ABC123"}},
)
r = SimpleNamespace(
choices=[SimpleNamespace(
message=SimpleNamespace(content=None, tool_calls=[tc], reasoning_content=None),
finish_reason="tool_calls",
)],
usage=None,
)
nr = transport.normalize_response(r)
assert nr.tool_calls[0].provider_data == {
"extra_content": {"google": {"thought_signature": "SIG_ABC123"}}
}
def test_reasoning_content_preserved_separately(self, transport):
"""DeepSeek/Moonshot use reasoning_content distinct from reasoning.
Don't merge them — the thinking-prefill retry check reads each field
separately."""
r = SimpleNamespace(
choices=[SimpleNamespace(
message=SimpleNamespace(
content=None, tool_calls=None,
reasoning="summary text",
reasoning_content="detailed scratchpad",
),
finish_reason="stop",
)],
usage=None,
)
nr = transport.normalize_response(r)
assert nr.reasoning == "summary text"
assert nr.provider_data == {"reasoning_content": "detailed scratchpad"}
class TestChatCompletionsCacheStats:
def test_no_usage(self, transport):
r = SimpleNamespace(usage=None)
assert transport.extract_cache_stats(r) is None
def test_no_details(self, transport):
r = SimpleNamespace(usage=SimpleNamespace(prompt_tokens_details=None))
assert transport.extract_cache_stats(r) is None
def test_with_cache(self, transport):
details = SimpleNamespace(cached_tokens=500, cache_write_tokens=100)
r = SimpleNamespace(usage=SimpleNamespace(prompt_tokens_details=details))
result = transport.extract_cache_stats(r)
assert result == {"cached_tokens": 500, "creation_tokens": 100}

View file

@ -0,0 +1,220 @@
"""Tests for the ResponsesApiTransport (Codex)."""
import json
import pytest
from types import SimpleNamespace
from agent.transports import get_transport
from agent.transports.types import NormalizedResponse, ToolCall
@pytest.fixture
def transport():
import agent.transports.codex # noqa: F401
return get_transport("codex_responses")
class TestCodexTransportBasic:
def test_api_mode(self, transport):
assert transport.api_mode == "codex_responses"
def test_registered_on_import(self, transport):
assert transport is not None
def test_convert_tools(self, transport):
tools = [{
"type": "function",
"function": {
"name": "terminal",
"description": "Run a command",
"parameters": {"type": "object", "properties": {"command": {"type": "string"}}},
}
}]
result = transport.convert_tools(tools)
assert len(result) == 1
assert result[0]["type"] == "function"
assert result[0]["name"] == "terminal"
class TestCodexBuildKwargs:
def test_basic_kwargs(self, transport):
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello"},
]
kw = transport.build_kwargs(
model="gpt-5.4",
messages=messages,
tools=[],
)
assert kw["model"] == "gpt-5.4"
assert kw["instructions"] == "You are helpful."
assert "input" in kw
assert kw["store"] is False
def test_system_extracted_from_messages(self, transport):
messages = [
{"role": "system", "content": "Custom system prompt"},
{"role": "user", "content": "Hi"},
]
kw = transport.build_kwargs(model="gpt-5.4", messages=messages, tools=[])
assert kw["instructions"] == "Custom system prompt"
def test_no_system_uses_default(self, transport):
messages = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(model="gpt-5.4", messages=messages, tools=[])
assert kw["instructions"] # should be non-empty default
def test_reasoning_config(self, transport):
messages = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="gpt-5.4", messages=messages, tools=[],
reasoning_config={"effort": "high"},
)
assert kw.get("reasoning", {}).get("effort") == "high"
def test_reasoning_disabled(self, transport):
messages = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="gpt-5.4", messages=messages, tools=[],
reasoning_config={"enabled": False},
)
assert "reasoning" not in kw or kw.get("include") == []
def test_session_id_sets_cache_key(self, transport):
messages = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="gpt-5.4", messages=messages, tools=[],
session_id="test-session-123",
)
assert kw.get("prompt_cache_key") == "test-session-123"
def test_github_responses_no_cache_key(self, transport):
messages = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="gpt-5.4", messages=messages, tools=[],
session_id="test-session",
is_github_responses=True,
)
assert "prompt_cache_key" not in kw
def test_max_tokens(self, transport):
messages = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="gpt-5.4", messages=messages, tools=[],
max_tokens=4096,
)
assert kw.get("max_output_tokens") == 4096
def test_codex_backend_no_max_output_tokens(self, transport):
messages = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="gpt-5.4", messages=messages, tools=[],
max_tokens=4096,
is_codex_backend=True,
)
assert "max_output_tokens" not in kw
def test_xai_headers(self, transport):
messages = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="grok-3", messages=messages, tools=[],
session_id="conv-123",
is_xai_responses=True,
)
assert kw.get("extra_headers", {}).get("x-grok-conv-id") == "conv-123"
def test_minimal_effort_clamped(self, transport):
messages = [{"role": "user", "content": "Hi"}]
kw = transport.build_kwargs(
model="gpt-5.4", messages=messages, tools=[],
reasoning_config={"effort": "minimal"},
)
# "minimal" should be clamped to "low"
assert kw.get("reasoning", {}).get("effort") == "low"
class TestCodexValidateResponse:
def test_none_response(self, transport):
assert transport.validate_response(None) is False
def test_empty_output(self, transport):
r = SimpleNamespace(output=[], output_text=None)
assert transport.validate_response(r) is False
def test_valid_output(self, transport):
r = SimpleNamespace(output=[{"type": "message", "content": []}])
assert transport.validate_response(r) is True
def test_output_text_fallback_not_valid(self, transport):
"""validate_response is strict — output_text doesn't make it valid.
The caller handles output_text fallback with diagnostic logging."""
r = SimpleNamespace(output=None, output_text="Some text")
assert transport.validate_response(r) is False
class TestCodexMapFinishReason:
def test_completed(self, transport):
assert transport.map_finish_reason("completed") == "stop"
def test_incomplete(self, transport):
assert transport.map_finish_reason("incomplete") == "length"
def test_failed(self, transport):
assert transport.map_finish_reason("failed") == "stop"
def test_unknown(self, transport):
assert transport.map_finish_reason("unknown_status") == "stop"
class TestCodexNormalizeResponse:
def test_text_response(self, transport):
"""Normalize a simple text Codex response."""
r = SimpleNamespace(
output=[
SimpleNamespace(
type="message",
role="assistant",
content=[SimpleNamespace(type="output_text", text="Hello world")],
status="completed",
),
],
status="completed",
incomplete_details=None,
usage=SimpleNamespace(input_tokens=10, output_tokens=5,
input_tokens_details=None, output_tokens_details=None),
)
nr = transport.normalize_response(r)
assert isinstance(nr, NormalizedResponse)
assert nr.content == "Hello world"
assert nr.finish_reason == "stop"
def test_tool_call_response(self, transport):
"""Normalize a Codex response with tool calls."""
r = SimpleNamespace(
output=[
SimpleNamespace(
type="function_call",
call_id="call_abc123",
name="terminal",
arguments=json.dumps({"command": "ls"}),
id="fc_abc123",
status="completed",
),
],
status="completed",
incomplete_details=None,
usage=SimpleNamespace(input_tokens=10, output_tokens=20,
input_tokens_details=None, output_tokens_details=None),
)
nr = transport.normalize_response(r)
assert nr.finish_reason == "tool_calls"
assert len(nr.tool_calls) == 1
tc = nr.tool_calls[0]
assert tc.name == "terminal"
assert '"command"' in tc.arguments

View file

@ -0,0 +1,220 @@
"""Tests for the transport ABC, registry, and AnthropicTransport."""
import pytest
from types import SimpleNamespace
from unittest.mock import MagicMock
from agent.transports.base import ProviderTransport
from agent.transports.types import NormalizedResponse, ToolCall, Usage
from agent.transports import get_transport, register_transport, _REGISTRY
# ── ABC contract tests ──────────────────────────────────────────────────
class TestProviderTransportABC:
"""Verify the ABC contract is enforceable."""
def test_cannot_instantiate_abc(self):
with pytest.raises(TypeError):
ProviderTransport()
def test_concrete_must_implement_all_abstract(self):
class Incomplete(ProviderTransport):
@property
def api_mode(self):
return "test"
with pytest.raises(TypeError):
Incomplete()
def test_minimal_concrete(self):
class Minimal(ProviderTransport):
@property
def api_mode(self):
return "test_minimal"
def convert_messages(self, messages, **kw):
return messages
def convert_tools(self, tools):
return tools
def build_kwargs(self, model, messages, tools=None, **params):
return {"model": model, "messages": messages}
def normalize_response(self, response, **kw):
return NormalizedResponse(content="ok", tool_calls=None, finish_reason="stop")
t = Minimal()
assert t.api_mode == "test_minimal"
assert t.validate_response(None) is True # default
assert t.extract_cache_stats(None) is None # default
assert t.map_finish_reason("end_turn") == "end_turn" # default passthrough
# ── Registry tests ───────────────────────────────────────────────────────
class TestTransportRegistry:
def test_get_unregistered_returns_none(self):
assert get_transport("nonexistent_mode") is None
def test_anthropic_registered_on_import(self):
import agent.transports.anthropic # noqa: F401
t = get_transport("anthropic_messages")
assert t is not None
assert t.api_mode == "anthropic_messages"
def test_register_and_get(self):
class DummyTransport(ProviderTransport):
@property
def api_mode(self):
return "dummy_test"
def convert_messages(self, messages, **kw):
return messages
def convert_tools(self, tools):
return tools
def build_kwargs(self, model, messages, tools=None, **params):
return {}
def normalize_response(self, response, **kw):
return NormalizedResponse(content=None, tool_calls=None, finish_reason="stop")
register_transport("dummy_test", DummyTransport)
t = get_transport("dummy_test")
assert t.api_mode == "dummy_test"
# Cleanup
_REGISTRY.pop("dummy_test", None)
# ── AnthropicTransport tests ────────────────────────────────────────────
class TestAnthropicTransport:
@pytest.fixture
def transport(self):
import agent.transports.anthropic # noqa: F401
return get_transport("anthropic_messages")
def test_api_mode(self, transport):
assert transport.api_mode == "anthropic_messages"
def test_convert_tools_simple(self, transport):
tools = [{
"type": "function",
"function": {
"name": "test_tool",
"description": "A test",
"parameters": {"type": "object", "properties": {}},
}
}]
result = transport.convert_tools(tools)
assert len(result) == 1
assert result[0]["name"] == "test_tool"
assert "input_schema" in result[0]
def test_validate_response_none(self, transport):
assert transport.validate_response(None) is False
def test_validate_response_empty_content(self, transport):
r = SimpleNamespace(content=[])
assert transport.validate_response(r) is False
def test_validate_response_valid(self, transport):
r = SimpleNamespace(content=[SimpleNamespace(type="text", text="hello")])
assert transport.validate_response(r) is True
def test_map_finish_reason(self, transport):
assert transport.map_finish_reason("end_turn") == "stop"
assert transport.map_finish_reason("tool_use") == "tool_calls"
assert transport.map_finish_reason("max_tokens") == "length"
assert transport.map_finish_reason("stop_sequence") == "stop"
assert transport.map_finish_reason("refusal") == "content_filter"
assert transport.map_finish_reason("model_context_window_exceeded") == "length"
assert transport.map_finish_reason("unknown") == "stop"
def test_extract_cache_stats_none_usage(self, transport):
r = SimpleNamespace(usage=None)
assert transport.extract_cache_stats(r) is None
def test_extract_cache_stats_with_cache(self, transport):
usage = SimpleNamespace(cache_read_input_tokens=100, cache_creation_input_tokens=50)
r = SimpleNamespace(usage=usage)
result = transport.extract_cache_stats(r)
assert result == {"cached_tokens": 100, "creation_tokens": 50}
def test_extract_cache_stats_zero(self, transport):
usage = SimpleNamespace(cache_read_input_tokens=0, cache_creation_input_tokens=0)
r = SimpleNamespace(usage=usage)
assert transport.extract_cache_stats(r) is None
def test_normalize_response_text(self, transport):
"""Test normalization of a simple text response."""
r = SimpleNamespace(
content=[SimpleNamespace(type="text", text="Hello world")],
stop_reason="end_turn",
usage=SimpleNamespace(input_tokens=10, output_tokens=5),
model="claude-sonnet-4-6",
)
nr = transport.normalize_response(r)
assert isinstance(nr, NormalizedResponse)
assert nr.content == "Hello world"
assert nr.tool_calls is None or nr.tool_calls == []
assert nr.finish_reason == "stop"
def test_normalize_response_tool_calls(self, transport):
"""Test normalization of a tool-use response."""
r = SimpleNamespace(
content=[
SimpleNamespace(
type="tool_use",
id="toolu_123",
name="terminal",
input={"command": "ls"},
),
],
stop_reason="tool_use",
usage=SimpleNamespace(input_tokens=10, output_tokens=20),
model="claude-sonnet-4-6",
)
nr = transport.normalize_response(r)
assert nr.finish_reason == "tool_calls"
assert len(nr.tool_calls) == 1
tc = nr.tool_calls[0]
assert tc.name == "terminal"
assert tc.id == "toolu_123"
assert '"command"' in tc.arguments
def test_normalize_response_thinking(self, transport):
"""Test normalization preserves thinking content."""
r = SimpleNamespace(
content=[
SimpleNamespace(type="thinking", thinking="Let me think..."),
SimpleNamespace(type="text", text="The answer is 42"),
],
stop_reason="end_turn",
usage=SimpleNamespace(input_tokens=10, output_tokens=15),
model="claude-sonnet-4-6",
)
nr = transport.normalize_response(r)
assert nr.content == "The answer is 42"
assert nr.reasoning == "Let me think..."
def test_build_kwargs_returns_dict(self, transport):
"""Test build_kwargs produces a usable kwargs dict."""
messages = [{"role": "user", "content": "Hello"}]
kw = transport.build_kwargs(
model="claude-sonnet-4-6",
messages=messages,
max_tokens=1024,
)
assert isinstance(kw, dict)
assert "model" in kw
assert "max_tokens" in kw
assert "messages" in kw
def test_convert_messages_extracts_system(self, transport):
"""Test convert_messages separates system from messages."""
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi"},
]
system, msgs = transport.convert_messages(messages)
# System should be extracted
assert system is not None
# Messages should only have user
assert len(msgs) >= 1

View file

@ -0,0 +1,151 @@
"""Tests for agent/transports/types.py — dataclass construction + helpers."""
import json
import pytest
from agent.transports.types import (
NormalizedResponse,
ToolCall,
Usage,
build_tool_call,
map_finish_reason,
)
# ---------------------------------------------------------------------------
# ToolCall
# ---------------------------------------------------------------------------
class TestToolCall:
def test_basic_construction(self):
tc = ToolCall(id="call_abc", name="terminal", arguments='{"cmd": "ls"}')
assert tc.id == "call_abc"
assert tc.name == "terminal"
assert tc.arguments == '{"cmd": "ls"}'
assert tc.provider_data is None
def test_none_id(self):
tc = ToolCall(id=None, name="read_file", arguments="{}")
assert tc.id is None
def test_provider_data(self):
tc = ToolCall(
id="call_x",
name="t",
arguments="{}",
provider_data={"call_id": "call_x", "response_item_id": "fc_x"},
)
assert tc.provider_data["call_id"] == "call_x"
assert tc.provider_data["response_item_id"] == "fc_x"
# ---------------------------------------------------------------------------
# Usage
# ---------------------------------------------------------------------------
class TestUsage:
def test_defaults(self):
u = Usage()
assert u.prompt_tokens == 0
assert u.completion_tokens == 0
assert u.total_tokens == 0
assert u.cached_tokens == 0
def test_explicit(self):
u = Usage(prompt_tokens=100, completion_tokens=50, total_tokens=150, cached_tokens=80)
assert u.total_tokens == 150
# ---------------------------------------------------------------------------
# NormalizedResponse
# ---------------------------------------------------------------------------
class TestNormalizedResponse:
def test_text_only(self):
r = NormalizedResponse(content="hello", tool_calls=None, finish_reason="stop")
assert r.content == "hello"
assert r.tool_calls is None
assert r.finish_reason == "stop"
assert r.reasoning is None
assert r.usage is None
assert r.provider_data is None
def test_with_tool_calls(self):
tcs = [ToolCall(id="call_1", name="terminal", arguments='{"cmd":"pwd"}')]
r = NormalizedResponse(content=None, tool_calls=tcs, finish_reason="tool_calls")
assert r.finish_reason == "tool_calls"
assert len(r.tool_calls) == 1
assert r.tool_calls[0].name == "terminal"
def test_with_reasoning(self):
r = NormalizedResponse(
content="answer",
tool_calls=None,
finish_reason="stop",
reasoning="I thought about it",
)
assert r.reasoning == "I thought about it"
def test_with_provider_data(self):
r = NormalizedResponse(
content=None,
tool_calls=None,
finish_reason="stop",
provider_data={"reasoning_details": [{"type": "thinking", "thinking": "hmm"}]},
)
assert r.provider_data["reasoning_details"][0]["type"] == "thinking"
# ---------------------------------------------------------------------------
# build_tool_call
# ---------------------------------------------------------------------------
class TestBuildToolCall:
def test_dict_arguments_serialized(self):
tc = build_tool_call(id="call_1", name="terminal", arguments={"cmd": "ls"})
assert tc.arguments == json.dumps({"cmd": "ls"})
assert tc.provider_data is None
def test_string_arguments_passthrough(self):
tc = build_tool_call(id="call_2", name="read_file", arguments='{"path": "/tmp"}')
assert tc.arguments == '{"path": "/tmp"}'
def test_provider_fields(self):
tc = build_tool_call(
id="call_3",
name="terminal",
arguments="{}",
call_id="call_3",
response_item_id="fc_3",
)
assert tc.provider_data == {"call_id": "call_3", "response_item_id": "fc_3"}
def test_none_id(self):
tc = build_tool_call(id=None, name="t", arguments="{}")
assert tc.id is None
# ---------------------------------------------------------------------------
# map_finish_reason
# ---------------------------------------------------------------------------
class TestMapFinishReason:
ANTHROPIC_MAP = {
"end_turn": "stop",
"tool_use": "tool_calls",
"max_tokens": "length",
"stop_sequence": "stop",
"refusal": "content_filter",
}
def test_known_reason(self):
assert map_finish_reason("end_turn", self.ANTHROPIC_MAP) == "stop"
assert map_finish_reason("tool_use", self.ANTHROPIC_MAP) == "tool_calls"
assert map_finish_reason("max_tokens", self.ANTHROPIC_MAP) == "length"
assert map_finish_reason("refusal", self.ANTHROPIC_MAP) == "content_filter"
def test_unknown_reason_defaults_to_stop(self):
assert map_finish_reason("something_new", self.ANTHROPIC_MAP) == "stop"
def test_none_reason(self):
assert map_finish_reason(None, self.ANTHROPIC_MAP) == "stop"

View file

@ -254,3 +254,88 @@ class TestCliApprovalUi:
# Command got truncated with a marker.
assert "(command truncated" in rendered
class TestApprovalCallbackThreadLocalWiring:
"""Regression guard for the thread-local callback freeze (#13617 / #13618).
After 62348cff made _approval_callback / _sudo_password_callback thread-local
(ACP GHSA-qg5c-hvr5-hjgr), the CLI agent thread could no longer see callbacks
registered in the main thread the dangerous-command prompt silently fell
back to stdin input() and deadlocked against prompt_toolkit. The fix is to
register the callbacks INSIDE the agent worker thread (matching the ACP
pattern). These tests lock in that invariant.
"""
def test_main_thread_registration_is_invisible_to_child_thread(self):
"""Confirms the underlying threading.local semantics that drove the bug.
If this ever starts passing as "visible", the thread-local isolation
is gone and the ACP race GHSA-qg5c-hvr5-hjgr may be back.
"""
from tools.terminal_tool import (
set_approval_callback,
_get_approval_callback,
)
def main_cb(_cmd, _desc):
return "once"
set_approval_callback(main_cb)
try:
seen = {}
def _child():
seen["value"] = _get_approval_callback()
t = threading.Thread(target=_child, daemon=True)
t.start()
t.join(timeout=2)
assert seen["value"] is None
finally:
set_approval_callback(None)
def test_child_thread_registration_is_visible_and_cleared_in_finally(self):
"""The fix pattern: register INSIDE the worker thread, clear in finally.
This is exactly what cli.py's run_agent() closure does. If this test
fails, the CLI approval prompt freeze (#13617) has regressed.
"""
from tools.terminal_tool import (
set_approval_callback,
set_sudo_password_callback,
_get_approval_callback,
_get_sudo_password_callback,
)
def approval_cb(_cmd, _desc):
return "once"
def sudo_cb():
return "hunter2"
seen = {}
def _worker():
# Mimic cli.py's run_agent() thread target.
set_approval_callback(approval_cb)
set_sudo_password_callback(sudo_cb)
try:
seen["approval"] = _get_approval_callback()
seen["sudo"] = _get_sudo_password_callback()
finally:
set_approval_callback(None)
set_sudo_password_callback(None)
seen["approval_after"] = _get_approval_callback()
seen["sudo_after"] = _get_sudo_password_callback()
t = threading.Thread(target=_worker, daemon=True)
t.start()
t.join(timeout=2)
assert seen["approval"] is approval_cb
assert seen["sudo"] is sudo_cb
# Finally block must clear both slots — otherwise a reused thread
# would hold a stale reference to a disposed CLI instance.
assert seen["approval_after"] is None
assert seen["sudo_after"] is None

View file

@ -0,0 +1,105 @@
"""Tests for CLI external-editor support."""
from unittest.mock import patch
from cli import HermesCLI
class _FakeBuffer:
def __init__(self, text=""):
self.calls = []
self.text = text
self.cursor_position = len(text)
def open_in_editor(self, validate_and_handle=False):
self.calls.append(validate_and_handle)
class _FakeApp:
def __init__(self):
self.current_buffer = _FakeBuffer()
def _make_cli(with_app=True):
cli_obj = HermesCLI.__new__(HermesCLI)
cli_obj._app = _FakeApp() if with_app else None
cli_obj._command_running = False
cli_obj._command_status = ""
cli_obj._command_display = ""
cli_obj._sudo_state = None
cli_obj._secret_state = None
cli_obj._approval_state = None
cli_obj._clarify_state = None
cli_obj._skip_paste_collapse = False
return cli_obj
def test_open_external_editor_uses_prompt_toolkit_buffer_editor():
cli_obj = _make_cli()
assert cli_obj._open_external_editor() is True
assert cli_obj._app.current_buffer.calls == [False]
def test_open_external_editor_rejects_when_no_tui():
cli_obj = _make_cli(with_app=False)
with patch("cli._cprint") as mock_cprint:
assert cli_obj._open_external_editor() is False
assert mock_cprint.called
assert "interactive cli" in str(mock_cprint.call_args).lower()
def test_open_external_editor_rejects_modal_prompts():
cli_obj = _make_cli()
cli_obj._approval_state = {"selected": 0}
with patch("cli._cprint") as mock_cprint:
assert cli_obj._open_external_editor() is False
assert mock_cprint.called
assert "active prompt" in str(mock_cprint.call_args).lower()
def test_open_external_editor_uses_explicit_buffer_when_provided():
cli_obj = _make_cli()
external_buffer = _FakeBuffer()
assert cli_obj._open_external_editor(buffer=external_buffer) is True
assert external_buffer.calls == [False]
assert cli_obj._app.current_buffer.calls == []
def test_expand_paste_references_replaces_placeholder_with_file_contents(tmp_path):
cli_obj = _make_cli()
paste_file = tmp_path / "paste.txt"
paste_file.write_text("line one\nline two", encoding="utf-8")
text = f"before [Pasted text #1: 2 lines → {paste_file}] after"
expanded = cli_obj._expand_paste_references(text)
assert expanded == "before line one\nline two after"
def test_open_external_editor_expands_paste_placeholders_before_open(tmp_path):
cli_obj = _make_cli()
paste_file = tmp_path / "paste.txt"
paste_file.write_text("alpha\nbeta", encoding="utf-8")
buffer = _FakeBuffer(text=f"[Pasted text #1: 2 lines → {paste_file}]")
assert cli_obj._open_external_editor(buffer=buffer) is True
assert buffer.text == "alpha\nbeta"
assert buffer.cursor_position == len("alpha\nbeta")
assert buffer.calls == [False]
def test_open_external_editor_sets_skip_collapse_flag_during_expansion(tmp_path):
cli_obj = _make_cli()
paste_file = tmp_path / "paste.txt"
paste_file.write_text("a\nb\nc\nd\ne\nf", encoding="utf-8")
buffer = _FakeBuffer(text=f"[Pasted text #1: 6 lines \u2192 {paste_file}]")
# After expansion the flag should have been set (to prevent re-collapse)
assert cli_obj._open_external_editor(buffer=buffer) is True
# Flag is consumed by _on_text_changed, but since no handler is attached
# in tests it stays True until the handler resets it.
assert cli_obj._skip_paste_collapse is True

View file

@ -147,6 +147,37 @@ class TestEscapedSpaces:
assert result["path"] == tmp_image_with_spaces
assert result["remainder"] == "what is this?"
def test_unquoted_spaces_in_path(self, tmp_image_with_spaces):
result = _detect_file_drop(str(tmp_image_with_spaces))
assert result is not None
assert result["path"] == tmp_image_with_spaces
assert result["is_image"] is True
assert result["remainder"] == ""
def test_unquoted_spaces_with_trailing_text(self, tmp_image_with_spaces):
user_input = f"{tmp_image_with_spaces} what is this?"
result = _detect_file_drop(user_input)
assert result is not None
assert result["path"] == tmp_image_with_spaces
assert result["remainder"] == "what is this?"
def test_mixed_escaped_and_literal_spaces_in_path(self, tmp_path):
img = tmp_path / "Screenshot 2026-04-21 at 1.04.43 PM.png"
img.write_bytes(b"\x89PNG\r\n\x1a\n")
mixed = str(img).replace("Screenshot ", "Screenshot\\ ").replace("2026-04-21 ", "2026-04-21\\ ").replace("at ", "at\\ ")
result = _detect_file_drop(mixed)
assert result is not None
assert result["path"] == img
assert result["is_image"] is True
assert result["remainder"] == ""
def test_file_uri_image_path(self, tmp_image_with_spaces):
uri = tmp_image_with_spaces.as_uri()
result = _detect_file_drop(uri)
assert result is not None
assert result["path"] == tmp_image_with_spaces
assert result["is_image"] is True
def test_tilde_prefixed_path(self, tmp_path, monkeypatch):
home = tmp_path / "home"
img = home / "storage" / "shared" / "Pictures" / "cat.png"

View file

@ -0,0 +1,141 @@
from io import StringIO
from rich.console import Console
from rich.markdown import Markdown
from cli import _render_final_assistant_content
def _render_to_text(renderable) -> str:
buf = StringIO()
Console(file=buf, width=80, force_terminal=False, color_system=None).print(renderable)
return buf.getvalue()
def test_final_assistant_content_uses_markdown_renderable():
renderable = _render_final_assistant_content("# Title\n\n- one\n- two")
assert isinstance(renderable, Markdown)
output = _render_to_text(renderable)
assert "Title" in output
assert "one" in output
assert "two" in output
def test_final_assistant_content_strips_ansi_before_markdown_rendering():
renderable = _render_final_assistant_content("\x1b[31m# Title\x1b[0m")
output = _render_to_text(renderable)
assert "Title" in output
assert "\x1b" not in output
def test_final_assistant_content_can_strip_markdown_syntax():
renderable = _render_final_assistant_content(
"***Bold italic***\n~~Strike~~\n- item\n# Title\n`code`",
mode="strip",
)
output = _render_to_text(renderable)
assert "Bold italic" in output
assert "Strike" in output
assert "item" in output
assert "Title" in output
assert "code" in output
assert "***" not in output
assert "~~" not in output
assert "`" not in output
def test_strip_mode_preserves_lists():
renderable = _render_final_assistant_content(
"**Formatting**\n- Ran prettier\n- Files changed\n- Verified clean",
mode="strip",
)
output = _render_to_text(renderable)
assert "- Ran prettier" in output
assert "- Files changed" in output
assert "- Verified clean" in output
assert "**" not in output
def test_strip_mode_preserves_ordered_lists():
renderable = _render_final_assistant_content(
"1. First item\n2. Second item\n3. Third item",
mode="strip",
)
output = _render_to_text(renderable)
assert "1. First" in output
assert "2. Second" in output
assert "3. Third" in output
def test_strip_mode_preserves_blockquotes():
renderable = _render_final_assistant_content(
"> This is quoted text\n> Another quoted line",
mode="strip",
)
output = _render_to_text(renderable)
assert "> This is quoted" in output
assert "> Another quoted" in output
def test_strip_mode_preserves_checkboxes():
renderable = _render_final_assistant_content(
"- [ ] Todo item\n- [x] Done item",
mode="strip",
)
output = _render_to_text(renderable)
assert "- [ ] Todo" in output
assert "- [x] Done" in output
def test_strip_mode_preserves_table_structure_while_cleaning_cell_markdown():
renderable = _render_final_assistant_content(
"| Syntax | Example |\n|---|---|\n| Bold | `**bold**` |\n| Strike | `~~strike~~` |",
mode="strip",
)
output = _render_to_text(renderable)
assert "| Syntax | Example |" in output
assert "|---|---|" in output
assert "| Bold | bold |" in output
assert "| Strike | strike |" in output
assert "**" not in output
assert "~~" not in output
assert "`" not in output
def test_final_assistant_content_can_leave_markdown_raw():
renderable = _render_final_assistant_content("***Bold italic***", mode="raw")
output = _render_to_text(renderable)
assert "***Bold italic***" in output
def test_strip_mode_preserves_intraword_underscores_in_snake_case_identifiers():
renderable = _render_final_assistant_content(
"Let me look at test_case_with_underscores and SOME_CONST "
"then /tmp/snake_case_dir/file_with_name.py",
mode="strip",
)
output = _render_to_text(renderable)
assert "test_case_with_underscores" in output
assert "SOME_CONST" in output
assert "snake_case_dir" in output
assert "file_with_name" in output
def test_strip_mode_still_strips_boundary_underscore_emphasis():
renderable = _render_final_assistant_content(
"say _hi_ and __bold__ now",
mode="strip",
)
output = _render_to_text(renderable)
assert "say hi and bold now" in output

View file

@ -207,48 +207,11 @@ def test_cli_turn_routing_uses_primary_when_disabled(monkeypatch):
shell.api_mode = "chat_completions"
shell.base_url = "https://openrouter.ai/api/v1"
shell.api_key = "sk-primary"
shell._smart_model_routing = {"enabled": False}
result = shell._resolve_turn_agent_config("what time is it in tokyo?")
assert result["model"] == "gpt-5"
assert result["runtime"]["provider"] == "openrouter"
assert result["label"] is None
def test_cli_turn_routing_uses_cheap_model_when_simple(monkeypatch):
cli = _import_cli()
def _runtime_resolve(**kwargs):
assert kwargs["requested"] == "zai"
return {
"provider": "zai",
"api_mode": "chat_completions",
"base_url": "https://open.z.ai/api/v1",
"api_key": "cheap-key",
"source": "env/config",
}
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _runtime_resolve)
shell = cli.HermesCLI(model="anthropic/claude-sonnet-4", compact=True, max_turns=1)
shell.provider = "openrouter"
shell.api_mode = "chat_completions"
shell.base_url = "https://openrouter.ai/api/v1"
shell.api_key = "primary-key"
shell._smart_model_routing = {
"enabled": True,
"cheap_model": {"provider": "zai", "model": "glm-5-air"},
"max_simple_chars": 160,
"max_simple_words": 28,
}
result = shell._resolve_turn_agent_config("what time is it in tokyo?")
assert result["model"] == "glm-5-air"
assert result["runtime"]["provider"] == "zai"
assert result["runtime"]["api_key"] == "cheap-key"
assert result["label"] is not None
def test_cli_prefers_config_provider_over_stale_env_override(monkeypatch):

View file

@ -0,0 +1,146 @@
"""Regression tests for classic-CLI mid-run /steer dispatch.
Background
----------
/steer sent while the agent is running used to be queued through
``self._pending_input`` alongside ordinary user input. ``process_loop``
pulls from that queue and calls ``process_command()`` but while the
agent is running, ``process_loop`` is blocked inside ``self.chat()``.
By the time the queued /steer was pulled, ``_agent_running`` had
already flipped back to False, so ``process_command()`` took the idle
fallback (``"No agent running; queued as next turn"``) and delivered
the steer as an ordinary next-turn message.
The fix dispatches /steer inline on the UI thread when the agent is
running matching the existing pattern for /model so the steer
reaches ``agent.steer()`` (thread-safe) without touching the queue.
These tests exercise the detector + inline dispatch without starting a
prompt_toolkit app.
"""
from __future__ import annotations
import importlib
import sys
from unittest.mock import MagicMock, patch
def _make_cli():
"""Create a HermesCLI instance with prompt_toolkit stubbed out."""
_clean_config = {
"model": {
"default": "anthropic/claude-opus-4.6",
"base_url": "https://openrouter.ai/api/v1",
"provider": "auto",
},
"display": {"compact": False, "tool_progress": "all"},
"agent": {},
"terminal": {"env_type": "local"},
}
clean_env = {"LLM_MODEL": "", "HERMES_MAX_ITERATIONS": ""}
prompt_toolkit_stubs = {
"prompt_toolkit": MagicMock(),
"prompt_toolkit.history": MagicMock(),
"prompt_toolkit.styles": MagicMock(),
"prompt_toolkit.patch_stdout": MagicMock(),
"prompt_toolkit.application": MagicMock(),
"prompt_toolkit.layout": MagicMock(),
"prompt_toolkit.layout.processors": MagicMock(),
"prompt_toolkit.filters": MagicMock(),
"prompt_toolkit.layout.dimension": MagicMock(),
"prompt_toolkit.layout.menus": MagicMock(),
"prompt_toolkit.widgets": MagicMock(),
"prompt_toolkit.key_binding": MagicMock(),
"prompt_toolkit.completion": MagicMock(),
"prompt_toolkit.formatted_text": MagicMock(),
"prompt_toolkit.auto_suggest": MagicMock(),
}
with patch.dict(sys.modules, prompt_toolkit_stubs), patch.dict(
"os.environ", clean_env, clear=False
):
import cli as _cli_mod
_cli_mod = importlib.reload(_cli_mod)
with patch.object(_cli_mod, "get_tool_definitions", return_value=[]), patch.dict(
_cli_mod.__dict__, {"CLI_CONFIG": _clean_config}
):
return _cli_mod.HermesCLI()
class TestSteerInlineDetector:
"""_should_handle_steer_command_inline gates the busy-path fast dispatch."""
def test_detects_steer_when_agent_running(self):
cli = _make_cli()
cli._agent_running = True
assert cli._should_handle_steer_command_inline("/steer focus on error handling") is True
def test_ignores_steer_when_agent_idle(self):
"""Idle-path /steer should fall through to the normal process_loop
dispatch so the queue-style fallback message is emitted."""
cli = _make_cli()
cli._agent_running = False
assert cli._should_handle_steer_command_inline("/steer do something") is False
def test_ignores_non_slash_input(self):
cli = _make_cli()
cli._agent_running = True
assert cli._should_handle_steer_command_inline("steer without slash") is False
assert cli._should_handle_steer_command_inline("") is False
def test_ignores_other_slash_commands(self):
cli = _make_cli()
cli._agent_running = True
assert cli._should_handle_steer_command_inline("/queue hello") is False
assert cli._should_handle_steer_command_inline("/stop") is False
assert cli._should_handle_steer_command_inline("/help") is False
def test_ignores_steer_with_attached_images(self):
"""Image payloads take the normal path; steer doesn't accept images."""
cli = _make_cli()
cli._agent_running = True
assert cli._should_handle_steer_command_inline("/steer text", has_images=True) is False
class TestSteerBusyPathDispatch:
"""When the detector fires, process_command('/steer ...') must call
agent.steer() directly rather than the idle-path fallback."""
def test_process_command_routes_to_agent_steer(self):
"""With _agent_running=True and agent.steer present, /steer reaches
agent.steer(payload), NOT _pending_input."""
cli = _make_cli()
cli._agent_running = True
cli.agent = MagicMock()
cli.agent.steer = MagicMock(return_value=True)
# Make sure the idle-path fallback would be observable if taken
cli._pending_input = MagicMock()
cli.process_command("/steer focus on errors")
cli.agent.steer.assert_called_once_with("focus on errors")
cli._pending_input.put.assert_not_called()
def test_idle_path_queues_as_next_turn(self):
"""Control — when the agent is NOT running, /steer correctly falls
back to next-turn queue semantics. Demonstrates why the fix was
needed: the queue path only works when you can actually drain it."""
cli = _make_cli()
cli._agent_running = False
cli.agent = MagicMock()
cli.agent.steer = MagicMock(return_value=True)
cli._pending_input = MagicMock()
cli.process_command("/steer would-be-next-turn")
# Idle path does NOT call agent.steer
cli.agent.steer.assert_not_called()
# It puts the payload in the queue as a normal next-turn message
cli._pending_input.put.assert_called_once_with("would-be-next-turn")
if __name__ == "__main__": # pragma: no cover
import pytest
pytest.main([__file__, "-v"])

View file

@ -0,0 +1,92 @@
import importlib
import os
import sys
from unittest.mock import MagicMock, patch
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
_cli_mod = None
def _make_cli(user_message_preview=None):
global _cli_mod
clean_config = {
"model": {
"default": "anthropic/claude-opus-4.6",
"base_url": "https://openrouter.ai/api/v1",
"provider": "auto",
},
"display": {
"compact": False,
"tool_progress": "all",
"user_message_preview": user_message_preview or {"first_lines": 2, "last_lines": 2},
},
"agent": {},
"terminal": {"env_type": "local"},
}
clean_env = {"LLM_MODEL": "", "HERMES_MAX_ITERATIONS": ""}
prompt_toolkit_stubs = {
"prompt_toolkit": MagicMock(),
"prompt_toolkit.history": MagicMock(),
"prompt_toolkit.styles": MagicMock(),
"prompt_toolkit.patch_stdout": MagicMock(),
"prompt_toolkit.application": MagicMock(),
"prompt_toolkit.layout": MagicMock(),
"prompt_toolkit.layout.processors": MagicMock(),
"prompt_toolkit.filters": MagicMock(),
"prompt_toolkit.layout.dimension": MagicMock(),
"prompt_toolkit.layout.menus": MagicMock(),
"prompt_toolkit.widgets": MagicMock(),
"prompt_toolkit.key_binding": MagicMock(),
"prompt_toolkit.completion": MagicMock(),
"prompt_toolkit.formatted_text": MagicMock(),
"prompt_toolkit.auto_suggest": MagicMock(),
}
with patch.dict(sys.modules, prompt_toolkit_stubs), patch.dict("os.environ", clean_env, clear=False):
import cli as mod
mod = importlib.reload(mod)
_cli_mod = mod
with patch.object(mod, "get_tool_definitions", return_value=[]), patch.dict(mod.__dict__, {"CLI_CONFIG": clean_config}):
return mod.HermesCLI()
class TestSubmittedUserMessagePreview:
def test_default_preview_shows_first_two_lines_and_last_two_lines(self):
cli = _make_cli()
rendered = cli._format_submitted_user_message_preview(
"line1\nline2\nline3\nline4\nline5\nline6"
)
assert "line1" in rendered
assert "line2" in rendered
assert "line5" in rendered
assert "line6" in rendered
assert "line3" not in rendered
assert "line4" not in rendered
assert "(+2 more lines)" in rendered
def test_preview_can_hide_last_lines(self):
cli = _make_cli({"first_lines": 2, "last_lines": 0})
rendered = cli._format_submitted_user_message_preview(
"line1\nline2\nline3\nline4\nline5\nline6"
)
assert "line1" in rendered
assert "line2" in rendered
assert "line5" not in rendered
assert "line6" not in rendered
assert "(+4 more lines)" in rendered
def test_invalid_first_lines_value_falls_back_to_one(self):
cli = _make_cli({"first_lines": 0, "last_lines": 2})
rendered = cli._format_submitted_user_message_preview("line1\nline2\nline3\nline4")
assert "line1" in rendered
assert "line3" in rendered
assert "line4" in rendered
assert "(+1 more line)" in rendered

View file

@ -183,27 +183,10 @@ class TestFastModeRouting(unittest.TestCase):
acp_command=None,
acp_args=[],
_credential_pool=None,
_smart_model_routing={},
service_tier="priority",
)
original_runtime = {
"api_key": "***",
"base_url": "https://openrouter.ai/api/v1",
"provider": "openrouter",
"api_mode": "chat_completions",
"command": None,
"args": [],
"credential_pool": None,
}
with patch("agent.smart_model_routing.resolve_turn_route", return_value={
"model": "gpt-5.4",
"runtime": dict(original_runtime),
"label": None,
"signature": ("gpt-5.4", "openrouter", "https://openrouter.ai/api/v1", "chat_completions", None, ()),
}):
route = cli_mod.HermesCLI._resolve_turn_agent_config(stub, "hi")
route = cli_mod.HermesCLI._resolve_turn_agent_config(stub, "hi")
# Provider should NOT have changed
assert route["runtime"]["provider"] == "openrouter"
@ -222,26 +205,10 @@ class TestFastModeRouting(unittest.TestCase):
acp_command=None,
acp_args=[],
_credential_pool=None,
_smart_model_routing={},
service_tier="priority",
)
primary_route = {
"model": "gpt-5.3-codex",
"runtime": {
"api_key": "***",
"base_url": "https://openrouter.ai/api/v1",
"provider": "openrouter",
"api_mode": "chat_completions",
"command": None,
"args": [],
"credential_pool": None,
},
"label": None,
"signature": ("gpt-5.3-codex", "openrouter", "https://openrouter.ai/api/v1", "chat_completions", None, ()),
}
with patch("agent.smart_model_routing.resolve_turn_route", return_value=primary_route):
route = cli_mod.HermesCLI._resolve_turn_agent_config(stub, "hi")
route = cli_mod.HermesCLI._resolve_turn_agent_config(stub, "hi")
assert route["runtime"]["provider"] == "openrouter"
assert route.get("request_overrides") is None
@ -329,27 +296,10 @@ class TestAnthropicFastMode(unittest.TestCase):
acp_command=None,
acp_args=[],
_credential_pool=None,
_smart_model_routing={},
service_tier="priority",
)
original_runtime = {
"api_key": "***",
"base_url": "https://api.anthropic.com",
"provider": "anthropic",
"api_mode": "anthropic_messages",
"command": None,
"args": [],
"credential_pool": None,
}
with patch("agent.smart_model_routing.resolve_turn_route", return_value={
"model": "claude-opus-4-6",
"runtime": dict(original_runtime),
"label": None,
"signature": ("claude-opus-4-6", "anthropic", "https://api.anthropic.com", "anthropic_messages", None, ()),
}):
route = cli_mod.HermesCLI._resolve_turn_agent_config(stub, "hi")
route = cli_mod.HermesCLI._resolve_turn_agent_config(stub, "hi")
assert route["runtime"]["provider"] == "anthropic"
assert route["request_overrides"] == {"speed": "fast"}

View file

@ -0,0 +1,21 @@
from unittest.mock import MagicMock, patch
def test_gquota_uses_chat_console_when_tui_is_live():
from agent.google_oauth import GoogleOAuthError
from cli import HermesCLI
cli = HermesCLI.__new__(HermesCLI)
cli.console = MagicMock()
cli._app = object()
live_console = MagicMock()
with patch("cli.ChatConsole", return_value=live_console), \
patch("agent.google_oauth.get_valid_access_token", side_effect=GoogleOAuthError("No Google OAuth credentials found")), \
patch("agent.google_oauth.load_credentials", return_value=None), \
patch("agent.google_code_assist.retrieve_user_quota"):
cli._handle_gquota_command("/gquota")
assert live_console.print.call_count == 2
cli.console.print.assert_not_called()

View file

@ -21,6 +21,7 @@ def test_manual_compress_reports_noop_without_success_banner(capsys):
shell.agent = MagicMock()
shell.agent.compression_enabled = True
shell.agent._cached_system_prompt = ""
shell.agent.session_id = shell.session_id # no-op compression: no split
shell.agent._compress_context.return_value = (list(history), "")
def _estimate(messages):
@ -48,6 +49,7 @@ def test_manual_compress_explains_when_token_estimate_rises(capsys):
shell.agent = MagicMock()
shell.agent.compression_enabled = True
shell.agent._cached_system_prompt = ""
shell.agent.session_id = shell.session_id # no-op: no split
shell.agent._compress_context.return_value = (compressed, "")
def _estimate(messages):
@ -64,3 +66,64 @@ def test_manual_compress_explains_when_token_estimate_rises(capsys):
assert "✅ Compressed: 4 → 3 messages" in output
assert "Rough transcript estimate: ~100 → ~120 tokens" in output
assert "denser summaries" in output
def test_manual_compress_syncs_session_id_after_split():
"""Regression for cli.session_id desync after /compress.
_compress_context ends the parent session and creates a new child session,
mutating agent.session_id. Without syncing, cli.session_id still points
at the ended parent causing /status, /resume, exit summary, and the
next end_session() call (e.g. from /resume <id>) to target the wrong row.
"""
shell = _make_cli()
history = _make_history()
old_id = shell.session_id
new_child_id = "20260101_000000_child1"
compressed = [
{"role": "user", "content": "[summary]"},
history[-1],
]
shell.conversation_history = history
shell.agent = MagicMock()
shell.agent.compression_enabled = True
shell.agent._cached_system_prompt = ""
# Simulate _compress_context mutating agent.session_id as a side effect.
def _fake_compress(*args, **kwargs):
shell.agent.session_id = new_child_id
return (compressed, "")
shell.agent._compress_context.side_effect = _fake_compress
shell.agent.session_id = old_id # starts in sync
shell._pending_title = "stale title"
with patch("agent.model_metadata.estimate_messages_tokens_rough", return_value=100):
shell._manual_compress()
# CLI session_id must now point at the continuation child, not the parent.
assert shell.session_id == new_child_id
assert shell.session_id != old_id
# Pending title must be cleared — titles belong to the parent lineage and
# get regenerated for the continuation.
assert shell._pending_title is None
def test_manual_compress_no_sync_when_session_id_unchanged():
"""If compression is a no-op (agent.session_id didn't change), the CLI
must NOT clear _pending_title or otherwise disturb session state.
"""
shell = _make_cli()
history = _make_history()
shell.conversation_history = history
shell.agent = MagicMock()
shell.agent.compression_enabled = True
shell.agent._cached_system_prompt = ""
shell.agent.session_id = shell.session_id
shell.agent._compress_context.return_value = (list(history), "")
shell._pending_title = "keep me"
with patch("agent.model_metadata.estimate_messages_tokens_rough", return_value=100):
shell._manual_compress()
# No split → pending title untouched.
assert shell._pending_title == "keep me"

View file

@ -33,6 +33,20 @@ class TestCLIQuickCommands:
printed = self._printed_plain(cli.console.print.call_args[0][0])
assert printed == "daily-note"
def test_exec_command_uses_chat_console_when_tui_is_live(self):
cli = self._make_cli({"dn": {"type": "exec", "command": "echo daily-note"}})
cli._app = object()
live_console = MagicMock()
with patch("cli.ChatConsole", return_value=live_console):
result = cli.process_command("/dn")
assert result is True
live_console.print.assert_called_once()
printed = self._printed_plain(live_console.print.call_args[0][0])
assert printed == "daily-note"
cli.console.print.assert_not_called()
def test_exec_command_stderr_shown_on_no_stdout(self):
cli = self._make_cli({"err": {"type": "exec", "command": "echo error >&2"}})
result = cli.process_command("/err")

View file

@ -186,6 +186,31 @@ _HERMES_BEHAVIORAL_VARS = frozenset({
"HERMES_HOME_MODE",
"BROWSER_CDP_URL",
"CAMOFOX_URL",
# Platform allowlists — not credentials, but if set from any source
# (user shell, earlier leaky test, CI env), they change gateway auth
# behavior and flake button-authorization tests.
"TELEGRAM_ALLOWED_USERS",
"DISCORD_ALLOWED_USERS",
"WHATSAPP_ALLOWED_USERS",
"SLACK_ALLOWED_USERS",
"SIGNAL_ALLOWED_USERS",
"SIGNAL_GROUP_ALLOWED_USERS",
"EMAIL_ALLOWED_USERS",
"SMS_ALLOWED_USERS",
"MATTERMOST_ALLOWED_USERS",
"MATRIX_ALLOWED_USERS",
"DINGTALK_ALLOWED_USERS",
"FEISHU_ALLOWED_USERS",
"WECOM_ALLOWED_USERS",
"GATEWAY_ALLOWED_USERS",
"GATEWAY_ALLOW_ALL_USERS",
"TELEGRAM_ALLOW_ALL_USERS",
"DISCORD_ALLOW_ALL_USERS",
"WHATSAPP_ALLOW_ALL_USERS",
"SLACK_ALLOW_ALL_USERS",
"SIGNAL_ALLOW_ALL_USERS",
"EMAIL_ALLOW_ALL_USERS",
"SMS_ALLOW_ALL_USERS",
})
@ -258,6 +283,107 @@ def _isolate_hermes_home(_hermetic_environment):
return None
# ── Module-level state reset ───────────────────────────────────────────────
#
# Python modules are singletons per process, and pytest-xdist workers are
# long-lived. Module-level dicts/sets (tool registries, approval state,
# interrupt flags) and ContextVars persist across tests in the same worker,
# causing tests that pass alone to fail when run with siblings.
#
# Each entry in this fixture clears state that belongs to a specific module.
# New state buckets go here too — this is the single gate that prevents
# "works alone, flakes in CI" bugs from state leakage.
#
# The skill `test-suite-cascade-diagnosis` documents the concrete patterns
# this closes; the running example was `test_command_guards` failing 12/15
# CI runs because ``tools.approval._session_approved`` carried approvals
# from one test's session into another's.
@pytest.fixture(autouse=True)
def _reset_module_state():
"""Clear module-level mutable state and ContextVars between tests.
Keeps state from leaking across tests on the same xdist worker. Modules
that don't exist yet (test collection before production import) are
skipped silently production import later creates fresh empty state.
"""
# --- tools.approval — the single biggest source of cross-test pollution ---
try:
from tools import approval as _approval_mod
_approval_mod._session_approved.clear()
_approval_mod._session_yolo.clear()
_approval_mod._permanent_approved.clear()
_approval_mod._pending.clear()
_approval_mod._gateway_queues.clear()
_approval_mod._gateway_notify_cbs.clear()
# ContextVar: reset to empty string so get_current_session_key()
# falls through to the env var / default path, matching a fresh
# process.
_approval_mod._approval_session_key.set("")
except Exception:
pass
# --- tools.interrupt — per-thread interrupt flag set ---
try:
from tools import interrupt as _interrupt_mod
with _interrupt_mod._lock:
_interrupt_mod._interrupted_threads.clear()
except Exception:
pass
# --- gateway.session_context — 9 ContextVars that represent
# the active gateway session. If set in one test and not reset,
# the next test's get_session_env() reads stale values.
try:
from gateway import session_context as _sc_mod
for _cv in (
_sc_mod._SESSION_PLATFORM,
_sc_mod._SESSION_CHAT_ID,
_sc_mod._SESSION_CHAT_NAME,
_sc_mod._SESSION_THREAD_ID,
_sc_mod._SESSION_USER_ID,
_sc_mod._SESSION_USER_NAME,
_sc_mod._SESSION_KEY,
_sc_mod._CRON_AUTO_DELIVER_PLATFORM,
_sc_mod._CRON_AUTO_DELIVER_CHAT_ID,
_sc_mod._CRON_AUTO_DELIVER_THREAD_ID,
):
_cv.set(_sc_mod._UNSET)
except Exception:
pass
# --- tools.env_passthrough — ContextVar<set[str]> with no default ---
# LookupError is normal if the test never set it. Setting it to an
# empty set unconditionally normalizes the starting state.
try:
from tools import env_passthrough as _envp_mod
_envp_mod._allowed_env_vars_var.set(set())
except Exception:
pass
# --- tools.credential_files — ContextVar<dict> ---
try:
from tools import credential_files as _credf_mod
_credf_mod._registered_files_var.set({})
except Exception:
pass
# --- tools.file_tools — per-task read history + file-ops cache ---
# _read_tracker accumulates per-task_id read history for loop detection,
# capped by _READ_HISTORY_CAP. If entries from a prior test persist, the
# cap is hit faster than expected and capacity-related tests flake.
try:
from tools import file_tools as _ft_mod
with _ft_mod._read_tracker_lock:
_ft_mod._read_tracker.clear()
with _ft_mod._file_ops_lock:
_ft_mod._file_ops_cache.clear()
except Exception:
pass
yield
@pytest.fixture()
def tmp_dir(tmp_path):
"""Provide a temporary directory that is cleaned up automatically."""

View file

@ -152,7 +152,6 @@ def test_gateway_run_agent_codex_path_handles_internal_401_refresh(monkeypatch):
runner._provider_routing = {}
runner._fallback_model = None
runner._running_agents = {}
runner._smart_model_routing = {}
from unittest.mock import MagicMock, AsyncMock
runner.hooks = MagicMock()
runner.hooks.emit = AsyncMock()

View file

@ -772,9 +772,10 @@ class TestRunJobSessionPersistence:
pass
def run_conversation(self, *args, **kwargs):
seen["platform"] = os.getenv("HERMES_CRON_AUTO_DELIVER_PLATFORM")
seen["chat_id"] = os.getenv("HERMES_CRON_AUTO_DELIVER_CHAT_ID")
seen["thread_id"] = os.getenv("HERMES_CRON_AUTO_DELIVER_THREAD_ID")
from gateway.session_context import get_session_env
seen["platform"] = get_session_env("HERMES_CRON_AUTO_DELIVER_PLATFORM") or None
seen["chat_id"] = get_session_env("HERMES_CRON_AUTO_DELIVER_CHAT_ID") or None
seen["thread_id"] = get_session_env("HERMES_CRON_AUTO_DELIVER_THREAD_ID") or None
return {"final_response": "ok"}
with patch("cron.scheduler._hermes_home", tmp_path), \
@ -1024,7 +1025,7 @@ class TestRunJobSkillBacked:
"id": "multi-skill-job",
"name": "multi skill test",
"prompt": "Combine the results.",
"skills": ["blogwatcher", "find-nearby"],
"skills": ["blogwatcher", "maps"],
}
fake_db = MagicMock()
@ -1057,12 +1058,12 @@ class TestRunJobSkillBacked:
assert error is None
assert final_response == "ok"
assert skill_view_mock.call_count == 2
assert [call.args[0] for call in skill_view_mock.call_args_list] == ["blogwatcher", "find-nearby"]
assert [call.args[0] for call in skill_view_mock.call_args_list] == ["blogwatcher", "maps"]
prompt_arg = mock_agent.run_conversation.call_args.args[0]
assert prompt_arg.index("blogwatcher") < prompt_arg.index("find-nearby")
assert prompt_arg.index("blogwatcher") < prompt_arg.index("maps")
assert "Instructions for blogwatcher." in prompt_arg
assert "Instructions for find-nearby." in prompt_arg
assert "Instructions for maps." in prompt_arg
assert "Combine the results." in prompt_arg
@ -1175,6 +1176,204 @@ class TestBuildJobPromptSilentHint:
assert system_pos < prompt_pos
class TestParseWakeGate:
"""Unit tests for _parse_wake_gate — pure function, no side effects."""
def test_empty_output_wakes(self):
from cron.scheduler import _parse_wake_gate
assert _parse_wake_gate("") is True
assert _parse_wake_gate(None) is True
def test_whitespace_only_wakes(self):
from cron.scheduler import _parse_wake_gate
assert _parse_wake_gate(" \n\n \t\n") is True
def test_non_json_last_line_wakes(self):
from cron.scheduler import _parse_wake_gate
assert _parse_wake_gate("hello world") is True
assert _parse_wake_gate("line 1\nline 2\nplain text") is True
def test_json_non_dict_wakes(self):
"""Bare arrays, numbers, strings must not be interpreted as a gate."""
from cron.scheduler import _parse_wake_gate
assert _parse_wake_gate("[1, 2, 3]") is True
assert _parse_wake_gate("42") is True
assert _parse_wake_gate('"wakeAgent"') is True
def test_wake_gate_false_skips(self):
from cron.scheduler import _parse_wake_gate
assert _parse_wake_gate('{"wakeAgent": false}') is False
def test_wake_gate_true_wakes(self):
from cron.scheduler import _parse_wake_gate
assert _parse_wake_gate('{"wakeAgent": true}') is True
def test_wake_gate_missing_wakes(self):
"""A JSON dict without a wakeAgent key defaults to waking."""
from cron.scheduler import _parse_wake_gate
assert _parse_wake_gate('{"data": {"foo": "bar"}}') is True
def test_non_boolean_false_still_wakes(self):
"""Only strict ``False`` skips — truthy/falsy shortcuts are too risky."""
from cron.scheduler import _parse_wake_gate
assert _parse_wake_gate('{"wakeAgent": 0}') is True
assert _parse_wake_gate('{"wakeAgent": null}') is True
assert _parse_wake_gate('{"wakeAgent": ""}') is True
def test_only_last_non_empty_line_parsed(self):
from cron.scheduler import _parse_wake_gate
multi = 'some log output\nmore output\n{"wakeAgent": false}'
assert _parse_wake_gate(multi) is False
def test_trailing_blank_lines_ignored(self):
from cron.scheduler import _parse_wake_gate
multi = '{"wakeAgent": false}\n\n\n'
assert _parse_wake_gate(multi) is False
def test_non_last_json_line_does_not_gate(self):
"""A JSON gate on an earlier line with plain text after it does NOT trigger."""
from cron.scheduler import _parse_wake_gate
multi = '{"wakeAgent": false}\nactually this is the real output'
assert _parse_wake_gate(multi) is True
class TestRunJobWakeGate:
"""Integration tests for run_job wake-gate short-circuit."""
@pytest.fixture(autouse=True)
def _stub_runtime_provider(self):
"""Stub ``resolve_runtime_provider`` for wake-gate tests.
``run_job`` resolves the runtime provider BEFORE constructing
``AIAgent``, so these tests must mock ``resolve_runtime_provider``
in addition to ``AIAgent`` otherwise in a hermetic CI env (no
API keys), the resolver raises and the test fails before the
patched AIAgent is ever reached.
"""
fake_runtime = {
"provider": "openrouter",
"api_mode": "chat_completions",
"base_url": "https://openrouter.ai/api/v1",
"api_key": "test-key",
"source": "stub",
"requested_provider": None,
}
with patch(
"hermes_cli.runtime_provider.resolve_runtime_provider",
return_value=fake_runtime,
):
yield
def _make_job(self, name="wake-gate-test", script="check.py"):
"""Minimal valid cron job dict for run_job."""
return {
"id": f"job_{name}",
"name": name,
"prompt": "Do a thing",
"schedule": "*/5 * * * *",
"script": script,
}
def test_wake_false_skips_agent_and_returns_silent(self, caplog):
"""When _run_job_script output ends with {wakeAgent: false}, the agent
is not invoked and run_job returns the SILENT marker so delivery is
suppressed."""
from cron.scheduler import SILENT_MARKER
import cron.scheduler as scheduler
with patch.object(scheduler, "_run_job_script",
return_value=(True, '{"wakeAgent": false}')), \
patch("run_agent.AIAgent") as agent_cls:
success, doc, final, err = scheduler.run_job(self._make_job())
assert success is True
assert err is None
assert final == SILENT_MARKER
assert "Script gate returned `wakeAgent=false`" in doc
agent_cls.assert_not_called()
def test_wake_true_runs_agent_with_injected_output(self):
"""When the script returns {wakeAgent: true, data: ...}, the agent is
invoked and the data line still shows up in the prompt."""
import cron.scheduler as scheduler
script_output = '{"wakeAgent": true, "data": {"new": 3}}'
agent = MagicMock()
agent.run_conversation = MagicMock(return_value={
"final_response": "ok", "messages": []
})
with patch.object(scheduler, "_run_job_script",
return_value=(True, script_output)), \
patch("run_agent.AIAgent", return_value=agent) as agent_cls:
success, doc, final, err = scheduler.run_job(self._make_job())
agent_cls.assert_called_once()
# The script output should be visible in the prompt passed to
# run_conversation.
call_kwargs = agent.run_conversation.call_args
prompt_arg = call_kwargs.args[0] if call_kwargs.args else call_kwargs.kwargs.get("user_message", "")
assert script_output in prompt_arg
assert success is True
assert err is None
def test_script_runs_only_once_on_wake(self):
"""Wake-true path must not re-run the script inside _build_job_prompt
(script would execute twice otherwise, wasting work and risking
double-side-effects)."""
import cron.scheduler as scheduler
call_count = 0
def _script_stub(path):
nonlocal call_count
call_count += 1
return (True, "regular output")
agent = MagicMock()
agent.run_conversation = MagicMock(return_value={
"final_response": "ok", "messages": []
})
with patch.object(scheduler, "_run_job_script", side_effect=_script_stub), \
patch("run_agent.AIAgent", return_value=agent):
scheduler.run_job(self._make_job())
assert call_count == 1, f"script ran {call_count}x, expected exactly 1"
def test_script_failure_does_not_trigger_gate(self):
"""If _run_job_script returns success=False, the gate is NOT evaluated
and the agent still runs (the failure is reported as context)."""
import cron.scheduler as scheduler
# Malicious or broken script whose stderr happens to contain the
# gate JSON — we must NOT honor it because ran_ok is False.
agent = MagicMock()
agent.run_conversation = MagicMock(return_value={
"final_response": "ok", "messages": []
})
with patch.object(scheduler, "_run_job_script",
return_value=(False, '{"wakeAgent": false}')), \
patch("run_agent.AIAgent", return_value=agent) as agent_cls:
success, doc, final, err = scheduler.run_job(self._make_job())
agent_cls.assert_called_once() # Agent DID wake despite the gate-like text
def test_no_script_path_runs_agent_normally(self):
"""Regression: jobs without a script still work."""
import cron.scheduler as scheduler
agent = MagicMock()
agent.run_conversation = MagicMock(return_value={
"final_response": "ok", "messages": []
})
job = self._make_job(script=None)
job.pop("script", None)
with patch.object(scheduler, "_run_job_script") as script_fn, \
patch("run_agent.AIAgent", return_value=agent) as agent_cls:
scheduler.run_job(job)
script_fn.assert_not_called()
agent_cls.assert_called_once()
class TestBuildJobPromptMissingSkill:
"""Verify that a missing skill logs a warning and does not crash the job."""
@ -1259,3 +1458,250 @@ class TestSendMediaViaAdapter:
self._run_with_loop(adapter, "123", media_files, None, {"id": "j3"})
adapter.send_voice.assert_called_once()
adapter.send_image_file.assert_called_once()
class TestParallelTick:
"""Verify that tick() runs due jobs concurrently and isolates ContextVars."""
@pytest.fixture(autouse=True)
def _isolate_tick_lock(self, tmp_path):
"""Point the tick file lock at a per-test temp dir to avoid xdist contention."""
lock_dir = tmp_path / "cron"
lock_dir.mkdir()
with patch("cron.scheduler._LOCK_DIR", lock_dir), \
patch("cron.scheduler._LOCK_FILE", lock_dir / ".tick.lock"):
yield
def test_parallel_jobs_run_concurrently(self):
"""Two jobs launched in the same tick should overlap in time."""
import threading
import time
barrier = threading.Barrier(2, timeout=5)
call_order = []
def mock_run_job(job):
"""Each job hits a barrier — both must be active simultaneously."""
call_order.append(("start", job["id"]))
barrier.wait() # blocks until both threads reach here
call_order.append(("end", job["id"]))
return (True, "output", "response", None)
jobs = [
{"id": "job-a", "name": "a", "deliver": "local"},
{"id": "job-b", "name": "b", "deliver": "local"},
]
with patch("cron.scheduler.get_due_jobs", return_value=jobs), \
patch("cron.scheduler.advance_next_run"), \
patch("cron.scheduler.run_job", side_effect=mock_run_job), \
patch("cron.scheduler.save_job_output", return_value="/tmp/out.md"), \
patch("cron.scheduler._deliver_result", return_value=None), \
patch("cron.scheduler.mark_job_run"):
from cron.scheduler import tick
result = tick(verbose=False)
assert result == 2
# Both starts happened before both ends — proof of concurrency
starts = [i for i, (action, _) in enumerate(call_order) if action == "start"]
ends = [i for i, (action, _) in enumerate(call_order) if action == "end"]
assert len(starts) == 2
assert len(ends) == 2
assert max(starts) < min(ends), f"Jobs not concurrent: {call_order}"
def test_parallel_jobs_isolated_contextvars(self):
"""Each job's ContextVars must be isolated — no cross-contamination."""
from gateway.session_context import get_session_env
seen = {}
def mock_run_job(job):
origin = job.get("origin", {})
# run_job sets ContextVars — verify each job sees its own
from gateway.session_context import set_session_vars, clear_session_vars
tokens = set_session_vars(
platform=origin.get("platform", ""),
chat_id=str(origin.get("chat_id", "")),
)
import time
time.sleep(0.05) # give other thread time to set its vars
platform = get_session_env("HERMES_SESSION_PLATFORM")
chat_id = get_session_env("HERMES_SESSION_CHAT_ID")
seen[job["id"]] = {"platform": platform, "chat_id": chat_id}
clear_session_vars(tokens)
return (True, "output", "response", None)
jobs = [
{"id": "tg-job", "name": "tg", "deliver": "local",
"origin": {"platform": "telegram", "chat_id": "111"}},
{"id": "dc-job", "name": "dc", "deliver": "local",
"origin": {"platform": "discord", "chat_id": "222"}},
]
with patch("cron.scheduler.get_due_jobs", return_value=jobs), \
patch("cron.scheduler.advance_next_run"), \
patch("cron.scheduler.run_job", side_effect=mock_run_job), \
patch("cron.scheduler.save_job_output", return_value="/tmp/out.md"), \
patch("cron.scheduler._deliver_result", return_value=None), \
patch("cron.scheduler.mark_job_run"):
from cron.scheduler import tick
tick(verbose=False)
assert seen["tg-job"] == {"platform": "telegram", "chat_id": "111"}
assert seen["dc-job"] == {"platform": "discord", "chat_id": "222"}
def test_max_parallel_env_var(self, monkeypatch):
"""HERMES_CRON_MAX_PARALLEL=1 should restore serial behaviour."""
monkeypatch.setenv("HERMES_CRON_MAX_PARALLEL", "1")
call_times = []
def mock_run_job(job):
import time
call_times.append(("start", job["id"], time.monotonic()))
time.sleep(0.05)
call_times.append(("end", job["id"], time.monotonic()))
return (True, "output", "response", None)
jobs = [
{"id": "s1", "name": "s1", "deliver": "local"},
{"id": "s2", "name": "s2", "deliver": "local"},
]
with patch("cron.scheduler.get_due_jobs", return_value=jobs), \
patch("cron.scheduler.advance_next_run"), \
patch("cron.scheduler.run_job", side_effect=mock_run_job), \
patch("cron.scheduler.save_job_output", return_value="/tmp/out.md"), \
patch("cron.scheduler._deliver_result", return_value=None), \
patch("cron.scheduler.mark_job_run"):
from cron.scheduler import tick
result = tick(verbose=False)
assert result == 2
# With max_workers=1, second job starts after first ends
end_s1 = [t for action, jid, t in call_times if action == "end" and jid == "s1"][0]
start_s2 = [t for action, jid, t in call_times if action == "start" and jid == "s2"][0]
assert start_s2 >= end_s1, "Jobs ran concurrently despite max_parallel=1"
class TestDeliverResultTimeoutCancelsFuture:
"""When future.result(timeout=60) raises TimeoutError in the live
adapter delivery path, _deliver_result must cancel the orphan
coroutine so it cannot duplicate-send after the standalone fallback.
"""
def test_live_adapter_timeout_cancels_future_and_falls_back(self):
"""End-to-end: live adapter hangs past the 60s budget, _deliver_result
patches the timeout down to a fast value, confirms future.cancel() fires,
and verifies the standalone fallback path still delivers."""
from gateway.config import Platform
from concurrent.futures import Future
# Live adapter whose send() coroutine never resolves within the budget
adapter = AsyncMock()
adapter.send.return_value = MagicMock(success=True)
pconfig = MagicMock()
pconfig.enabled = True
mock_cfg = MagicMock()
mock_cfg.platforms = {Platform.TELEGRAM: pconfig}
loop = MagicMock()
loop.is_running.return_value = True
# A real concurrent.futures.Future so .cancel() has real semantics,
# but we override .result() to raise TimeoutError exactly like the
# 60s wait firing in production.
captured_future = Future()
cancel_calls = []
original_cancel = captured_future.cancel
def tracking_cancel():
cancel_calls.append(True)
return original_cancel()
captured_future.cancel = tracking_cancel
captured_future.result = MagicMock(side_effect=TimeoutError("timed out"))
def fake_run_coro(coro, _loop):
coro.close()
return captured_future
job = {
"id": "timeout-job",
"deliver": "origin",
"origin": {"platform": "telegram", "chat_id": "123"},
}
standalone_send = AsyncMock(return_value={"success": True})
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
patch("cron.scheduler.load_config", return_value={"cron": {"wrap_response": False}}), \
patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro), \
patch("tools.send_message_tool._send_to_platform", new=standalone_send):
result = _deliver_result(
job,
"Hello world",
adapters={Platform.TELEGRAM: adapter},
loop=loop,
)
# 1. The orphan future was cancelled on timeout (the bug fix)
assert cancel_calls == [True], "future.cancel() must fire on TimeoutError"
# 2. The standalone fallback delivered — no double send, no silent drop
assert result is None, f"expected successful delivery, got error: {result!r}"
standalone_send.assert_awaited_once()
class TestSendMediaTimeoutCancelsFuture:
"""Same orphan-coroutine guarantee for _send_media_via_adapter's
future.result(timeout=30) call. If this times out mid-batch, the
in-flight coroutine must be cancelled before the next file is tried.
"""
def test_media_send_timeout_cancels_future_and_continues(self):
"""End-to-end: _send_media_via_adapter with a future whose .result()
raises TimeoutError. Assert cancel() fires and the loop proceeds
to the next file rather than hanging or crashing."""
from concurrent.futures import Future
adapter = MagicMock()
adapter.send_image_file = AsyncMock()
adapter.send_video = AsyncMock()
# First file: future that times out. Second file: future that resolves OK.
timeout_future = Future()
timeout_cancel_calls = []
original_cancel = timeout_future.cancel
def tracking_cancel():
timeout_cancel_calls.append(True)
return original_cancel()
timeout_future.cancel = tracking_cancel
timeout_future.result = MagicMock(side_effect=TimeoutError("timed out"))
ok_future = Future()
ok_future.set_result(MagicMock(success=True))
futures_iter = iter([timeout_future, ok_future])
def fake_run_coro(coro, _loop):
coro.close()
return next(futures_iter)
media_files = [
("/tmp/slow.png", False), # times out
("/tmp/fast.mp4", False), # succeeds
]
loop = MagicMock()
job = {"id": "media-timeout"}
with patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
# Should not raise — the except Exception clause swallows the timeout
_send_media_via_adapter(adapter, "chat-1", media_files, None, loop, job)
# 1. The timed-out future was cancelled (the bug fix)
assert timeout_cancel_calls == [True], "future.cancel() must fire on TimeoutError"
# 2. Second file still got dispatched — one timeout doesn't abort the batch
adapter.send_video.assert_called_once()
assert adapter.send_video.call_args[1]["video_path"] == "/tmp/fast.mp4"

View file

@ -12,7 +12,7 @@ No LLM, no real platform connections.
import asyncio
import sys
import uuid
from datetime import datetime
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
@ -22,6 +22,7 @@ from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent, SendResult
from gateway.session import SessionEntry, SessionSource, build_session_key
E2E_MESSAGE_SETTLE_DELAY = 0.3
# Platform library mocks
@ -113,8 +114,9 @@ _ensure_telegram_mock()
_ensure_discord_mock()
_ensure_slack_mock()
from gateway.platforms.discord import DiscordAdapter # noqa: E402
import discord # noqa: E402 — mocked above
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
from gateway.platforms.discord import DiscordAdapter # noqa: E402
import gateway.platforms.slack as _slack_mod # noqa: E402
_slack_mod.SLACK_AVAILABLE = True
@ -264,3 +266,140 @@ def runner(platform, session_entry):
@pytest.fixture()
def adapter(platform, runner):
return make_adapter(platform, runner)
# ═══════════════════════════════════════════════════════════════════════════
# Discord helpers and fixtures
# ═══════════════════════════════════════════════════════════════════════════
BOT_USER_ID = 99999
BOT_USER_NAME = "HermesBot"
CHANNEL_ID = 22222
GUILD_ID = 44444
THREAD_ID = 33333
MESSAGE_ID_COUNTER = 0
def _next_message_id() -> int:
global MESSAGE_ID_COUNTER
MESSAGE_ID_COUNTER += 1
return 70000 + MESSAGE_ID_COUNTER
def make_fake_bot_user():
return SimpleNamespace(
id=BOT_USER_ID, name=BOT_USER_NAME,
display_name=BOT_USER_NAME, bot=True,
)
def make_fake_guild(guild_id: int = GUILD_ID, name: str = "Test Server"):
return SimpleNamespace(id=guild_id, name=name)
def make_fake_text_channel(channel_id: int = CHANNEL_ID, name: str = "general", guild=None):
return SimpleNamespace(
id=channel_id, name=name,
guild=guild or make_fake_guild(),
topic=None, type=0,
)
def make_fake_dm_channel(channel_id: int = 55555):
ch = MagicMock(spec=[])
ch.id = channel_id
ch.name = "DM"
ch.topic = None
ch.__class__ = discord.DMChannel
return ch
def make_fake_thread(thread_id: int = THREAD_ID, name: str = "test-thread", parent=None):
th = MagicMock(spec=[])
th.id = thread_id
th.name = name
th.parent = parent or make_fake_text_channel()
th.parent_id = th.parent.id
th.guild = th.parent.guild
th.topic = None
th.type = 11
th.__class__ = discord.Thread
return th
def make_discord_message(
*, content: str = "hello", author=None, channel=None, mentions=None,
attachments=None, message_id: int = None,
):
if message_id is None:
message_id = _next_message_id()
if author is None:
author = SimpleNamespace(
id=11111, name="testuser", display_name="testuser", bot=False,
)
if channel is None:
channel = make_fake_text_channel()
if mentions is None:
mentions = []
if attachments is None:
attachments = []
return SimpleNamespace(
id=message_id, content=content, author=author, channel=channel,
mentions=mentions, attachments=attachments,
type=getattr(discord, "MessageType", SimpleNamespace()).default,
reference=None, created_at=datetime.now(timezone.utc),
create_thread=AsyncMock(),
)
def get_response_text(adapter) -> str | None:
"""Extract the response text from adapter.send() call args, or None if not called."""
if not adapter.send.called:
return None
return adapter.send.call_args[1].get("content") or adapter.send.call_args[0][1]
def _make_discord_adapter_wired(runner=None):
"""Create a DiscordAdapter wired to a GatewayRunner for e2e tests."""
if runner is None:
runner = make_runner(Platform.DISCORD)
config = PlatformConfig(enabled=True, token="e2e-test-token")
from gateway.platforms.helpers import ThreadParticipationTracker
with patch.object(ThreadParticipationTracker, "_load", return_value=set()):
adapter = DiscordAdapter(config)
bot_user = make_fake_bot_user()
adapter._client = SimpleNamespace(
user=bot_user,
get_channel=lambda _id: None,
fetch_channel=AsyncMock(),
)
adapter.send = AsyncMock(return_value=SendResult(success=True, message_id="e2e-resp-1"))
adapter.send_typing = AsyncMock()
adapter.set_message_handler(runner._handle_message)
runner.adapters[Platform.DISCORD] = adapter
return adapter, runner
@pytest.fixture()
def discord_setup():
return _make_discord_adapter_wired()
@pytest.fixture()
def discord_adapter(discord_setup):
return discord_setup[0]
@pytest.fixture()
def discord_runner(discord_setup):
return discord_setup[1]
@pytest.fixture()
def bot_user():
return make_fake_bot_user()

View file

@ -0,0 +1,106 @@
"""Minimal e2e tests for Discord mention stripping + /command detection.
Covers the fix for slash commands not being recognized when sent via
@mention in a channel, especially after auto-threading.
"""
import asyncio
from unittest.mock import AsyncMock
import pytest
from tests.e2e.conftest import (
BOT_USER_ID,
E2E_MESSAGE_SETTLE_DELAY,
get_response_text,
make_discord_message,
make_fake_dm_channel,
make_fake_thread,
)
pytestmark = pytest.mark.asyncio
async def dispatch(adapter, msg):
await adapter._handle_message(msg)
await asyncio.sleep(E2E_MESSAGE_SETTLE_DELAY)
class TestMentionStrippedCommandDispatch:
async def test_mention_then_command(self, discord_adapter, bot_user):
"""<@BOT> /help → mention stripped, /help dispatched."""
msg = make_discord_message(
content=f"<@{BOT_USER_ID}> /help",
mentions=[bot_user],
)
await dispatch(discord_adapter, msg)
response = get_response_text(discord_adapter)
assert response is not None
assert "/new" in response
async def test_nickname_mention_then_command(self, discord_adapter, bot_user):
"""<@!BOT> /help → nickname mention also stripped, /help works."""
msg = make_discord_message(
content=f"<@!{BOT_USER_ID}> /help",
mentions=[bot_user],
)
await dispatch(discord_adapter, msg)
response = get_response_text(discord_adapter)
assert response is not None
assert "/new" in response
async def test_text_before_command_not_detected(self, discord_adapter, bot_user):
"""'<@BOT> something else /help' → mention stripped, but 'something else /help'
doesn't start with / so it's treated as text, not a command."""
msg = make_discord_message(
content=f"<@{BOT_USER_ID}> something else /help",
mentions=[bot_user],
)
await dispatch(discord_adapter, msg)
# Message is accepted (not dropped by mention gate), but since it doesn't
# start with / it's routed as text — no command output, and no agent in this
# mock setup means no send call either.
response = get_response_text(discord_adapter)
assert response is None or "/new" not in response
async def test_no_mention_in_channel_dropped(self, discord_adapter):
"""Message without @mention in server channel → silently dropped."""
msg = make_discord_message(content="/help", mentions=[])
await dispatch(discord_adapter, msg)
assert get_response_text(discord_adapter) is None
async def test_dm_no_mention_needed(self, discord_adapter):
"""DMs don't require @mention — /help works directly."""
dm = make_fake_dm_channel()
msg = make_discord_message(content="/help", channel=dm, mentions=[])
await dispatch(discord_adapter, msg)
response = get_response_text(discord_adapter)
assert response is not None
assert "/new" in response
class TestAutoThreadingPreservesCommand:
async def test_command_detected_after_auto_thread(self, discord_adapter, bot_user, monkeypatch):
"""@mention /help in channel with auto-thread → thread created AND command dispatched."""
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
fake_thread = make_fake_thread(thread_id=90001, name="help")
msg = make_discord_message(
content=f"<@{BOT_USER_ID}> /help",
mentions=[bot_user],
)
# Simulate discord.py restoring the original raw content (with mention)
# after create_thread(), which undoes any prior mention stripping.
original_content = msg.content
async def clobber_content(**kwargs):
msg.content = original_content
return fake_thread
msg.create_thread = AsyncMock(side_effect=clobber_content)
await dispatch(discord_adapter, msg)
msg.create_thread.assert_awaited_once()
response = get_response_text(discord_adapter)
assert response is not None
assert "/new" in response

View file

@ -108,6 +108,7 @@ def make_restart_runner(
runner.hooks.emit = AsyncMock()
runner.pairing_store = MagicMock()
runner.session_store = MagicMock()
runner.session_store._entries = {}
runner.delivery_router = MagicMock()
platform_adapter = adapter or RestartTestAdapter()

View file

@ -12,6 +12,7 @@ Tests cover:
- Error handling (invalid JSON, missing fields)
"""
import asyncio
import json
import time
import uuid
@ -25,6 +26,7 @@ from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.api_server import (
APIServerAdapter,
ResponseStore,
_IdempotencyCache,
_CORS_HEADERS,
_derive_chat_session_id,
check_api_server_requirements,
@ -104,6 +106,95 @@ class TestResponseStore:
assert store.delete("resp_missing") is False
# ---------------------------------------------------------------------------
# _IdempotencyCache
# ---------------------------------------------------------------------------
class TestIdempotencyCache:
@pytest.mark.asyncio
async def test_concurrent_same_key_and_fingerprint_runs_once(self):
cache = _IdempotencyCache()
gate = asyncio.Event()
started = asyncio.Event()
calls = 0
async def compute():
nonlocal calls
calls += 1
started.set()
await gate.wait()
return ("response", {"total_tokens": 1})
first = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute))
second = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute))
await started.wait()
assert calls == 1
gate.set()
first_result, second_result = await asyncio.gather(first, second)
assert first_result == second_result == ("response", {"total_tokens": 1})
@pytest.mark.asyncio
async def test_different_fingerprint_does_not_reuse_inflight_task(self):
cache = _IdempotencyCache()
gate = asyncio.Event()
started = asyncio.Event()
calls = 0
async def compute():
nonlocal calls
calls += 1
result = calls
if calls == 2:
started.set()
await gate.wait()
return result
first = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute))
second = asyncio.create_task(cache.get_or_set("idem-key", "fp-2", compute))
await started.wait()
assert calls == 2
gate.set()
results = await asyncio.gather(first, second)
assert sorted(results) == [1, 2]
@pytest.mark.asyncio
async def test_cancelled_waiter_does_not_drop_shared_inflight_task(self):
cache = _IdempotencyCache()
gate = asyncio.Event()
started = asyncio.Event()
calls = 0
async def compute():
nonlocal calls
calls += 1
started.set()
await gate.wait()
return "response"
first = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute))
await started.wait()
assert calls == 1
first.cancel()
with pytest.raises(asyncio.CancelledError):
await first
second = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute))
await asyncio.sleep(0)
assert calls == 1
gate.set()
assert await second == "response"
# ---------------------------------------------------------------------------
# Adapter initialization
# ---------------------------------------------------------------------------

View file

@ -20,6 +20,8 @@ from aiohttp.test_utils import TestClient, TestServer
from gateway.config import PlatformConfig
from gateway.platforms.api_server import APIServerAdapter, cors_middleware
_MOD = "gateway.platforms.api_server"
# ---------------------------------------------------------------------------
# Helpers
@ -83,10 +85,10 @@ class TestListJobs:
"""GET /api/jobs returns job list."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_list", return_value=[SAMPLE_JOB]
with patch(
f"{_MOD}._CRON_AVAILABLE", True
), patch(
f"{_MOD}._cron_list", return_value=[SAMPLE_JOB]
):
resp = await cli.get("/api/jobs")
assert resp.status == 200
@ -104,10 +106,10 @@ class TestListJobs:
app = _create_app(adapter)
mock_list = MagicMock(return_value=[SAMPLE_JOB])
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_list", mock_list
with patch(
f"{_MOD}._CRON_AVAILABLE", True
), patch(
f"{_MOD}._cron_list", mock_list
):
resp = await cli.get("/api/jobs?include_disabled=true")
assert resp.status == 200
@ -119,10 +121,10 @@ class TestListJobs:
app = _create_app(adapter)
mock_list = MagicMock(return_value=[])
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_list", mock_list
with patch(
f"{_MOD}._CRON_AVAILABLE", True
), patch(
f"{_MOD}._cron_list", mock_list
):
resp = await cli.get("/api/jobs")
assert resp.status == 200
@ -140,10 +142,10 @@ class TestCreateJob:
app = _create_app(adapter)
mock_create = MagicMock(return_value=SAMPLE_JOB)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_create", mock_create
with patch(
f"{_MOD}._CRON_AVAILABLE", True
), patch(
f"{_MOD}._cron_create", mock_create
):
resp = await cli.post("/api/jobs", json={
"name": "test-job",
@ -164,7 +166,7 @@ class TestCreateJob:
"""POST /api/jobs without name returns 400."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
with patch(f"{_MOD}._CRON_AVAILABLE", True):
resp = await cli.post("/api/jobs", json={
"schedule": "*/5 * * * *",
"prompt": "do something",
@ -178,7 +180,7 @@ class TestCreateJob:
"""POST /api/jobs with name > 200 chars returns 400."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
with patch(f"{_MOD}._CRON_AVAILABLE", True):
resp = await cli.post("/api/jobs", json={
"name": "x" * 201,
"schedule": "*/5 * * * *",
@ -192,7 +194,7 @@ class TestCreateJob:
"""POST /api/jobs with prompt > 5000 chars returns 400."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
with patch(f"{_MOD}._CRON_AVAILABLE", True):
resp = await cli.post("/api/jobs", json={
"name": "test-job",
"schedule": "*/5 * * * *",
@ -207,7 +209,7 @@ class TestCreateJob:
"""POST /api/jobs with repeat=0 returns 400."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
with patch(f"{_MOD}._CRON_AVAILABLE", True):
resp = await cli.post("/api/jobs", json={
"name": "test-job",
"schedule": "*/5 * * * *",
@ -222,7 +224,7 @@ class TestCreateJob:
"""POST /api/jobs without schedule returns 400."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
with patch(f"{_MOD}._CRON_AVAILABLE", True):
resp = await cli.post("/api/jobs", json={
"name": "test-job",
})
@ -242,10 +244,10 @@ class TestGetJob:
app = _create_app(adapter)
mock_get = MagicMock(return_value=SAMPLE_JOB)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_get", mock_get
with patch(
f"{_MOD}._CRON_AVAILABLE", True
), patch(
f"{_MOD}._cron_get", mock_get
):
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
assert resp.status == 200
@ -259,10 +261,10 @@ class TestGetJob:
app = _create_app(adapter)
mock_get = MagicMock(return_value=None)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_get", mock_get
with patch(
f"{_MOD}._CRON_AVAILABLE", True
), patch(
f"{_MOD}._cron_get", mock_get
):
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
assert resp.status == 404
@ -272,7 +274,7 @@ class TestGetJob:
"""GET /api/jobs/{id} with non-hex id returns 400."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
with patch(f"{_MOD}._CRON_AVAILABLE", True):
resp = await cli.get("/api/jobs/not-a-valid-hex!")
assert resp.status == 400
data = await resp.json()
@ -291,10 +293,10 @@ class TestUpdateJob:
updated_job = {**SAMPLE_JOB, "name": "updated-name"}
mock_update = MagicMock(return_value=updated_job)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_update", mock_update
with patch(
f"{_MOD}._CRON_AVAILABLE", True
), patch(
f"{_MOD}._cron_update", mock_update
):
resp = await cli.patch(
f"/api/jobs/{VALID_JOB_ID}",
@ -317,10 +319,10 @@ class TestUpdateJob:
updated_job = {**SAMPLE_JOB, "name": "new-name"}
mock_update = MagicMock(return_value=updated_job)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_update", mock_update
with patch(
f"{_MOD}._CRON_AVAILABLE", True
), patch(
f"{_MOD}._cron_update", mock_update
):
resp = await cli.patch(
f"/api/jobs/{VALID_JOB_ID}",
@ -342,7 +344,7 @@ class TestUpdateJob:
"""PATCH /api/jobs/{id} with only unknown fields returns 400."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
with patch(f"{_MOD}._CRON_AVAILABLE", True):
resp = await cli.patch(
f"/api/jobs/{VALID_JOB_ID}",
json={"evil_field": "malicious"},
@ -363,10 +365,10 @@ class TestDeleteJob:
app = _create_app(adapter)
mock_remove = MagicMock(return_value=True)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_remove", mock_remove
with patch(
f"{_MOD}._CRON_AVAILABLE", True
), patch(
f"{_MOD}._cron_remove", mock_remove
):
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
assert resp.status == 200
@ -380,10 +382,10 @@ class TestDeleteJob:
app = _create_app(adapter)
mock_remove = MagicMock(return_value=False)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_remove", mock_remove
with patch(
f"{_MOD}._CRON_AVAILABLE", True
), patch(
f"{_MOD}._cron_remove", mock_remove
):
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
assert resp.status == 404
@ -401,10 +403,10 @@ class TestPauseJob:
paused_job = {**SAMPLE_JOB, "enabled": False}
mock_pause = MagicMock(return_value=paused_job)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_pause", mock_pause
with patch(
f"{_MOD}._CRON_AVAILABLE", True
), patch(
f"{_MOD}._cron_pause", mock_pause
):
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/pause")
assert resp.status == 200
@ -426,10 +428,10 @@ class TestResumeJob:
resumed_job = {**SAMPLE_JOB, "enabled": True}
mock_resume = MagicMock(return_value=resumed_job)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_resume", mock_resume
with patch(
f"{_MOD}._CRON_AVAILABLE", True
), patch(
f"{_MOD}._cron_resume", mock_resume
):
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/resume")
assert resp.status == 200
@ -451,10 +453,10 @@ class TestRunJob:
triggered_job = {**SAMPLE_JOB, "last_run": "2025-01-01T00:00:00Z"}
mock_trigger = MagicMock(return_value=triggered_job)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_trigger", mock_trigger
with patch(
f"{_MOD}._CRON_AVAILABLE", True
), patch(
f"{_MOD}._cron_trigger", mock_trigger
):
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/run")
assert resp.status == 200
@ -473,7 +475,7 @@ class TestAuthRequired:
"""GET /api/jobs without API key returns 401 when key is set."""
app = _create_app(auth_adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
with patch(f"{_MOD}._CRON_AVAILABLE", True):
resp = await cli.get("/api/jobs")
assert resp.status == 401
@ -482,7 +484,7 @@ class TestAuthRequired:
"""POST /api/jobs without API key returns 401 when key is set."""
app = _create_app(auth_adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
with patch(f"{_MOD}._CRON_AVAILABLE", True):
resp = await cli.post("/api/jobs", json={
"name": "test", "schedule": "* * * * *",
})
@ -493,7 +495,7 @@ class TestAuthRequired:
"""GET /api/jobs/{id} without API key returns 401 when key is set."""
app = _create_app(auth_adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
with patch(f"{_MOD}._CRON_AVAILABLE", True):
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
assert resp.status == 401
@ -502,7 +504,7 @@ class TestAuthRequired:
"""DELETE /api/jobs/{id} without API key returns 401."""
app = _create_app(auth_adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
with patch(f"{_MOD}._CRON_AVAILABLE", True):
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
assert resp.status == 401
@ -512,10 +514,10 @@ class TestAuthRequired:
app = _create_app(auth_adapter)
mock_list = MagicMock(return_value=[])
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_list", mock_list
with patch(
f"{_MOD}._CRON_AVAILABLE", True
), patch(
f"{_MOD}._cron_list", mock_list
):
resp = await cli.get(
"/api/jobs",
@ -534,7 +536,7 @@ class TestCronUnavailable:
"""GET /api/jobs returns 501 when _CRON_AVAILABLE is False."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
with patch(f"{_MOD}._CRON_AVAILABLE", False):
resp = await cli.get("/api/jobs")
assert resp.status == 501
data = await resp.json()
@ -551,8 +553,8 @@ class TestCronUnavailable:
return SAMPLE_JOB
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True), patch.object(
APIServerAdapter, "_cron_pause", staticmethod(_plain_pause)
with patch(f"{_MOD}._CRON_AVAILABLE", True), patch(
f"{_MOD}._cron_pause", _plain_pause
):
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/pause")
assert resp.status == 200
@ -571,8 +573,8 @@ class TestCronUnavailable:
return [SAMPLE_JOB]
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True), patch.object(
APIServerAdapter, "_cron_list", staticmethod(_plain_list)
with patch(f"{_MOD}._CRON_AVAILABLE", True), patch(
f"{_MOD}._cron_list", _plain_list
):
resp = await cli.get("/api/jobs?include_disabled=true")
assert resp.status == 200
@ -593,8 +595,8 @@ class TestCronUnavailable:
return updated_job
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True), patch.object(
APIServerAdapter, "_cron_update", staticmethod(_plain_update)
with patch(f"{_MOD}._CRON_AVAILABLE", True), patch(
f"{_MOD}._cron_update", _plain_update
):
resp = await cli.patch(
f"/api/jobs/{VALID_JOB_ID}",
@ -611,7 +613,7 @@ class TestCronUnavailable:
"""POST /api/jobs returns 501 when _CRON_AVAILABLE is False."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
with patch(f"{_MOD}._CRON_AVAILABLE", False):
resp = await cli.post("/api/jobs", json={
"name": "test", "schedule": "* * * * *",
})
@ -622,7 +624,7 @@ class TestCronUnavailable:
"""GET /api/jobs/{id} returns 501 when _CRON_AVAILABLE is False."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
with patch(f"{_MOD}._CRON_AVAILABLE", False):
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
assert resp.status == 501
@ -631,7 +633,7 @@ class TestCronUnavailable:
"""DELETE /api/jobs/{id} returns 501 when _CRON_AVAILABLE is False."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
with patch(f"{_MOD}._CRON_AVAILABLE", False):
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
assert resp.status == 501
@ -640,7 +642,7 @@ class TestCronUnavailable:
"""POST /api/jobs/{id}/pause returns 501 when _CRON_AVAILABLE is False."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
with patch(f"{_MOD}._CRON_AVAILABLE", False):
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/pause")
assert resp.status == 501
@ -649,7 +651,7 @@ class TestCronUnavailable:
"""POST /api/jobs/{id}/resume returns 501 when _CRON_AVAILABLE is False."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
with patch(f"{_MOD}._CRON_AVAILABLE", False):
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/resume")
assert resp.status == 501
@ -658,6 +660,6 @@ class TestCronUnavailable:
"""POST /api/jobs/{id}/run returns 501 when _CRON_AVAILABLE is False."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
with patch(f"{_MOD}._CRON_AVAILABLE", False):
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/run")
assert resp.status == 501

View file

@ -0,0 +1,308 @@
"""End-to-end tests for inline image inputs on /v1/chat/completions and /v1/responses.
Covers the multimodal normalization path added to the API server. Unlike the
adapter-level tests that patch ``_run_agent``, these tests patch
``AIAgent.run_conversation`` instead so the adapter's full request-handling
path (including the ``run_agent`` prologue that used to crash on list content)
executes against a real aiohttp app.
"""
from unittest.mock import MagicMock, patch
import pytest
from aiohttp import web
from aiohttp.test_utils import TestClient, TestServer
from gateway.config import PlatformConfig
from gateway.platforms.api_server import (
APIServerAdapter,
_content_has_visible_payload,
_normalize_multimodal_content,
cors_middleware,
security_headers_middleware,
)
# ---------------------------------------------------------------------------
# Pure-function tests for _normalize_multimodal_content
# ---------------------------------------------------------------------------
class TestNormalizeMultimodalContent:
def test_string_passthrough(self):
assert _normalize_multimodal_content("hello") == "hello"
def test_none_returns_empty_string(self):
assert _normalize_multimodal_content(None) == ""
def test_text_only_list_collapses_to_string(self):
content = [{"type": "text", "text": "hi"}, {"type": "text", "text": "there"}]
assert _normalize_multimodal_content(content) == "hi\nthere"
def test_responses_input_text_canonicalized(self):
content = [{"type": "input_text", "text": "hello"}]
assert _normalize_multimodal_content(content) == "hello"
def test_image_url_preserved_with_text(self):
content = [
{"type": "text", "text": "describe this"},
{"type": "image_url", "image_url": {"url": "https://example.com/cat.png", "detail": "high"}},
]
out = _normalize_multimodal_content(content)
assert isinstance(out, list)
assert out == [
{"type": "text", "text": "describe this"},
{"type": "image_url", "image_url": {"url": "https://example.com/cat.png", "detail": "high"}},
]
def test_input_image_converted_to_canonical_shape(self):
content = [
{"type": "input_text", "text": "hi"},
{"type": "input_image", "image_url": "https://example.com/cat.png"},
]
out = _normalize_multimodal_content(content)
assert out == [
{"type": "text", "text": "hi"},
{"type": "image_url", "image_url": {"url": "https://example.com/cat.png"}},
]
def test_data_image_url_accepted(self):
content = [{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}}]
out = _normalize_multimodal_content(content)
assert out == [{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}}]
def test_non_image_data_url_rejected(self):
content = [{"type": "image_url", "image_url": {"url": "data:text/plain;base64,SGVsbG8="}}]
with pytest.raises(ValueError) as exc:
_normalize_multimodal_content(content)
assert str(exc.value).startswith("unsupported_content_type:")
def test_file_part_rejected(self):
with pytest.raises(ValueError) as exc:
_normalize_multimodal_content([{"type": "file", "file": {"file_id": "f_1"}}])
assert str(exc.value).startswith("unsupported_content_type:")
def test_input_file_part_rejected(self):
with pytest.raises(ValueError) as exc:
_normalize_multimodal_content([{"type": "input_file", "file_id": "f_1"}])
assert str(exc.value).startswith("unsupported_content_type:")
def test_missing_url_rejected(self):
with pytest.raises(ValueError) as exc:
_normalize_multimodal_content([{"type": "image_url", "image_url": {}}])
assert str(exc.value).startswith("invalid_image_url:")
def test_bad_scheme_rejected(self):
with pytest.raises(ValueError) as exc:
_normalize_multimodal_content([{"type": "image_url", "image_url": {"url": "ftp://example.com/x.png"}}])
assert str(exc.value).startswith("invalid_image_url:")
def test_unknown_part_type_rejected(self):
with pytest.raises(ValueError) as exc:
_normalize_multimodal_content([{"type": "audio", "audio": {}}])
assert str(exc.value).startswith("unsupported_content_type:")
class TestContentHasVisiblePayload:
def test_non_empty_string(self):
assert _content_has_visible_payload("hello")
def test_whitespace_only_string(self):
assert not _content_has_visible_payload(" ")
def test_list_with_image_only(self):
assert _content_has_visible_payload([{"type": "image_url", "image_url": {"url": "x"}}])
def test_list_with_only_empty_text(self):
assert not _content_has_visible_payload([{"type": "text", "text": ""}])
# ---------------------------------------------------------------------------
# HTTP integration — real aiohttp client hitting the adapter handlers
# ---------------------------------------------------------------------------
def _make_adapter() -> APIServerAdapter:
return APIServerAdapter(PlatformConfig(enabled=True))
def _create_app(adapter: APIServerAdapter) -> web.Application:
mws = [mw for mw in (cors_middleware, security_headers_middleware) if mw is not None]
app = web.Application(middlewares=mws)
app["api_server_adapter"] = adapter
app.router.add_post("/v1/chat/completions", adapter._handle_chat_completions)
app.router.add_post("/v1/responses", adapter._handle_responses)
app.router.add_get("/v1/responses/{response_id}", adapter._handle_get_response)
return app
@pytest.fixture
def adapter():
return _make_adapter()
class TestChatCompletionsMultimodalHTTP:
@pytest.mark.asyncio
async def test_inline_image_preserved_to_run_agent(self, adapter):
"""Multimodal user content reaches _run_agent as a list of parts."""
image_payload = [
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "https://example.com/cat.png", "detail": "high"}},
]
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(
adapter,
"_run_agent",
new=MagicMock(),
) as mock_run:
async def _stub(**kwargs):
mock_run.captured = kwargs
return (
{"final_response": "A cat.", "messages": [], "api_calls": 1},
{"input_tokens": 0, "output_tokens": 0, "total_tokens": 0},
)
mock_run.side_effect = _stub
resp = await cli.post(
"/v1/chat/completions",
json={
"model": "hermes-agent",
"messages": [{"role": "user", "content": image_payload}],
},
)
assert resp.status == 200, await resp.text()
assert mock_run.captured["user_message"] == image_payload
@pytest.mark.asyncio
async def test_text_only_array_collapses_to_string(self, adapter):
"""Text-only array becomes a plain string so logging stays unchanged."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(adapter, "_run_agent", new=MagicMock()) as mock_run:
async def _stub(**kwargs):
mock_run.captured = kwargs
return (
{"final_response": "ok", "messages": [], "api_calls": 1},
{"input_tokens": 0, "output_tokens": 0, "total_tokens": 0},
)
mock_run.side_effect = _stub
resp = await cli.post(
"/v1/chat/completions",
json={
"model": "hermes-agent",
"messages": [
{"role": "user", "content": [{"type": "text", "text": "hello"}]},
],
},
)
assert resp.status == 200, await resp.text()
assert mock_run.captured["user_message"] == "hello"
@pytest.mark.asyncio
async def test_file_part_returns_400(self, adapter):
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/v1/chat/completions",
json={
"model": "hermes-agent",
"messages": [
{"role": "user", "content": [{"type": "file", "file": {"file_id": "f_1"}}]},
],
},
)
assert resp.status == 400
body = await resp.json()
assert body["error"]["code"] == "unsupported_content_type"
assert body["error"]["param"] == "messages[0].content"
@pytest.mark.asyncio
async def test_non_image_data_url_returns_400(self, adapter):
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/v1/chat/completions",
json={
"model": "hermes-agent",
"messages": [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": "data:text/plain;base64,SGVsbG8="},
},
],
},
],
},
)
assert resp.status == 400
body = await resp.json()
assert body["error"]["code"] == "unsupported_content_type"
class TestResponsesMultimodalHTTP:
@pytest.mark.asyncio
async def test_input_image_canonicalized_and_forwarded(self, adapter):
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(adapter, "_run_agent", new=MagicMock()) as mock_run:
async def _stub(**kwargs):
mock_run.captured = kwargs
return (
{"final_response": "ok", "messages": [], "api_calls": 1},
{"input_tokens": 0, "output_tokens": 0, "total_tokens": 0},
)
mock_run.side_effect = _stub
resp = await cli.post(
"/v1/responses",
json={
"model": "hermes-agent",
"input": [
{
"role": "user",
"content": [
{"type": "input_text", "text": "Describe."},
{
"type": "input_image",
"image_url": "https://example.com/cat.png",
},
],
}
],
},
)
assert resp.status == 200, await resp.text()
expected = [
{"type": "text", "text": "Describe."},
{"type": "image_url", "image_url": {"url": "https://example.com/cat.png"}},
]
assert mock_run.captured["user_message"] == expected
@pytest.mark.asyncio
async def test_input_file_returns_400(self, adapter):
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/v1/responses",
json={
"model": "hermes-agent",
"input": [
{
"role": "user",
"content": [{"type": "input_file", "file_id": "f_1"}],
}
],
},
)
assert resp.status == 400
body = await resp.json()
assert body["error"]["code"] == "unsupported_content_type"

View file

@ -0,0 +1,148 @@
"""Regression test: cancel_background_tasks must drain late-arrival tasks.
During gateway shutdown, a message arriving while
cancel_background_tasks is mid-await can spawn a fresh
_process_message_background task via handle_message, which is added
to self._background_tasks. Without the re-drain loop, the subsequent
_background_tasks.clear() drops the reference; the task runs
untracked against a disconnecting adapter.
"""
import asyncio
from unittest.mock import AsyncMock
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType
from gateway.session import SessionSource, build_session_key
class _StubAdapter(BasePlatformAdapter):
async def connect(self):
pass
async def disconnect(self):
pass
async def send(self, chat_id, text, **kwargs):
return None
async def get_chat_info(self, chat_id):
return {}
def _make_adapter():
adapter = _StubAdapter(PlatformConfig(enabled=True, token="t"), Platform.TELEGRAM)
adapter._send_with_retry = AsyncMock(return_value=None)
return adapter
def _event(text, cid="42"):
return MessageEvent(
text=text,
message_type=MessageType.TEXT,
source=SessionSource(platform=Platform.TELEGRAM, chat_id=cid, chat_type="dm"),
)
@pytest.mark.asyncio
async def test_cancel_background_tasks_drains_late_arrivals():
"""A message that arrives during the gather window must be picked
up by the re-drain loop, not leaked as an untracked task."""
adapter = _make_adapter()
sk = build_session_key(
SessionSource(platform=Platform.TELEGRAM, chat_id="42", chat_type="dm")
)
m1_started = asyncio.Event()
m1_cleanup_running = asyncio.Event()
m2_started = asyncio.Event()
m2_cancelled = asyncio.Event()
async def handler(event):
if event.text == "M1":
m1_started.set()
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
m1_cleanup_running.set()
# Widen the gather window with a shielded cleanup
# delay so M2 can get injected during it.
await asyncio.shield(asyncio.sleep(0.2))
raise
else: # M2 — the late arrival
m2_started.set()
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
m2_cancelled.set()
raise
adapter._message_handler = handler
# Spawn M1.
await adapter.handle_message(_event("M1"))
await asyncio.wait_for(m1_started.wait(), timeout=1.0)
# Kick off shutdown. This will cancel M1 and await its cleanup.
cancel_task = asyncio.create_task(adapter.cancel_background_tasks())
# Wait until M1's cleanup is running (inside the shielded sleep).
# This is the race window: cancel_task is awaiting gather, M1 is
# shielded in cleanup, the _active_sessions entry has been cleared
# by M1's own finally.
await asyncio.wait_for(m1_cleanup_running.wait(), timeout=1.0)
# Clear the active-session entry (M1's finally hasn't fully run yet,
# but in production the platform dispatcher would deliver a new
# message that takes the no-active-session spawn path). For this
# repro, make it deterministic.
adapter._active_sessions.pop(sk, None)
# Inject late arrival — spawns a fresh _process_message_background
# task and adds it to _background_tasks while cancel_task is still
# in gather.
await adapter.handle_message(_event("M2"))
await asyncio.wait_for(m2_started.wait(), timeout=1.0)
# Let cancel_task finish. Round 1's gather completes when M1's
# shielded cleanup finishes. Round 2 should pick up M2.
await asyncio.wait_for(cancel_task, timeout=5.0)
# Assert M2 was drained, not leaked.
assert m2_cancelled.is_set(), (
"Late-arrival M2 was NOT cancelled by cancel_background_tasks — "
"the re-drain loop is missing and the task leaked"
)
assert adapter._background_tasks == set()
@pytest.mark.asyncio
async def test_cancel_background_tasks_handles_no_tasks():
"""Regression guard: no tasks, no hang, no error."""
adapter = _make_adapter()
await adapter.cancel_background_tasks()
assert adapter._background_tasks == set()
@pytest.mark.asyncio
async def test_cancel_background_tasks_bounded_rounds():
"""Regression guard: the drain loop is bounded — it does not spin
forever even if late-arrival tasks keep getting spawned."""
adapter = _make_adapter()
# Single well-behaved task that cancels cleanly — baseline check
# that the loop terminates in one round.
async def quick():
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
raise
task = asyncio.create_task(quick())
adapter._background_tasks.add(task)
await adapter.cancel_background_tasks()
assert task.done()
assert adapter._background_tasks == set()

View file

@ -0,0 +1,91 @@
"""Regression tests for the TUI gateway's `complete.path` handler.
Reported during the TUI v2 blitz retest: typing `@folder:` (and `@folder`
with no colon yet) still surfaced files alongside directories in the
TUI composer, because the gateway-side completion lives in
`tui_gateway/server.py` and was never touched by the earlier fix to
`hermes_cli/commands.py`.
Covers:
- `@folder:` only yields directories
- `@file:` only yields regular files
- Bare `@folder` / `@file` (no colon) lists cwd directly
- Explicit prefix is preserved in the completion text
"""
from __future__ import annotations
from pathlib import Path
from tui_gateway import server
def _fixture(tmp_path: Path):
(tmp_path / "readme.md").write_text("x")
(tmp_path / ".env").write_text("x")
(tmp_path / "src").mkdir()
(tmp_path / "docs").mkdir()
def _items(word: str):
resp = server.handle_request({"id": "1", "method": "complete.path", "params": {"word": word}})
return [(it["text"], it["display"], it.get("meta", "")) for it in resp["result"]["items"]]
def test_at_folder_colon_only_dirs(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
_fixture(tmp_path)
texts = [t for t, _, _ in _items("@folder:")]
assert all(t.startswith("@folder:") for t in texts), texts
assert any(t == "@folder:src/" for t in texts)
assert any(t == "@folder:docs/" for t in texts)
assert not any(t == "@folder:readme.md" for t in texts)
assert not any(t == "@folder:.env" for t in texts)
def test_at_file_colon_only_files(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
_fixture(tmp_path)
texts = [t for t, _, _ in _items("@file:")]
assert all(t.startswith("@file:") for t in texts), texts
assert any(t == "@file:readme.md" for t in texts)
assert not any(t == "@file:src/" for t in texts)
assert not any(t == "@file:docs/" for t in texts)
def test_at_folder_bare_without_colon_lists_dirs(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
_fixture(tmp_path)
texts = [t for t, _, _ in _items("@folder")]
assert any(t == "@folder:src/" for t in texts), texts
assert any(t == "@folder:docs/" for t in texts), texts
assert not any(t == "@folder:readme.md" for t in texts)
def test_at_file_bare_without_colon_lists_files(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
_fixture(tmp_path)
texts = [t for t, _, _ in _items("@file")]
assert any(t == "@file:readme.md" for t in texts), texts
assert not any(t == "@file:src/" for t in texts)
def test_bare_at_still_shows_static_refs(tmp_path, monkeypatch):
"""`@` alone should list the static references so users discover the
available prefixes. (Unchanged behaviour; regression guard.)
"""
monkeypatch.chdir(tmp_path)
texts = [t for t, _, _ in _items("@")]
for expected in ("@diff", "@staged", "@file:", "@folder:", "@url:", "@git:"):
assert expected in texts, f"missing static ref {expected!r} in {texts!r}"

View file

@ -75,7 +75,6 @@ def _make_runner():
runner._service_tier = None
runner._provider_routing = {}
runner._fallback_model = None
runner._smart_model_routing = {}
runner._running_agents = {}
runner._pending_model_notes = {}
runner._session_db = None

View file

@ -0,0 +1,79 @@
"""Discord adapter race polish: concurrent join_voice_channel must not
double-invoke channel.connect() on the same guild."""
import asyncio
from unittest.mock import MagicMock, patch
import pytest
from gateway.config import Platform, PlatformConfig
def _make_adapter():
from gateway.platforms.discord import DiscordAdapter
adapter = object.__new__(DiscordAdapter)
adapter._platform = Platform.DISCORD
adapter.config = PlatformConfig(enabled=True, token="t")
adapter._ready_event = asyncio.Event()
adapter._allowed_user_ids = set()
adapter._allowed_role_ids = set()
adapter._voice_clients = {}
adapter._voice_locks = {}
adapter._voice_receivers = {}
adapter._voice_listen_tasks = {}
adapter._voice_timeout_tasks = {}
adapter._voice_text_channels = {}
adapter._voice_sources = {}
adapter._client = MagicMock()
return adapter
@pytest.mark.asyncio
async def test_concurrent_joins_do_not_double_connect():
"""Two concurrent join_voice_channel calls on the same guild must
serialize through the per-guild lock only ONE channel.connect()
actually fires; the second sees the _voice_clients entry the first
just installed."""
adapter = _make_adapter()
connect_count = [0]
release = asyncio.Event()
class FakeVC:
def __init__(self, channel):
self.channel = channel
def is_connected(self):
return True
async def move_to(self, _channel):
return None
async def slow_connect(self):
connect_count[0] += 1
await release.wait()
return FakeVC(self)
channel = MagicMock()
channel.id = 111
channel.guild.id = 42
channel.connect = lambda: slow_connect(channel)
from gateway.platforms import discord as discord_mod
with patch.object(discord_mod, "VoiceReceiver",
MagicMock(return_value=MagicMock(start=lambda: None))):
with patch.object(discord_mod.asyncio, "ensure_future",
lambda _c: asyncio.create_task(asyncio.sleep(0))):
t1 = asyncio.create_task(adapter.join_voice_channel(channel))
t2 = asyncio.create_task(adapter.join_voice_channel(channel))
await asyncio.sleep(0.05)
release.set()
r1, r2 = await asyncio.gather(t1, t2)
assert connect_count[0] == 1, (
f"expected 1 channel.connect() call, got {connect_count[0]}"
"per-guild lock is not serializing join_voice_channel"
)
assert r1 is True and r2 is True
assert 42 in adapter._voice_clients

View file

@ -283,6 +283,48 @@ def test_persist_dm_topic_thread_id_skips_if_already_set(tmp_path):
# ── _get_dm_topic_info ──
def test_persist_dm_topic_thread_id_preserves_config_on_write_failure(tmp_path):
"""Failed writes should leave the original config.yaml intact."""
import yaml
config_data = {
"platforms": {
"telegram": {
"extra": {
"dm_topics": [
{
"chat_id": 111,
"topics": [
{"name": "General", "icon_color": 123},
],
}
]
}
}
}
}
config_file = tmp_path / ".hermes" / "config.yaml"
config_file.parent.mkdir(parents=True)
original_text = yaml.dump(config_data)
config_file.write_text(original_text, encoding="utf-8")
adapter = _make_adapter()
def fail_dump(*args, **kwargs):
raise RuntimeError("boom")
with patch.object(Path, "home", return_value=tmp_path), \
patch.dict(os.environ, {"HERMES_HOME": str(tmp_path / ".hermes")}), \
patch("yaml.dump", side_effect=fail_dump):
adapter._persist_dm_topic_thread_id(111, "General", 999)
assert config_file.read_text(encoding="utf-8") == original_text
result = yaml.safe_load(config_file.read_text(encoding="utf-8"))
topics = result["platforms"]["telegram"]["extra"]["dm_topics"][0]["topics"]
assert "thread_id" not in topics[0]
def test_get_dm_topic_info_finds_cached_topic():
"""Should return topic config when thread_id is in cache."""
adapter = _make_adapter([

View file

@ -4,7 +4,7 @@ import sys
import threading
import types
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock
import pytest
import yaml
@ -53,7 +53,6 @@ def _make_runner():
runner._service_tier = None
runner._provider_routing = {}
runner._fallback_model = None
runner._smart_model_routing = {}
runner._running_agents = {}
runner._pending_model_notes = {}
runner._session_db = None
@ -97,13 +96,7 @@ def test_turn_route_injects_priority_processing_without_changing_runtime():
"credential_pool": None,
}
with patch("agent.smart_model_routing.resolve_turn_route", return_value={
"model": "gpt-5.4",
"runtime": dict(runtime_kwargs),
"label": None,
"signature": ("gpt-5.4", "openrouter", "https://openrouter.ai/api/v1", "chat_completions", None, ()),
}):
route = gateway_run.GatewayRunner._resolve_turn_agent_config(runner, "hi", "gpt-5.4", runtime_kwargs)
route = gateway_run.GatewayRunner._resolve_turn_agent_config(runner, "hi", "gpt-5.4", runtime_kwargs)
assert route["runtime"]["provider"] == "openrouter"
assert route["runtime"]["api_mode"] == "chat_completions"
@ -123,13 +116,7 @@ def test_turn_route_skips_priority_processing_for_unsupported_models():
"credential_pool": None,
}
with patch("agent.smart_model_routing.resolve_turn_route", return_value={
"model": "gpt-5.3-codex",
"runtime": dict(runtime_kwargs),
"label": None,
"signature": ("gpt-5.3-codex", "openrouter", "https://openrouter.ai/api/v1", "chat_completions", None, ()),
}):
route = gateway_run.GatewayRunner._resolve_turn_agent_config(runner, "hi", "gpt-5.3-codex", runtime_kwargs)
route = gateway_run.GatewayRunner._resolve_turn_agent_config(runner, "hi", "gpt-5.3-codex", runtime_kwargs)
assert route["request_overrides"] is None

View file

@ -10,6 +10,8 @@ from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock, patch
from gateway.platforms.base import ProcessingOutcome
try:
import lark_oapi
_HAS_LARK_OAPI = True
@ -638,83 +640,54 @@ class TestAdapterBehavior(unittest.TestCase):
)
@patch.dict(os.environ, {}, clear=True)
@unittest.skipUnless(_HAS_LARK_OAPI, "lark-oapi not installed")
def test_add_ack_reaction_uses_ok_emoji(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
captured = {}
class _ReactionAPI:
def create(self, request):
captured["request"] = request
return SimpleNamespace(
success=lambda: True,
data=SimpleNamespace(reaction_id="r_typing"),
)
adapter._client = SimpleNamespace(
im=SimpleNamespace(v1=SimpleNamespace(message_reaction=_ReactionAPI()))
)
async def _direct(func, *args, **kwargs):
return func(*args, **kwargs)
with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct):
reaction_id = asyncio.run(adapter._add_ack_reaction("om_msg"))
self.assertEqual(reaction_id, "r_typing")
self.assertEqual(captured["request"].request_body.reaction_type["emoji_type"], "OK")
@patch.dict(os.environ, {}, clear=True)
def test_add_ack_reaction_logs_warning_on_failure(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
class _ReactionAPI:
def create(self, request):
raise RuntimeError("boom")
adapter._client = SimpleNamespace(
im=SimpleNamespace(v1=SimpleNamespace(message_reaction=_ReactionAPI()))
)
async def _direct(func, *args, **kwargs):
return func(*args, **kwargs)
with (
patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct),
self.assertLogs("gateway.platforms.feishu", level="WARNING") as logs,
):
reaction_id = asyncio.run(adapter._add_ack_reaction("om_msg"))
self.assertIsNone(reaction_id)
self.assertTrue(
any("Failed to add ack reaction to om_msg" in entry for entry in logs.output),
logs.output,
)
@patch.dict(os.environ, {}, clear=True)
def test_ack_reaction_events_are_ignored_to_avoid_feedback_loops(self):
def test_bot_origin_reactions_are_dropped_to_avoid_feedback_loops(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
adapter._loop = object()
for emoji in ("Typing", "CrossMark"):
event = SimpleNamespace(
message_id="om_msg",
operator_type="bot",
reaction_type=SimpleNamespace(emoji_type=emoji),
)
data = SimpleNamespace(event=event)
with patch(
"gateway.platforms.feishu.asyncio.run_coroutine_threadsafe"
) as run_threadsafe:
adapter._on_reaction_event("im.message.reaction.created_v1", data)
run_threadsafe.assert_not_called()
@patch.dict(os.environ, {}, clear=True)
def test_user_reaction_with_managed_emoji_is_still_routed(self):
# Operator-origin filter is enough to prevent feedback loops; we must
# not additionally swallow user-origin reactions just because their
# emoji happens to collide with a lifecycle emoji.
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
adapter._loop = SimpleNamespace(is_closed=lambda: False)
event = SimpleNamespace(
message_id="om_msg",
operator_type="user",
reaction_type=SimpleNamespace(emoji_type="OK"),
reaction_type=SimpleNamespace(emoji_type="Typing"),
)
data = SimpleNamespace(event=event)
with patch("gateway.platforms.feishu.asyncio.run_coroutine_threadsafe") as run_threadsafe:
adapter._on_reaction_event("im.message.reaction.created_v1", data)
def _close_coro_and_return_future(coro, _loop):
coro.close()
return SimpleNamespace(add_done_callback=lambda _: None)
run_threadsafe.assert_not_called()
with patch(
"gateway.platforms.feishu.asyncio.run_coroutine_threadsafe",
side_effect=_close_coro_and_return_future,
) as run_threadsafe:
adapter._on_reaction_event("im.message.reaction.created_v1", data)
run_threadsafe.assert_called_once()
@patch.dict(os.environ, {"FEISHU_GROUP_POLICY": "open"}, clear=True)
def test_group_message_requires_mentions_even_when_policy_open(self):
@ -743,6 +716,57 @@ class TestAdapterBehavior(unittest.TestCase):
self.assertFalse(adapter._should_accept_group_message(SimpleNamespace(mentions=[other_mention]), sender_id, ""))
@patch.dict(
os.environ,
{
"FEISHU_BOT_OPEN_ID": "ou_hermes",
"FEISHU_BOT_USER_ID": "u_hermes",
},
clear=True,
)
def test_other_bot_sender_is_not_treated_as_self_sent_message(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
event = SimpleNamespace(
sender=SimpleNamespace(
sender_type="bot",
sender_id=SimpleNamespace(open_id="ou_other_bot", user_id="u_other_bot"),
)
)
self.assertFalse(adapter._is_self_sent_bot_message(event))
@patch.dict(
os.environ,
{
"FEISHU_BOT_OPEN_ID": "ou_hermes",
"FEISHU_BOT_USER_ID": "u_hermes",
},
clear=True,
)
def test_self_bot_sender_is_treated_as_self_sent_message(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
by_open_id = SimpleNamespace(
sender=SimpleNamespace(
sender_type="bot",
sender_id=SimpleNamespace(open_id="ou_hermes", user_id="u_other"),
)
)
by_user_id = SimpleNamespace(
sender=SimpleNamespace(
sender_type="app",
sender_id=SimpleNamespace(open_id="ou_other", user_id="u_hermes"),
)
)
self.assertTrue(adapter._is_self_sent_bot_message(by_open_id))
self.assertTrue(adapter._is_self_sent_bot_message(by_user_id))
@patch.dict(
os.environ,
{
@ -2370,6 +2394,134 @@ class TestAdapterBehavior(unittest.TestCase):
elements = payload["zh_cn"]["content"][0]
self.assertEqual(elements, [{"tag": "md", "text": "可以用 **粗体** 和 *斜体*。"}])
@patch.dict(os.environ, {}, clear=True)
def test_send_splits_fenced_code_blocks_into_separate_post_rows(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
captured = {}
class _MessageAPI:
def create(self, request):
captured["request"] = request
return SimpleNamespace(
success=lambda: True,
data=SimpleNamespace(message_id="om_codeblock"),
)
adapter._client = SimpleNamespace(
im=SimpleNamespace(
v1=SimpleNamespace(
message=_MessageAPI(),
)
)
)
async def _direct(func, *args, **kwargs):
return func(*args, **kwargs)
content = (
"确认已入库 ✓\n"
"文件路径:`/root/.hermes/profiles/agent_cto/cron/jobs.json`\n"
"**解码后的内容:**\n"
"```json\n"
'{"cron": "list"}\n'
"```\n"
"后续说明仍应保留。"
)
with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct):
result = asyncio.run(
adapter.send(
chat_id="oc_chat",
content=content,
)
)
self.assertTrue(result.success)
self.assertEqual(captured["request"].request_body.msg_type, "post")
payload = json.loads(captured["request"].request_body.content)
rows = payload["zh_cn"]["content"]
self.assertEqual(
rows,
[
[
{
"tag": "md",
"text": "确认已入库 ✓\n文件路径:`/root/.hermes/profiles/agent_cto/cron/jobs.json`\n**解码后的内容:**",
}
],
[{"tag": "md", "text": "```json\n{\"cron\": \"list\"}\n```"}],
[{"tag": "md", "text": "后续说明仍应保留。"}],
],
)
@patch.dict(os.environ, {}, clear=True)
def test_build_post_payload_keeps_fence_like_code_lines_inside_code_block(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
payload = json.loads(
adapter._build_post_payload(
"before\n```python\n```oops\n```\nafter"
)
)
self.assertEqual(
payload["zh_cn"]["content"],
[
[{"tag": "md", "text": "before"}],
[{"tag": "md", "text": "```python\n```oops\n```"}],
[{"tag": "md", "text": "after"}],
],
)
@patch.dict(os.environ, {}, clear=True)
def test_build_post_payload_preserves_trailing_spaces_in_code_block(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
payload = json.loads(
adapter._build_post_payload(
"before\n```python\nline with two spaces \n```\nafter"
)
)
self.assertEqual(
payload["zh_cn"]["content"],
[
[{"tag": "md", "text": "before"}],
[{"tag": "md", "text": "```python\nline with two spaces \n```"}],
[{"tag": "md", "text": "after"}],
],
)
@patch.dict(os.environ, {}, clear=True)
def test_build_post_payload_splits_multiple_fenced_code_blocks(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
payload = json.loads(
adapter._build_post_payload(
"before\n```python\nprint(1)\n```\nmiddle\n```json\n{}\n```\nafter"
)
)
self.assertEqual(
payload["zh_cn"]["content"],
[
[{"tag": "md", "text": "before"}],
[{"tag": "md", "text": "```python\nprint(1)\n```"}],
[{"tag": "md", "text": "middle"}],
[{"tag": "md", "text": "```json\n{}\n```"}],
[{"tag": "md", "text": "after"}],
],
)
@patch.dict(os.environ, {}, clear=True)
def test_send_falls_back_to_text_when_post_payload_is_rejected(self):
from gateway.config import PlatformConfig
@ -2505,6 +2657,135 @@ class TestAdapterBehavior(unittest.TestCase):
)
@unittest.skipUnless(_HAS_LARK_OAPI, "lark-oapi not installed")
class TestHydrateBotIdentity(unittest.TestCase):
"""Hydration of bot identity via /open-apis/bot/v3/info and application info.
Covers the manual-setup path where FEISHU_BOT_OPEN_ID / FEISHU_BOT_USER_ID
are not configured. Hydration must populate _bot_open_id so that
_is_self_sent_bot_message() can filter the adapter's own outbound echoes.
"""
def _make_adapter(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
return FeishuAdapter(PlatformConfig())
@patch.dict(os.environ, {}, clear=True)
def test_hydration_populates_open_id_from_bot_info(self):
adapter = self._make_adapter()
adapter._client = Mock()
payload = json.dumps(
{
"code": 0,
"bot": {
"bot_name": "Hermes Bot",
"open_id": "ou_hermes_hydrated",
},
}
).encode("utf-8")
response = SimpleNamespace(content=payload)
adapter._client.request = Mock(return_value=response)
asyncio.run(adapter._hydrate_bot_identity())
self.assertEqual(adapter._bot_open_id, "ou_hermes_hydrated")
self.assertEqual(adapter._bot_name, "Hermes Bot")
# Application-info fallback must NOT run when bot_name is already set.
self.assertFalse(
adapter._client.application.v6.application.get.called
if hasattr(adapter._client, "application") else False
)
@patch.dict(
os.environ,
{
"FEISHU_BOT_OPEN_ID": "ou_env",
"FEISHU_BOT_NAME": "Env Hermes",
},
clear=True,
)
def test_hydration_skipped_when_env_vars_supply_both_fields(self):
adapter = self._make_adapter()
adapter._client = Mock()
adapter._client.request = Mock()
asyncio.run(adapter._hydrate_bot_identity())
# Neither probe should run — both fields are already populated.
adapter._client.request.assert_not_called()
self.assertEqual(adapter._bot_open_id, "ou_env")
self.assertEqual(adapter._bot_name, "Env Hermes")
@patch.dict(os.environ, {"FEISHU_BOT_OPEN_ID": "ou_env"}, clear=True)
def test_hydration_fills_only_missing_fields(self):
"""Env-var open_id must NOT be overwritten by a different probe value."""
adapter = self._make_adapter()
adapter._client = Mock()
payload = json.dumps(
{
"code": 0,
"bot": {
"bot_name": "Hermes Bot",
"open_id": "ou_probe_DIFFERENT",
},
}
).encode("utf-8")
adapter._client.request = Mock(return_value=SimpleNamespace(content=payload))
asyncio.run(adapter._hydrate_bot_identity())
self.assertEqual(adapter._bot_open_id, "ou_env") # preserved
self.assertEqual(adapter._bot_name, "Hermes Bot") # filled in
@patch.dict(os.environ, {}, clear=True)
def test_hydration_tolerates_probe_failure_and_falls_back_to_app_info(self):
adapter = self._make_adapter()
adapter._client = Mock()
adapter._client.request = Mock(side_effect=RuntimeError("network down"))
# Make the application-info fallback succeed for _bot_name.
app_response = Mock()
app_response.success = Mock(return_value=True)
app_response.data = SimpleNamespace(app=SimpleNamespace(app_name="Fallback Bot"))
adapter._client.application.v6.application.get = Mock(return_value=app_response)
adapter._build_get_application_request = Mock(return_value=object())
asyncio.run(adapter._hydrate_bot_identity())
# Primary probe failed — open_id stays empty, but bot_name came from app-info.
self.assertEqual(adapter._bot_open_id, "")
self.assertEqual(adapter._bot_name, "Fallback Bot")
@patch.dict(os.environ, {}, clear=True)
def test_hydrated_open_id_enables_self_send_filter(self):
"""E2E: after hydration, _is_self_sent_bot_message() rejects adapter's own id."""
adapter = self._make_adapter()
adapter._client = Mock()
payload = json.dumps(
{"code": 0, "bot": {"bot_name": "Hermes", "open_id": "ou_hermes"}}
).encode("utf-8")
adapter._client.request = Mock(return_value=SimpleNamespace(content=payload))
asyncio.run(adapter._hydrate_bot_identity())
self_event = SimpleNamespace(
sender=SimpleNamespace(
sender_type="bot",
sender_id=SimpleNamespace(open_id="ou_hermes", user_id=""),
)
)
peer_event = SimpleNamespace(
sender=SimpleNamespace(
sender_type="bot",
sender_id=SimpleNamespace(open_id="ou_peer_bot", user_id=""),
)
)
self.assertTrue(adapter._is_self_sent_bot_message(self_event))
self.assertFalse(adapter._is_self_sent_bot_message(peer_event))
@unittest.skipUnless(_HAS_LARK_OAPI, "lark-oapi not installed")
class TestPendingInboundQueue(unittest.TestCase):
"""Tests for the loop-not-ready race (#5499): inbound events arriving
@ -2970,3 +3251,231 @@ class TestSenderNameResolution(unittest.TestCase):
result = asyncio.run(adapter._resolve_sender_name_from_api("ou_broken"))
self.assertIsNone(result)
@unittest.skipUnless(_HAS_LARK_OAPI, "lark-oapi not installed")
class TestProcessingReactions(unittest.TestCase):
"""Typing on start → removed on SUCCESS, swapped for CrossMark on FAILURE,
removed (no replacement) on CANCELLED."""
@staticmethod
def _run(coro):
return asyncio.run(coro)
def _build_adapter(
self,
create_success: bool = True,
delete_success: bool = True,
next_reaction_id: str = "r1",
):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
tracker = SimpleNamespace(
create_calls=[],
delete_calls=[],
next_reaction_id=next_reaction_id,
create_success=create_success,
delete_success=delete_success,
)
def _create(request):
tracker.create_calls.append(
request.request_body.reaction_type["emoji_type"]
)
if tracker.create_success:
return SimpleNamespace(
success=lambda: True,
data=SimpleNamespace(reaction_id=tracker.next_reaction_id),
)
return SimpleNamespace(
success=lambda: False, code=99, msg="rejected", data=None,
)
def _delete(request):
tracker.delete_calls.append(request.reaction_id)
return SimpleNamespace(
success=lambda: tracker.delete_success,
code=0 if tracker.delete_success else 99,
msg="success" if tracker.delete_success else "rejected",
)
adapter._client = SimpleNamespace(
im=SimpleNamespace(
v1=SimpleNamespace(
message_reaction=SimpleNamespace(create=_create, delete=_delete),
),
),
)
return adapter, tracker
@staticmethod
def _event(message_id: str = "om_msg"):
return SimpleNamespace(message_id=message_id)
def _patch_to_thread(self):
async def _direct(func, *args, **kwargs):
return func(*args, **kwargs)
return patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct)
# ------------------------------------------------------------------ start
@patch.dict(os.environ, {}, clear=True)
def test_start_adds_typing_and_caches_reaction_id(self):
adapter, tracker = self._build_adapter(next_reaction_id="r_typing")
with self._patch_to_thread():
self._run(adapter.on_processing_start(self._event()))
self.assertEqual(tracker.create_calls, ["Typing"])
self.assertEqual(adapter._pending_processing_reactions["om_msg"], "r_typing")
@patch.dict(os.environ, {}, clear=True)
def test_start_is_idempotent_for_same_message_id(self):
adapter, tracker = self._build_adapter(next_reaction_id="r_typing")
with self._patch_to_thread():
self._run(adapter.on_processing_start(self._event()))
self._run(adapter.on_processing_start(self._event()))
self.assertEqual(tracker.create_calls, ["Typing"])
@patch.dict(os.environ, {}, clear=True)
def test_start_does_not_cache_when_create_fails(self):
adapter, tracker = self._build_adapter(create_success=False)
with self._patch_to_thread():
self._run(adapter.on_processing_start(self._event()))
self.assertEqual(tracker.create_calls, ["Typing"])
self.assertNotIn("om_msg", adapter._pending_processing_reactions)
# --------------------------------------------------------------- complete
@patch.dict(os.environ, {}, clear=True)
def test_success_removes_typing_and_adds_nothing(self):
adapter, tracker = self._build_adapter(next_reaction_id="r_typing")
with self._patch_to_thread():
self._run(adapter.on_processing_start(self._event()))
self._run(
adapter.on_processing_complete(self._event(), ProcessingOutcome.SUCCESS)
)
self.assertEqual(tracker.create_calls, ["Typing"])
self.assertEqual(tracker.delete_calls, ["r_typing"])
self.assertNotIn("om_msg", adapter._pending_processing_reactions)
@patch.dict(os.environ, {}, clear=True)
def test_failure_removes_typing_then_adds_cross_mark(self):
adapter, tracker = self._build_adapter(next_reaction_id="r_typing")
with self._patch_to_thread():
self._run(adapter.on_processing_start(self._event()))
self._run(
adapter.on_processing_complete(self._event(), ProcessingOutcome.FAILURE)
)
self.assertEqual(tracker.create_calls, ["Typing", "CrossMark"])
self.assertEqual(tracker.delete_calls, ["r_typing"])
@patch.dict(os.environ, {}, clear=True)
def test_cancelled_removes_typing_and_adds_nothing(self):
adapter, tracker = self._build_adapter(next_reaction_id="r_typing")
with self._patch_to_thread():
self._run(adapter.on_processing_start(self._event()))
self._run(
adapter.on_processing_complete(self._event(), ProcessingOutcome.CANCELLED)
)
self.assertEqual(tracker.create_calls, ["Typing"])
self.assertEqual(tracker.delete_calls, ["r_typing"])
self.assertNotIn("om_msg", adapter._pending_processing_reactions)
@patch.dict(os.environ, {}, clear=True)
def test_failure_without_preceding_start_still_adds_cross_mark(self):
adapter, tracker = self._build_adapter()
with self._patch_to_thread():
self._run(
adapter.on_processing_complete(self._event(), ProcessingOutcome.FAILURE)
)
self.assertEqual(tracker.create_calls, ["CrossMark"])
self.assertEqual(tracker.delete_calls, [])
@patch.dict(os.environ, {}, clear=True)
def test_success_without_preceding_start_is_full_noop(self):
adapter, tracker = self._build_adapter()
with self._patch_to_thread():
self._run(
adapter.on_processing_complete(self._event(), ProcessingOutcome.SUCCESS)
)
self.assertEqual(tracker.create_calls, [])
self.assertEqual(tracker.delete_calls, [])
# ------------------------- delete failure: don't stack badges -----------
@patch.dict(os.environ, {}, clear=True)
def test_delete_failure_on_failure_outcome_skips_cross_mark(self):
# Removing Typing is best-effort — but if it fails, we must NOT
# additionally add CrossMark, or the UI would show two contradictory
# badges. The handle stays in the cache for LRU to clean up later.
adapter, tracker = self._build_adapter(
next_reaction_id="r_typing", delete_success=False,
)
with self._patch_to_thread():
self._run(adapter.on_processing_start(self._event()))
self._run(
adapter.on_processing_complete(self._event(), ProcessingOutcome.FAILURE)
)
self.assertEqual(tracker.create_calls, ["Typing"]) # CrossMark NOT added
self.assertEqual(tracker.delete_calls, ["r_typing"]) # delete was attempted
self.assertEqual(
adapter._pending_processing_reactions["om_msg"], "r_typing",
) # handle retained
@patch.dict(os.environ, {}, clear=True)
def test_delete_failure_on_success_outcome_retains_handle(self):
adapter, tracker = self._build_adapter(
next_reaction_id="r_typing", delete_success=False,
)
with self._patch_to_thread():
self._run(adapter.on_processing_start(self._event()))
self._run(
adapter.on_processing_complete(self._event(), ProcessingOutcome.SUCCESS)
)
self.assertEqual(tracker.create_calls, ["Typing"])
self.assertEqual(tracker.delete_calls, ["r_typing"])
self.assertEqual(
adapter._pending_processing_reactions["om_msg"], "r_typing",
)
# ------------------------------------------------------------- env toggle
@patch.dict(os.environ, {"FEISHU_REACTIONS": "false"}, clear=True)
def test_env_disable_short_circuits_both_hooks(self):
adapter, tracker = self._build_adapter()
with self._patch_to_thread():
self._run(adapter.on_processing_start(self._event()))
self._run(
adapter.on_processing_complete(self._event(), ProcessingOutcome.FAILURE)
)
self.assertEqual(tracker.create_calls, [])
self.assertEqual(tracker.delete_calls, [])
# ------------------------------------------------------------- LRU bounds
@patch.dict(os.environ, {}, clear=True)
def test_cache_evicts_oldest_entry_beyond_size_limit(self):
from gateway.platforms.feishu import _FEISHU_PROCESSING_REACTION_CACHE_SIZE
adapter, _ = self._build_adapter()
counter = {"n": 0}
def _create(_request):
counter["n"] += 1
return SimpleNamespace(
success=lambda: True,
data=SimpleNamespace(reaction_id=f"r{counter['n']}"),
)
adapter._client.im.v1.message_reaction.create = _create
with self._patch_to_thread():
for i in range(_FEISHU_PROCESSING_REACTION_CACHE_SIZE + 1):
self._run(adapter.on_processing_start(self._event(f"om_{i}")))
self.assertNotIn("om_0", adapter._pending_processing_reactions)
self.assertIn(
f"om_{_FEISHU_PROCESSING_REACTION_CACHE_SIZE}",
adapter._pending_processing_reactions,
)
self.assertEqual(
len(adapter._pending_processing_reactions),
_FEISHU_PROCESSING_REACTION_CACHE_SIZE,
)

View file

@ -355,8 +355,17 @@ async def test_none_user_id_does_not_generate_pairing_code(monkeypatch, tmp_path
async def test_non_internal_event_without_user_triggers_pairing(monkeypatch, tmp_path):
"""Verify the normal (non-internal) path still triggers pairing for unknown users."""
import gateway.run as gateway_run
import gateway.pairing as pairing_mod
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
# gateway.pairing.PAIRING_DIR is a module-level constant captured at
# import time from whichever HERMES_HOME was set then. Per-test
# HERMES_HOME redirection in conftest doesn't retroactively move it.
# Override directly so pairing rate-limit state lives in this test's
# tmp_path (and so stale state from prior xdist workers can't leak in).
pairing_dir = tmp_path / "pairing"
pairing_dir.mkdir()
monkeypatch.setattr(pairing_mod, "PAIRING_DIR", pairing_dir)
(tmp_path / "config.yaml").write_text("", encoding="utf-8")
# Clear env vars that could let all users through (loaded by

View file

@ -1,13 +1,18 @@
"""Tests for the pending_event None guard in recursive _run_agent calls.
"""Tests for pending follow-up extraction in recursive _run_agent calls.
When pending_event is None (Path B: pending comes from interrupt_message),
accessing pending_event.channel_prompt previously raised AttributeError.
This verifies the fix: channel_prompt is captured inside the
`if pending_event is not None:` block and falls back to None otherwise.
Also verifies that internal control interrupt reasons like "Stop requested"
do not get recycled into the pending-user-message follow-up path.
"""
from types import SimpleNamespace
from gateway.run import _is_control_interrupt_message
def _extract_channel_prompt(pending_event):
"""Reproduce the fixed logic from gateway/run.py.
@ -21,6 +26,15 @@ def _extract_channel_prompt(pending_event):
return next_channel_prompt
def _extract_pending_text(interrupted, pending_event, interrupt_message):
"""Reproduce the fixed pending-text selection from gateway/run.py."""
if interrupted and pending_event is None and interrupt_message:
if _is_control_interrupt_message(interrupt_message):
return None
return interrupt_message
return None
class TestPendingEventNoneChannelPrompt:
"""Guard against AttributeError when pending_event is None."""
@ -40,3 +54,19 @@ class TestPendingEventNoneChannelPrompt:
event = SimpleNamespace()
result = _extract_channel_prompt(event)
assert result is None
class TestControlInterruptMessages:
"""Control interrupt reasons must not become follow-up user input."""
def test_stop_requested_is_not_treated_as_pending_user_message(self):
result = _extract_pending_text(True, None, "Stop requested")
assert result is None
def test_session_reset_requested_is_not_treated_as_pending_user_message(self):
result = _extract_pending_text(True, None, "Session reset requested")
assert result is None
def test_real_user_interrupt_message_still_requeues(self):
result = _extract_pending_text(True, None, "actually use postgres instead")
assert result == "actually use postgres instead"

View file

@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import Platform, StreamingConfig
from gateway.platforms.base import resolve_proxy_url
from gateway.run import GatewayRunner
from gateway.session import SessionSource
@ -19,6 +20,7 @@ def _make_runner(proxy_url=None):
runner.config = MagicMock()
runner.config.streaming = StreamingConfig()
runner._running_agents = {}
runner._session_run_generation = {}
runner._session_model_overrides = {}
runner._agent_cache = {}
runner._agent_cache_lock = None
@ -132,6 +134,15 @@ class TestGetProxyUrl:
assert runner._get_proxy_url() is None
class TestResolveProxyUrl:
def test_normalizes_socks_alias_from_all_proxy(self, monkeypatch):
for key in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY",
"https_proxy", "http_proxy", "all_proxy"):
monkeypatch.delenv(key, raising=False)
monkeypatch.setenv("ALL_PROXY", "socks://127.0.0.1:1080/")
assert resolve_proxy_url() == "socks5://127.0.0.1:1080/"
class TestRunAgentProxyDispatch:
"""Test that _run_agent() delegates to proxy when configured."""
@ -160,10 +171,12 @@ class TestRunAgentProxyDispatch:
source=source,
session_id="test-session-123",
session_key="test-key",
run_generation=7,
)
assert result["final_response"] == "Hello from remote!"
runner._run_agent_via_proxy.assert_called_once()
assert runner._run_agent_via_proxy.call_args.kwargs["run_generation"] == 7
@pytest.mark.asyncio
async def test_run_agent_skips_proxy_when_not_configured(self, monkeypatch):
@ -370,6 +383,40 @@ class TestRunAgentViaProxy:
assert "session_id" in result
assert result["session_id"] == "sess-123"
@pytest.mark.asyncio
async def test_proxy_stale_generation_returns_empty_result(self, monkeypatch):
monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")
monkeypatch.delenv("GATEWAY_PROXY_KEY", raising=False)
runner = _make_runner()
source = _make_source()
runner._session_run_generation["test-key"] = 2
resp = _FakeSSEResponse(
status=200,
sse_chunks=[
'data: {"choices":[{"delta":{"content":"stale"}}]}\n\n',
"data: [DONE]\n\n",
],
)
session = _FakeSession(resp)
with patch("gateway.run._load_gateway_config", return_value={}):
with _patch_aiohttp(session):
with patch("aiohttp.ClientTimeout"):
result = await runner._run_agent_via_proxy(
message="hi",
context_prompt="",
history=[],
source=source,
session_id="sess-123",
session_key="test-key",
run_generation=1,
)
assert result["final_response"] == ""
assert result["messages"] == []
assert result["api_calls"] == 0
@pytest.mark.asyncio
async def test_no_auth_header_without_key(self, monkeypatch):
monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")

View file

@ -0,0 +1,159 @@
"""Tests for reply-to pointer injection in _prepare_inbound_message_text.
The `[Replying to: "..."]` prefix is a *disambiguation pointer*, not
deduplication. It must always be injected when the user explicitly replies
to a prior message even when the quoted text already exists somewhere
in the conversation history. History can contain the same or similar text
multiple times, and without an explicit pointer the agent has to guess
which prior message the user is referencing.
"""
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent
from gateway.run import GatewayRunner
from gateway.session import SessionSource
def _make_runner() -> GatewayRunner:
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake")},
)
runner.adapters = {}
runner._model = "openai/gpt-4.1-mini"
runner._base_url = None
return runner
def _source() -> SessionSource:
return SessionSource(
platform=Platform.TELEGRAM,
chat_id="123",
chat_name="DM",
chat_type="private",
user_name="Alice",
)
@pytest.mark.asyncio
async def test_reply_prefix_injected_when_text_absent_from_history():
runner = _make_runner()
source = _source()
event = MessageEvent(
text="What's the best time to go?",
source=source,
reply_to_message_id="42",
reply_to_text="Japan is great for culture, food, and efficiency.",
)
result = await runner._prepare_inbound_message_text(
event=event,
source=source,
history=[{"role": "user", "content": "unrelated"}],
)
assert result is not None
assert result.startswith(
'[Replying to: "Japan is great for culture, food, and efficiency."]'
)
assert result.endswith("What's the best time to go?")
@pytest.mark.asyncio
async def test_reply_prefix_still_injected_when_text_in_history():
"""Regression test: the pointer must survive even when the quoted text
already appears in history. Previously a `found_in_history` guard
silently dropped the prefix, leaving the agent to guess which prior
message the user was referencing."""
runner = _make_runner()
source = _source()
quoted = "Japan is great for culture, food, and efficiency."
event = MessageEvent(
text="What's the best time to go?",
source=source,
reply_to_message_id="42",
reply_to_text=quoted,
)
history = [
{"role": "user", "content": "I'm thinking of going to Japan or Italy."},
{
"role": "assistant",
"content": (
f"{quoted} Italy is better if you prefer a relaxed pace."
),
},
{"role": "user", "content": "How long should I stay?"},
{"role": "assistant", "content": "For Japan, 10-14 days is ideal."},
]
result = await runner._prepare_inbound_message_text(
event=event,
source=source,
history=history,
)
assert result is not None
assert result.startswith(f'[Replying to: "{quoted}"]')
assert result.endswith("What's the best time to go?")
@pytest.mark.asyncio
async def test_no_prefix_without_reply_context():
runner = _make_runner()
source = _source()
event = MessageEvent(text="hello", source=source)
result = await runner._prepare_inbound_message_text(
event=event,
source=source,
history=[],
)
assert result == "hello"
@pytest.mark.asyncio
async def test_no_prefix_when_reply_to_text_is_empty():
"""reply_to_message_id alone without text (e.g. a reply to a media-only
message) should not produce an empty `[Replying to: ""]` prefix."""
runner = _make_runner()
source = _source()
event = MessageEvent(
text="hi",
source=source,
reply_to_message_id="42",
reply_to_text=None,
)
result = await runner._prepare_inbound_message_text(
event=event,
source=source,
history=[],
)
assert result == "hi"
@pytest.mark.asyncio
async def test_reply_snippet_truncated_to_500_chars():
runner = _make_runner()
source = _source()
long_text = "x" * 800
event = MessageEvent(
text="follow-up",
source=source,
reply_to_message_id="42",
reply_to_text=long_text,
)
result = await runner._prepare_inbound_message_text(
event=event,
source=source,
history=[],
)
assert result is not None
assert result.startswith('[Replying to: "' + "x" * 500 + '"]')
assert "x" * 501 not in result

View file

@ -1,6 +1,7 @@
import asyncio
import shutil
import subprocess
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
@ -8,7 +9,7 @@ import pytest
import gateway.run as gateway_run
from gateway.platforms.base import MessageEvent, MessageType
from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
from gateway.session import build_session_key
from gateway.session import SessionEntry, build_session_key
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
@ -242,3 +243,31 @@ async def test_shutdown_notification_send_failure_does_not_block():
# Should not raise
await runner._notify_active_sessions_of_shutdown()
@pytest.mark.asyncio
async def test_shutdown_notification_uses_persisted_origin_for_colon_ids():
"""Shutdown notifications should route from persisted origin, not reparsed keys."""
runner, adapter = make_restart_runner()
adapter.send = AsyncMock()
source = make_restart_source(chat_id="!room123:example.org", chat_type="group")
source.platform = gateway_run.Platform.MATRIX
session_key = build_session_key(source)
runner._running_agents[session_key] = MagicMock()
runner.session_store._entries = {
session_key: SessionEntry(
session_key=session_key,
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
origin=source,
platform=source.platform,
chat_type=source.chat_type,
)
}
runner.adapters = {gateway_run.Platform.MATRIX: adapter}
await runner._notify_active_sessions_of_shutdown()
assert adapter.send.await_count == 1
assert adapter.send.await_args.args[0] == "!room123:example.org"

View file

@ -51,6 +51,9 @@ class ProgressCaptureAdapter(BasePlatformAdapter):
async def send_typing(self, chat_id, metadata=None) -> None:
self.typing.append({"chat_id": chat_id, "metadata": metadata})
async def stop_typing(self, chat_id) -> None:
self.typing.append({"chat_id": chat_id, "metadata": {"stopped": True}})
async def get_chat_info(self, chat_id: str):
return {"id": chat_id}
@ -90,6 +93,40 @@ class LongPreviewAgent:
}
class DelayedProgressAgent:
def __init__(self, **kwargs):
self.tool_progress_callback = kwargs.get("tool_progress_callback")
self.tools = []
def run_conversation(self, message, conversation_history=None, task_id=None):
self.tool_progress_callback("tool.started", "terminal", "first command", {})
time.sleep(0.45)
self.tool_progress_callback("tool.started", "terminal", "second command", {})
time.sleep(0.1)
return {
"final_response": "done",
"messages": [],
"api_calls": 1,
}
class DelayedInterimAgent:
def __init__(self, **kwargs):
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
self.tools = []
def run_conversation(self, message, conversation_history=None, task_id=None):
self.interim_assistant_callback("first interim")
time.sleep(0.45)
self.interim_assistant_callback("second interim")
time.sleep(0.1)
return {
"final_response": "done",
"messages": [],
"api_calls": 1,
}
def _make_runner(adapter):
gateway_run = importlib.import_module("gateway.run")
GatewayRunner = gateway_run.GatewayRunner
@ -104,6 +141,7 @@ def _make_runner(adapter):
runner._fallback_model = None
runner._session_db = None
runner._running_agents = {}
runner._session_run_generation = {}
runner.hooks = SimpleNamespace(loaded_hooks=False)
runner.config = SimpleNamespace(
thread_sessions_per_user=False,
@ -744,6 +782,154 @@ async def test_base_processing_releases_post_delivery_callback_after_main_send()
assert released == [True]
@pytest.mark.asyncio
async def test_run_agent_drops_tool_progress_after_generation_invalidation(monkeypatch, tmp_path):
import yaml
(tmp_path / "config.yaml").write_text(
yaml.dump({"display": {"tool_progress": "all"}}),
encoding="utf-8",
)
fake_dotenv = types.ModuleType("dotenv")
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = DelayedProgressAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
import tools.terminal_tool # noqa: F401 - register terminal tool metadata
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
runner = _make_runner(adapter)
gateway_run = importlib.import_module("gateway.run")
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
source = SessionSource(
platform=Platform.DISCORD,
chat_id="dm-1",
chat_type="dm",
thread_id=None,
)
session_key = "agent:main:discord:dm:dm-1"
runner._session_run_generation[session_key] = 1
original_send = adapter.send
invalidated = {"done": False}
async def send_and_invalidate(chat_id, content, reply_to=None, metadata=None):
result = await original_send(chat_id, content, reply_to=reply_to, metadata=metadata)
if "first command" in content and not invalidated["done"]:
invalidated["done"] = True
runner._invalidate_session_run_generation(session_key, reason="test_stop")
return result
adapter.send = send_and_invalidate
result = await runner._run_agent(
message="hello",
context_prompt="",
history=[],
source=source,
session_id="sess-progress-stop",
session_key=session_key,
run_generation=1,
)
all_progress_text = " ".join(call["content"] for call in adapter.sent)
all_progress_text += " ".join(call["content"] for call in adapter.edits)
assert result["final_response"] == "done"
assert 'first command' in all_progress_text
assert 'second command' not in all_progress_text
@pytest.mark.asyncio
async def test_run_agent_drops_interim_commentary_after_generation_invalidation(monkeypatch, tmp_path):
import yaml
(tmp_path / "config.yaml").write_text(
yaml.dump({"display": {"tool_progress": "off", "interim_assistant_messages": True}}),
encoding="utf-8",
)
fake_dotenv = types.ModuleType("dotenv")
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = DelayedInterimAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
runner = _make_runner(adapter)
gateway_run = importlib.import_module("gateway.run")
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
source = SessionSource(
platform=Platform.DISCORD,
chat_id="dm-2",
chat_type="dm",
thread_id=None,
)
session_key = "agent:main:discord:dm:dm-2"
runner._session_run_generation[session_key] = 1
original_send = adapter.send
invalidated = {"done": False}
async def send_and_invalidate(chat_id, content, reply_to=None, metadata=None):
result = await original_send(chat_id, content, reply_to=reply_to, metadata=metadata)
if content == "first interim" and not invalidated["done"]:
invalidated["done"] = True
runner._invalidate_session_run_generation(session_key, reason="test_stop")
return result
adapter.send = send_and_invalidate
result = await runner._run_agent(
message="hello",
context_prompt="",
history=[],
source=source,
session_id="sess-commentary-stop",
session_key=session_key,
run_generation=1,
)
sent_texts = [call["content"] for call in adapter.sent]
assert result["final_response"] == "done"
assert "first interim" in sent_texts
assert "second interim" not in sent_texts
@pytest.mark.asyncio
async def test_keep_typing_stops_immediately_when_interrupt_event_is_set():
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
stop_event = asyncio.Event()
task = asyncio.create_task(
adapter._keep_typing(
"dm-typing-stop",
interval=30.0,
stop_event=stop_event,
)
)
await asyncio.sleep(0.05)
stop_event.set()
await asyncio.wait_for(task, timeout=0.5)
normal_typing_calls = [
call for call in adapter.typing if call.get("metadata") != {"stopped": True}
]
stopped_calls = [
call for call in adapter.typing if call.get("metadata") == {"stopped": True}
]
assert len(normal_typing_calls) == 1
assert len(stopped_calls) == 1
@pytest.mark.asyncio
async def test_verbose_mode_does_not_truncate_args_by_default(monkeypatch, tmp_path):
"""Verbose mode with default tool_preview_length (0) should NOT truncate args.

View file

@ -184,8 +184,15 @@ async def test_start_gateway_replace_force_uses_terminate_pid(monkeypatch, tmp_p
async def stop(self):
return None
monkeypatch.setattr("gateway.status.get_running_pid", lambda: 42)
monkeypatch.setattr("gateway.status.remove_pid_file", lambda: None)
# get_running_pid returns 42 before we kill the old gateway, then None
# after remove_pid_file() clears the record (reflects real behavior).
_pid_state = {"alive": True}
def _mock_get_running_pid():
return 42 if _pid_state["alive"] else None
def _mock_remove_pid_file():
_pid_state["alive"] = False
monkeypatch.setattr("gateway.status.get_running_pid", _mock_get_running_pid)
monkeypatch.setattr("gateway.status.remove_pid_file", _mock_remove_pid_file)
monkeypatch.setattr("gateway.status.release_all_scoped_locks", lambda: 0)
monkeypatch.setattr("gateway.status.terminate_pid", lambda pid, force=False: calls.append((pid, force)))
monkeypatch.setattr("gateway.run.os.getpid", lambda: 100)
@ -253,8 +260,13 @@ async def test_start_gateway_replace_writes_takeover_marker_before_sigterm(
async def stop(self):
return None
monkeypatch.setattr("gateway.status.get_running_pid", lambda: 42)
monkeypatch.setattr("gateway.status.remove_pid_file", lambda: None)
_pid_state = {"alive": True}
def _mock_get_running_pid():
return 42 if _pid_state["alive"] else None
def _mock_remove_pid_file():
_pid_state["alive"] = False
monkeypatch.setattr("gateway.status.get_running_pid", _mock_get_running_pid)
monkeypatch.setattr("gateway.status.remove_pid_file", _mock_remove_pid_file)
monkeypatch.setattr("gateway.status.release_all_scoped_locks", lambda: 0)
monkeypatch.setattr("gateway.status.write_takeover_marker", record_write_marker)
monkeypatch.setattr("gateway.status.terminate_pid", record_terminate)
@ -319,3 +331,23 @@ async def test_start_gateway_replace_clears_marker_on_permission_denied(
assert ok is False
# Marker must NOT be left behind
assert not (tmp_path / ".gateway-takeover.json").exists()
def test_runner_warns_when_docker_gateway_lacks_explicit_output_mount(monkeypatch, tmp_path, caplog):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.setenv("TERMINAL_ENV", "docker")
monkeypatch.setenv("TERMINAL_DOCKER_VOLUMES", '["/etc/localtime:/etc/localtime:ro"]')
config = GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")
},
sessions_dir=tmp_path / "sessions",
)
with caplog.at_level("WARNING"):
GatewayRunner(config)
assert any(
"host-visible output mount" in record.message
for record in caplog.records
)

View file

@ -0,0 +1,167 @@
"""Regression tests: /yolo and /verbose dispatch mid-agent-run.
When an agent is running, the gateway's running-agent guard rejects most
slash commands with "⏳ Agent is running — /{cmd} can't run mid-turn"
(PR #12334). A small allowlist bypasses that and actually dispatches:
* /yolo toggles the session yolo flag; useful to pre-approve a
pending approval prompt without waiting for the agent to finish.
* /verbose cycles the per-platform tool-progress display mode;
affects the ongoing stream.
Commands whose handlers say "takes effect on next message" stay on the
catch-all by design:
* /fast writes config.yaml only
* /reasoning writes config.yaml only
These tests lock in both behaviors so the allowlist doesn't silently
grow or shrink.
"""
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent
from gateway.session import SessionEntry, SessionSource, build_session_key
def _make_source() -> SessionSource:
return SessionSource(
platform=Platform.TELEGRAM,
user_id="u1",
chat_id="c1",
user_name="tester",
chat_type="dm",
)
def _make_event(text: str) -> MessageEvent:
return MessageEvent(text=text, source=_make_source(), message_id="m1")
def _make_runner():
"""Minimal GatewayRunner with an active running agent for this session."""
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
)
adapter = MagicMock()
adapter.send = AsyncMock()
runner.adapters = {Platform.TELEGRAM: adapter}
runner._voice_mode = {}
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
session_entry = SessionEntry(
session_key=build_session_key(_make_source()),
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
runner.session_store = MagicMock()
runner.session_store.get_or_create_session.return_value = session_entry
runner.session_store.load_transcript.return_value = []
runner.session_store.has_any_sessions.return_value = True
runner.session_store.append_to_transcript = MagicMock()
runner.session_store.rewrite_transcript = MagicMock()
runner.session_store.update_session = MagicMock()
runner._running_agents = {}
runner._running_agents_ts = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._session_db = None
runner._reasoning_config = None
runner._provider_routing = {}
runner._fallback_model = None
runner._show_reasoning = False
runner._service_tier = None
runner._is_user_authorized = lambda _source: True
runner._set_session_env = lambda _context: None
runner._should_send_voice_reply = lambda *_args, **_kwargs: False
runner._send_voice_reply = AsyncMock()
runner._capture_gateway_honcho_if_configured = lambda *args, **kwargs: None
runner._emit_gateway_run_progress = AsyncMock()
# Simulate agent actively running for this session so the guard fires.
# Note: the stale-eviction branch calls agent.get_activity_summary() and
# compares seconds_since_activity against HERMES_AGENT_TIMEOUT. Return a
# dict with recent activity so the eviction path doesn't clear our
# fake running agent before the toggle guard runs.
import time
sk = build_session_key(_make_source())
agent_mock = MagicMock()
agent_mock.get_activity_summary.return_value = {
"seconds_since_activity": 0.0,
"last_activity_desc": "api_call",
"api_call_count": 1,
"max_iterations": 60,
}
runner._running_agents[sk] = agent_mock
runner._running_agents_ts[sk] = time.time()
return runner
@pytest.mark.asyncio
async def test_yolo_dispatches_mid_run(monkeypatch):
"""/yolo mid-run must dispatch to its handler, not hit the catch-all."""
runner = _make_runner()
runner._handle_yolo_command = AsyncMock(return_value="⚡ YOLO mode **ON** for this session")
result = await runner._handle_message(_make_event("/yolo"))
runner._handle_yolo_command.assert_awaited_once()
assert result == "⚡ YOLO mode **ON** for this session"
assert "can't run mid-turn" not in (result or "")
@pytest.mark.asyncio
async def test_verbose_dispatches_mid_run(monkeypatch):
"""/verbose mid-run must dispatch to its handler, not hit the catch-all."""
runner = _make_runner()
runner._handle_verbose_command = AsyncMock(return_value="tool progress: new")
result = await runner._handle_message(_make_event("/verbose"))
runner._handle_verbose_command.assert_awaited_once()
assert result == "tool progress: new"
assert "can't run mid-turn" not in (result or "")
@pytest.mark.asyncio
async def test_fast_rejected_mid_run():
"""/fast mid-run must hit the busy catch-all — config-only, next message."""
runner = _make_runner()
runner._handle_fast_command = AsyncMock(
side_effect=AssertionError("/fast should not dispatch mid-run")
)
result = await runner._handle_message(_make_event("/fast"))
runner._handle_fast_command.assert_not_awaited()
assert result is not None
assert "can't run mid-turn" in result
assert "/fast" in result
@pytest.mark.asyncio
async def test_reasoning_rejected_mid_run():
"""/reasoning mid-run must hit the busy catch-all — config-only, next message."""
runner = _make_runner()
runner._handle_reasoning_command = AsyncMock(
side_effect=AssertionError("/reasoning should not dispatch mid-run")
)
result = await runner._handle_message(_make_event("/reasoning high"))
runner._handle_reasoning_command.assert_not_awaited()
assert result is not None
assert "can't run mid-turn" in result
assert "/reasoning" in result

View file

@ -356,6 +356,28 @@ class TestBuildSessionContextPrompt:
assert "**User:** Alice" in prompt
assert "Multi-user thread" not in prompt
def test_shared_non_thread_group_prompt_hides_single_user(self):
"""Shared non-thread group sessions should avoid pinning one user."""
config = GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake"),
},
group_sessions_per_user=False,
)
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="-1002285219667",
chat_name="Test Group",
chat_type="group",
user_name="Alice",
)
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "Multi-user session" in prompt
assert "[sender name]" in prompt
assert "**User:** Alice" not in prompt
def test_dm_thread_shows_user_not_multi(self):
"""DM threads are single-user and should show User, not multi-user note."""
config = GatewayConfig(
@ -1037,6 +1059,7 @@ class TestRewriteTranscriptPreservesReasoning:
role="assistant",
content="The answer is 42.",
reasoning="I need to think step by step.",
reasoning_content="provider scratchpad",
reasoning_details=[{"type": "summary", "text": "step by step"}],
codex_reasoning_items=[{"id": "r1", "type": "reasoning"}],
)
@ -1044,6 +1067,7 @@ class TestRewriteTranscriptPreservesReasoning:
# Verify all three were stored
before = db.get_messages_as_conversation(session_id)
assert before[0].get("reasoning") == "I need to think step by step."
assert before[0].get("reasoning_content") == "provider scratchpad"
assert before[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}]
assert before[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}]
@ -1060,5 +1084,6 @@ class TestRewriteTranscriptPreservesReasoning:
# Load again — all three reasoning fields must survive
after = db.get_messages_as_conversation(session_id)
assert after[0].get("reasoning") == "I need to think step by step."
assert after[0].get("reasoning_content") == "provider scratchpad"
assert after[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}]
assert after[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}]

View file

@ -0,0 +1,76 @@
"""Regression tests for the TUI gateway's ``session.list`` handler.
Reported during TUI v2 blitz retest: the ``/resume`` modal inside a TUI
session only surfaced ``tui``/``cli`` rows, hiding telegram sessions users
could still resume directly via ``hermes --tui --resume <id>``.
The fix widens the picker to a curated allowlist of user-facing sources
(tui/cli + chat adapters) while still filtering internal/system sources.
"""
from __future__ import annotations
from tui_gateway import server
class _StubDB:
def __init__(self, rows):
self.rows = rows
self.calls: list[dict] = []
def list_sessions_rich(self, **kwargs):
self.calls.append(kwargs)
return list(self.rows)
def _call(limit: int = 20):
return server.handle_request({
"id": "1",
"method": "session.list",
"params": {"limit": limit},
})
def test_session_list_includes_telegram_but_filters_internal_sources(monkeypatch):
rows = [
{"id": "tui-1", "source": "tui", "started_at": 9},
{"id": "tool-1", "source": "tool", "started_at": 8},
{"id": "tg-1", "source": "telegram", "started_at": 7},
{"id": "acp-1", "source": "acp", "started_at": 6},
{"id": "cli-1", "source": "cli", "started_at": 5},
]
db = _StubDB(rows)
monkeypatch.setattr(server, "_get_db", lambda: db)
resp = _call(limit=10)
sessions = resp["result"]["sessions"]
ids = [s["id"] for s in sessions]
assert "tg-1" in ids and "tui-1" in ids and "cli-1" in ids, ids
assert "tool-1" not in ids and "acp-1" not in ids, ids
def test_session_list_fetches_wider_window_before_filtering(monkeypatch):
db = _StubDB([{"id": "x", "source": "cli", "started_at": 1}])
monkeypatch.setattr(server, "_get_db", lambda: db)
_call(limit=10)
assert len(db.calls) == 1
assert db.calls[0].get("source") is None, db.calls[0]
assert db.calls[0].get("limit") == 100, db.calls[0]
def test_session_list_preserves_ordering_after_filter(monkeypatch):
rows = [
{"id": "newest", "source": "telegram", "started_at": 5},
{"id": "internal", "source": "tool", "started_at": 4},
{"id": "middle", "source": "tui", "started_at": 3},
{"id": "oldest", "source": "discord", "started_at": 1},
]
monkeypatch.setattr(server, "_get_db", lambda: _StubDB(rows))
resp = _call()
ids = [s["id"] for s in resp["result"]["sessions"]]
assert ids == ["newest", "middle", "oldest"]

View file

@ -24,10 +24,18 @@ class _FakeAdapter:
def __init__(self):
self._pending_messages = {}
self._active_sessions = {}
self.interrupted_sessions = []
async def send(self, chat_id, text, **kwargs):
pass
async def interrupt_session_activity(self, session_key, chat_id):
self.interrupted_sessions.append((session_key, chat_id))
event = self._active_sessions.get(session_key)
if event is not None:
event.set()
def _make_runner():
runner = object.__new__(GatewayRunner)
@ -37,6 +45,7 @@ def _make_runner():
runner.adapters = {Platform.TELEGRAM: _FakeAdapter()}
runner._running_agents = {}
runner._running_agents_ts = {}
runner._session_run_generation = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._voice_mode = {}
@ -81,7 +90,7 @@ async def test_sentinel_placed_before_agent_setup():
# Patch _handle_message_with_agent to capture state at entry
sentinel_was_set = False
async def mock_inner(self_inner, ev, src, qk):
async def mock_inner(self_inner, ev, src, qk, generation):
nonlocal sentinel_was_set
sentinel_was_set = runner._running_agents.get(qk) is _AGENT_PENDING_SENTINEL
return "ok"
@ -105,7 +114,7 @@ async def test_sentinel_cleaned_up_after_handler_returns():
event = _make_event()
session_key = build_session_key(event.source)
async def mock_inner(self_inner, ev, src, qk):
async def mock_inner(self_inner, ev, src, qk, generation):
return "ok"
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
@ -127,7 +136,7 @@ async def test_sentinel_cleaned_up_on_exception():
event = _make_event()
session_key = build_session_key(event.source)
async def mock_inner(self_inner, ev, src, qk):
async def mock_inner(self_inner, ev, src, qk, generation):
raise RuntimeError("boom")
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
@ -154,7 +163,7 @@ async def test_second_message_during_sentinel_queued_not_duplicate():
barrier = asyncio.Event()
async def slow_inner(self_inner, ev, src, qk):
async def slow_inner(self_inner, ev, src, qk, generation):
# Simulate slow setup — wait until test tells us to proceed
await barrier.wait()
return "ok"
@ -333,7 +342,7 @@ async def test_stop_during_sentinel_force_cleans_session():
barrier = asyncio.Event()
async def slow_inner(self_inner, ev, src, qk):
async def slow_inner(self_inner, ev, src, qk, generation):
await barrier.wait()
return "ok"
@ -381,6 +390,7 @@ async def test_stop_hard_kills_running_agent():
fake_agent = MagicMock()
fake_agent.get_activity_summary.return_value = {"seconds_since_activity": 0}
runner._running_agents[session_key] = fake_agent
runner.adapters[Platform.TELEGRAM]._active_sessions[session_key] = asyncio.Event()
# Send /stop
stop_event = _make_event(text="/stop")
@ -393,6 +403,10 @@ async def test_stop_hard_kills_running_agent():
assert session_key not in runner._running_agents, (
"/stop must remove the agent from _running_agents so the session is unlocked"
)
assert runner.adapters[Platform.TELEGRAM].interrupted_sessions == [
(session_key, "12345")
]
assert runner.adapters[Platform.TELEGRAM]._active_sessions[session_key].is_set()
# Must return a confirmation
assert result is not None

View file

@ -117,11 +117,20 @@ class TestPruneBasics:
assert "idle" not in store._entries
def test_prune_skips_entries_with_active_processes(self, tmp_path):
"""Sessions with active bg processes aren't pruned even if old."""
active_session_ids = {"sid_active"}
"""Sessions with active bg processes aren't pruned even if old.
def _has_active(session_id: str) -> bool:
return session_id in active_session_ids
The callback is keyed by session_key matching what
process_registry.has_active_for_session() actually consumes in
gateway/run.py. Prior to the fix this test passed the callback a
session_id, which silently matched an implementation bug where
prune_old_entries was also passing session_id; real-world usage
(via process_registry) takes a session_key and never matched, so
active sessions were still being pruned.
"""
active_session_keys = {"active"}
def _has_active(session_key: str) -> bool:
return session_key in active_session_keys
store = _make_store(tmp_path, has_active_processes_fn=_has_active)
store._entries["active"] = _entry(
@ -137,6 +146,26 @@ class TestPruneBasics:
assert "active" in store._entries
assert "idle" not in store._entries
def test_prune_active_check_uses_session_key_not_session_id(self, tmp_path):
"""Regression guard: a callback that only recognises session_ids must
NOT protect entries during prune. This pins the fix so a future
refactor can't silently revert to passing session_id again.
"""
def _recognises_only_ids(identifier: str) -> bool:
return identifier.startswith("sid_")
store = _make_store(tmp_path, has_active_processes_fn=_recognises_only_ids)
store._entries["active"] = _entry(
"active", age_days=1000, session_id="sid_active"
)
removed = store.prune_old_entries(max_age_days=90)
# Entry is pruned because the callback receives "active" (session_key),
# not "sid_active" (session_id), so _recognises_only_ids returns False.
assert removed == 1
assert "active" not in store._entries
def test_prune_does_not_write_disk_when_no_removals(self, tmp_path):
"""If nothing is evictable, _save() should NOT be called."""
store = _make_store(tmp_path)

View file

@ -0,0 +1,70 @@
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent
from gateway.run import GatewayRunner
from gateway.session import SessionSource
def _make_runner(config: GatewayConfig) -> GatewayRunner:
runner = object.__new__(GatewayRunner)
runner.config = config
runner.adapters = {}
runner._model = "openai/gpt-4.1-mini"
runner._base_url = None
return runner
@pytest.mark.asyncio
async def test_preprocess_prefixes_sender_for_shared_non_thread_group_session():
runner = _make_runner(
GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake"),
},
group_sessions_per_user=False,
)
)
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="-1002285219667",
chat_name="Test Group",
chat_type="group",
user_name="Alice",
)
event = MessageEvent(text="hello", source=source)
result = await runner._prepare_inbound_message_text(
event=event,
source=source,
history=[],
)
assert result == "[Alice] hello"
@pytest.mark.asyncio
async def test_preprocess_keeps_plain_text_for_default_group_sessions():
runner = _make_runner(
GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake"),
},
)
)
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="-1002285219667",
chat_name="Test Group",
chat_type="group",
user_name="Alice",
)
event = MessageEvent(text="hello", source=source)
result = await runner._prepare_inbound_message_text(
event=event,
source=source,
history=[],
)
assert result == "hello"

View file

@ -91,6 +91,29 @@ class TestSignalAdapterInit:
assert adapter._account_normalized == "+15551234567"
class TestSignalConnectCleanup:
"""Regression coverage for failed connect() cleanup."""
@pytest.mark.asyncio
async def test_releases_lock_and_closes_client_on_healthcheck_failure(self, monkeypatch):
adapter = _make_signal_adapter(monkeypatch)
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=MagicMock(status_code=503))
mock_client.aclose = AsyncMock()
with patch("gateway.platforms.signal.httpx.AsyncClient", return_value=mock_client), \
patch("gateway.status.acquire_scoped_lock", return_value=(True, None)), \
patch("gateway.status.release_scoped_lock") as mock_release:
result = await adapter.connect()
assert result is False
mock_client.aclose.assert_awaited_once()
mock_release.assert_called_once_with("signal-phone", "+15551234567")
assert adapter.client is None
assert adapter._platform_lock_identity is None
class TestSignalHelpers:
def test_redact_phone_long(self):
from gateway.platforms.helpers import redact_phone
@ -283,7 +306,13 @@ class TestSignalSessionSource:
class TestSignalPhoneRedaction:
@pytest.fixture(autouse=True)
def _ensure_redaction_enabled(self, monkeypatch):
# agent.redact snapshots _REDACT_ENABLED at import time from the
# HERMES_REDACT_SECRETS env var. monkeypatch.delenv is too late —
# the module was already imported during test collection with
# whatever value was in the env then. Force the flag directly.
# See skill: xdist-cross-test-pollution Pattern 5.
monkeypatch.delenv("HERMES_REDACT_SECRETS", raising=False)
monkeypatch.setattr("agent.redact._REDACT_ENABLED", True)
def test_us_number(self):
from agent.redact import redact_sensitive_text
@ -438,6 +467,97 @@ class TestSignalSendImageFile:
assert "failed" in result.error.lower()
class TestSignalRecipientResolution:
@pytest.mark.asyncio
async def test_send_prefers_cached_uuid_for_direct_messages(self, monkeypatch):
adapter = _make_signal_adapter(monkeypatch)
adapter._stop_typing_indicator = AsyncMock()
adapter._remember_recipient_identifiers("+15551230000", "68680952-6d86-45bc-85e0-1a4d186d53ee")
captured = []
async def mock_rpc(method, params, rpc_id=None, **kwargs):
captured.append({"method": method, "params": dict(params)})
return {"timestamp": 1234567890}
adapter._rpc = mock_rpc
result = await adapter.send(chat_id="+15551230000", content="hello")
assert result.success is True
assert captured[0]["method"] == "send"
assert captured[0]["params"]["recipient"] == ["68680952-6d86-45bc-85e0-1a4d186d53ee"]
@pytest.mark.asyncio
async def test_send_looks_up_uuid_via_list_contacts(self, monkeypatch):
adapter = _make_signal_adapter(monkeypatch)
adapter._stop_typing_indicator = AsyncMock()
captured = []
async def mock_rpc(method, params, rpc_id=None, **kwargs):
captured.append({"method": method, "params": dict(params)})
if method == "listContacts":
return [{
"recipient": "351935789098",
"number": "+15551230000",
"uuid": "68680952-6d86-45bc-85e0-1a4d186d53ee",
"isRegistered": True,
}]
if method == "send":
return {"timestamp": 1234567890}
return None
adapter._rpc = mock_rpc
result = await adapter.send(chat_id="+15551230000", content="hello")
assert result.success is True
assert captured[0]["method"] == "listContacts"
assert captured[1]["method"] == "send"
assert captured[1]["params"]["recipient"] == ["68680952-6d86-45bc-85e0-1a4d186d53ee"]
@pytest.mark.asyncio
async def test_send_falls_back_to_phone_when_no_uuid_found(self, monkeypatch):
adapter = _make_signal_adapter(monkeypatch)
adapter._stop_typing_indicator = AsyncMock()
captured = []
async def mock_rpc(method, params, rpc_id=None, **kwargs):
captured.append({"method": method, "params": dict(params)})
if method == "listContacts":
return []
if method == "send":
return {"timestamp": 1234567890}
return None
adapter._rpc = mock_rpc
result = await adapter.send(chat_id="+15551230000", content="hello")
assert result.success is True
assert captured[1]["params"]["recipient"] == ["+15551230000"]
@pytest.mark.asyncio
async def test_send_typing_uses_cached_uuid(self, monkeypatch):
adapter = _make_signal_adapter(monkeypatch)
adapter._remember_recipient_identifiers("+15551230000", "68680952-6d86-45bc-85e0-1a4d186d53ee")
captured = []
async def mock_rpc(method, params, rpc_id=None, **kwargs):
captured.append({"method": method, "params": dict(params), "rpc_id": rpc_id})
return {}
adapter._rpc = mock_rpc
await adapter.send_typing("+15551230000")
assert captured[0]["method"] == "sendTyping"
assert captured[0]["params"]["recipient"] == ["68680952-6d86-45bc-85e0-1a4d186d53ee"]
# ---------------------------------------------------------------------------
# send_voice method (#5105)
# ---------------------------------------------------------------------------

View file

@ -150,6 +150,31 @@ class TestAppMentionHandler:
assert "/hermes" in registered_commands
class TestSlackConnectCleanup:
"""Regression coverage for failed connect() cleanup."""
@pytest.mark.asyncio
async def test_releases_platform_lock_when_auth_fails(self):
config = PlatformConfig(enabled=True, token="xoxb-fake")
adapter = SlackAdapter(config)
mock_app = MagicMock()
mock_web_client = AsyncMock()
mock_web_client.auth_test = AsyncMock(side_effect=RuntimeError("boom"))
with patch.object(_slack_mod, "AsyncApp", return_value=mock_app), \
patch.object(_slack_mod, "AsyncWebClient", return_value=mock_web_client), \
patch.object(_slack_mod, "AsyncSocketModeHandler", return_value=MagicMock()), \
patch.dict(os.environ, {"SLACK_APP_TOKEN": "xapp-fake"}), \
patch("gateway.status.acquire_scoped_lock", return_value=(True, None)), \
patch("gateway.status.release_scoped_lock") as mock_release:
result = await adapter.connect()
assert result is False
mock_release.assert_called_once_with("slack-app-token", "xapp-fake")
assert adapter._platform_lock_identity is None
# ---------------------------------------------------------------------------
# TestSendDocument
# ---------------------------------------------------------------------------
@ -1006,7 +1031,7 @@ class TestReactions:
@pytest.mark.asyncio
async def test_reactions_in_message_flow(self, adapter):
"""Reactions should be added on receipt and swapped on completion."""
"""Reactions should be bracketed around actual processing via hooks."""
adapter._app.client.reactions_add = AsyncMock()
adapter._app.client.reactions_remove = AsyncMock()
adapter._app.client.users_info = AsyncMock(return_value={
@ -1022,15 +1047,147 @@ class TestReactions:
}
await adapter._handle_slack_message(event)
# Should have added 👀, then removed 👀, then added ✅
# _handle_slack_message should register the message for reactions
assert "1234567890.000001" in adapter._reacting_message_ids
# Simulate the base class calling on_processing_start
from gateway.platforms.base import MessageEvent, MessageType, SessionSource
from gateway.config import Platform
source = SessionSource(
platform=Platform.SLACK,
chat_id="C123",
chat_type="dm",
user_id="U_USER",
)
msg_event = MessageEvent(
text="hello",
message_type=MessageType.TEXT,
source=source,
message_id="1234567890.000001",
)
await adapter.on_processing_start(msg_event)
add_calls = adapter._app.client.reactions_add.call_args_list
assert len(add_calls) == 1
assert add_calls[0].kwargs["name"] == "eyes"
# Simulate the base class calling on_processing_complete
from gateway.platforms.base import ProcessingOutcome
await adapter.on_processing_complete(msg_event, ProcessingOutcome.SUCCESS)
add_calls = adapter._app.client.reactions_add.call_args_list
remove_calls = adapter._app.client.reactions_remove.call_args_list
assert len(add_calls) == 2
assert add_calls[0].kwargs["name"] == "eyes"
assert add_calls[1].kwargs["name"] == "white_check_mark"
assert len(remove_calls) == 1
assert remove_calls[0].kwargs["name"] == "eyes"
# Message ID should be cleaned up
assert "1234567890.000001" not in adapter._reacting_message_ids
@pytest.mark.asyncio
async def test_reactions_failure_outcome(self, adapter):
"""Failed processing should add :x: instead of :white_check_mark:."""
adapter._app.client.reactions_add = AsyncMock()
adapter._app.client.reactions_remove = AsyncMock()
from gateway.platforms.base import MessageEvent, MessageType, SessionSource, ProcessingOutcome
from gateway.config import Platform
source = SessionSource(
platform=Platform.SLACK,
chat_id="C123",
chat_type="dm",
user_id="U_USER",
)
adapter._reacting_message_ids.add("1234567890.000002")
msg_event = MessageEvent(
text="hello",
message_type=MessageType.TEXT,
source=source,
message_id="1234567890.000002",
)
await adapter.on_processing_complete(msg_event, ProcessingOutcome.FAILURE)
add_calls = adapter._app.client.reactions_add.call_args_list
remove_calls = adapter._app.client.reactions_remove.call_args_list
assert len(add_calls) == 1
assert add_calls[0].kwargs["name"] == "x"
assert len(remove_calls) == 1
assert remove_calls[0].kwargs["name"] == "eyes"
@pytest.mark.asyncio
async def test_reactions_skipped_for_non_dm_non_mention(self, adapter):
"""Non-DM, non-mention messages should not get reactions."""
adapter._app.client.reactions_add = AsyncMock()
adapter._app.client.reactions_remove = AsyncMock()
adapter._app.client.users_info = AsyncMock(return_value={
"user": {"profile": {"display_name": "Tyler"}}
})
event = {
"text": "hello",
"user": "U_USER",
"channel": "C123",
"channel_type": "channel",
"ts": "1234567890.000003",
}
await adapter._handle_slack_message(event)
# Should NOT register for reactions when not mentioned in a channel
assert "1234567890.000003" not in adapter._reacting_message_ids
adapter._app.client.reactions_add.assert_not_called()
adapter._app.client.reactions_remove.assert_not_called()
@pytest.mark.asyncio
async def test_reactions_disabled_via_env(self, adapter, monkeypatch):
"""SLACK_REACTIONS=false should suppress all reaction lifecycle."""
monkeypatch.setenv("SLACK_REACTIONS", "false")
adapter._app.client.reactions_add = AsyncMock()
adapter._app.client.reactions_remove = AsyncMock()
adapter._app.client.users_info = AsyncMock(return_value={
"user": {"profile": {"display_name": "Tyler"}}
})
event = {
"text": "hello",
"user": "U_USER",
"channel": "C123",
"channel_type": "im",
"ts": "1234567890.000004",
}
await adapter._handle_slack_message(event)
# Should NOT register for reactions when toggle is off
assert "1234567890.000004" not in adapter._reacting_message_ids
# Hooks should also be no-ops when disabled
from gateway.platforms.base import MessageEvent, MessageType, SessionSource, ProcessingOutcome
from gateway.config import Platform
source = SessionSource(
platform=Platform.SLACK,
chat_id="C123",
chat_type="dm",
user_id="U_USER",
)
msg_event = MessageEvent(
text="hello",
message_type=MessageType.TEXT,
source=source,
message_id="1234567890.000004",
)
# Force-add to verify hooks respect the toggle independently
adapter._reacting_message_ids.add("1234567890.000004")
await adapter.on_processing_start(msg_event)
await adapter.on_processing_complete(msg_event, ProcessingOutcome.SUCCESS)
adapter._app.client.reactions_add.assert_not_called()
adapter._app.client.reactions_remove.assert_not_called()
@pytest.mark.asyncio
async def test_reactions_enabled_by_default(self, adapter):
"""SLACK_REACTIONS defaults to true (matches existing behavior)."""
assert adapter._reactions_enabled() is True
# ---------------------------------------------------------------------------
# TestThreadReplyHandling

View file

@ -19,6 +19,30 @@ class TestGatewayPidState:
assert isinstance(payload["argv"], list)
assert payload["argv"]
def test_write_pid_file_is_atomic_against_concurrent_writers(self, tmp_path, monkeypatch):
"""Regression: two concurrent --replace invocations must not both win.
Without O_CREAT|O_EXCL, two processes racing through start_gateway()'s
termination-wait would both write to gateway.pid, silently overwriting
each other and leaving multiple gateway instances alive (#11718).
"""
import pytest
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
# First write wins.
status.write_pid_file()
assert (tmp_path / "gateway.pid").exists()
# Second write (simulating a racing --replace that missed the earlier
# guards) must raise FileExistsError rather than clobber the record.
with pytest.raises(FileExistsError):
status.write_pid_file()
# Original record is preserved.
payload = json.loads((tmp_path / "gateway.pid").read_text())
assert payload["pid"] == os.getpid()
def test_get_running_pid_rejects_live_non_gateway_pid(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
pid_path = tmp_path / "gateway.pid"

View file

@ -50,6 +50,7 @@ def _make_runner(session_entry: SessionEntry):
runner.session_store.rewrite_transcript = MagicMock()
runner.session_store.update_session = MagicMock()
runner._running_agents = {}
runner._session_run_generation = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._session_db = MagicMock()
@ -223,6 +224,121 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch):
)
@pytest.mark.asyncio
async def test_handle_message_discards_stale_result_after_session_invalidation(monkeypatch):
import gateway.run as gateway_run
session_entry = SessionEntry(
session_key=build_session_key(_make_source()),
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
runner = _make_runner(session_entry)
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
session_key = session_entry.session_key
runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks = {session_key: object()}
async def _stale_result(**kwargs):
runner._invalidate_session_run_generation(kwargs["session_key"], reason="test_stale_result")
return {
"final_response": "late reply",
"messages": [],
"tools": [],
"history_offset": 0,
"last_prompt_tokens": 80,
"input_tokens": 120,
"output_tokens": 45,
"model": "openai/test-model",
}
runner._run_agent = AsyncMock(side_effect=_stale_result)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
monkeypatch.setattr(
"agent.model_metadata.get_model_context_length",
lambda *_args, **_kwargs: 100000,
)
result = await runner._handle_message(_make_event("hello"))
assert result is None
runner.session_store.append_to_transcript.assert_not_called()
runner.session_store.update_session.assert_not_called()
assert session_key not in runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks
@pytest.mark.asyncio
async def test_handle_message_stale_result_keeps_newer_generation_callback(monkeypatch):
import gateway.run as gateway_run
class _Adapter:
def __init__(self):
self._post_delivery_callbacks = {}
async def send(self, *args, **kwargs):
return None
def pop_post_delivery_callback(self, session_key, *, generation=None):
entry = self._post_delivery_callbacks.get(session_key)
if entry is None:
return None
if isinstance(entry, tuple):
entry_generation, callback = entry
if generation is not None and entry_generation != generation:
return None
self._post_delivery_callbacks.pop(session_key, None)
return callback
if generation is not None:
return None
return self._post_delivery_callbacks.pop(session_key, None)
session_entry = SessionEntry(
session_key=build_session_key(_make_source()),
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
runner = _make_runner(session_entry)
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
session_key = session_entry.session_key
adapter = _Adapter()
runner.adapters[Platform.TELEGRAM] = adapter
async def _stale_result(**kwargs):
# Simulate a newer run claiming the callback slot before the stale run unwinds.
runner._session_run_generation[session_key] = 2
adapter._post_delivery_callbacks[session_key] = (2, lambda: None)
return {
"final_response": "late reply",
"messages": [],
"tools": [],
"history_offset": 0,
"last_prompt_tokens": 80,
"input_tokens": 120,
"output_tokens": 45,
"model": "openai/test-model",
}
runner._run_agent = AsyncMock(side_effect=_stale_result)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
monkeypatch.setattr(
"agent.model_metadata.get_model_context_length",
lambda *_args, **_kwargs: 100000,
)
result = await runner._handle_message(_make_event("hello"))
assert result is None
assert session_key in adapter._post_delivery_callbacks
assert adapter._post_delivery_callbacks[session_key][0] == 2
@pytest.mark.asyncio
async def test_status_command_bypasses_active_session_guard():

View file

@ -133,6 +133,43 @@ class TestFinalizeCapabilityGate:
assert picky.edit_message.call_args[1]["finalize"] is True
class TestEditMessageFinalizeSignature:
"""Every concrete platform adapter must accept the ``finalize`` kwarg.
stream_consumer._send_or_edit always passes ``finalize=`` to
``adapter.edit_message(...)`` (see gateway/stream_consumer.py). An
adapter that overrides edit_message without accepting finalize raises
TypeError the first time streaming hits a segment break or final edit.
Guard the contract with an explicit signature check so it cannot
silently regress existing tests use MagicMock which swallows any
kwarg and cannot catch this.
"""
@pytest.mark.parametrize(
"module_path,class_name",
[
("gateway.platforms.telegram", "TelegramAdapter"),
("gateway.platforms.discord", "DiscordAdapter"),
("gateway.platforms.slack", "SlackAdapter"),
("gateway.platforms.matrix", "MatrixAdapter"),
("gateway.platforms.mattermost", "MattermostAdapter"),
("gateway.platforms.feishu", "FeishuAdapter"),
("gateway.platforms.whatsapp", "WhatsAppAdapter"),
("gateway.platforms.dingtalk", "DingTalkAdapter"),
],
)
def test_edit_message_accepts_finalize(self, module_path, class_name):
import inspect
module = pytest.importorskip(module_path)
cls = getattr(module, class_name)
params = inspect.signature(cls.edit_message).parameters
assert "finalize" in params, (
f"{class_name}.edit_message must accept 'finalize' kwarg; "
f"stream_consumer._send_or_edit passes it unconditionally"
)
class TestSendOrEditMediaStripping:
"""Verify _send_or_edit strips MEDIA: before sending to the platform."""
@ -502,11 +539,13 @@ class TestSegmentBreakOnToolBoundary:
@pytest.mark.asyncio
async def test_segment_break_clears_failed_edit_fallback_state(self):
"""A tool boundary after edit failure must not duplicate the next segment."""
"""A tool boundary after edit failure must flush the undelivered tail
without duplicating the prefix the user already saw (#8124)."""
adapter = MagicMock()
send_results = [
SimpleNamespace(success=True, message_id="msg_1"),
SimpleNamespace(success=True, message_id="msg_2"),
SimpleNamespace(success=True, message_id="msg_3"),
]
adapter.send = AsyncMock(side_effect=send_results)
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=False, error="flood_control:6"))
@ -526,7 +565,60 @@ class TestSegmentBreakOnToolBoundary:
await task
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
assert sent_texts == ["Hello ▉", "Next segment"]
# The undelivered "world" tail must reach the user, and the next
# segment must not duplicate "Hello" that was already visible.
assert sent_texts == ["Hello ▉", "world", "Next segment"]
@pytest.mark.asyncio
async def test_segment_break_after_mid_stream_edit_failure_preserves_tail(self):
"""Regression for #8124: when an earlier edit succeeded but later edits
fail (persistent flood control) and a tool boundary arrives before the
fallback threshold is reached, the pre-boundary tail must still be
delivered not silently dropped by the segment reset."""
adapter = MagicMock()
# msg_1 for the initial partial, msg_2 for the flushed tail,
# msg_3 for the post-boundary segment.
send_results = [
SimpleNamespace(success=True, message_id="msg_1"),
SimpleNamespace(success=True, message_id="msg_2"),
SimpleNamespace(success=True, message_id="msg_3"),
]
adapter.send = AsyncMock(side_effect=send_results)
# First two edits succeed, everything after fails with flood control
# — simulating Telegram's "edit once then get rate-limited" pattern.
edit_results = [
SimpleNamespace(success=True), # "Hello world ▉" — succeeds
SimpleNamespace(success=False, error="flood_control:6.0"), # "Hello world more ▉" — flood triggered
SimpleNamespace(success=False, error="flood_control:6.0"), # finalize edit at segment break
SimpleNamespace(success=False, error="flood_control:6.0"), # cursor-strip attempt
]
adapter.edit_message = AsyncMock(side_effect=edit_results + [edit_results[-1]] * 10)
adapter.MAX_MESSAGE_LENGTH = 4096
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor="")
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
consumer.on_delta("Hello")
task = asyncio.create_task(consumer.run())
await asyncio.sleep(0.08)
consumer.on_delta(" world")
await asyncio.sleep(0.08)
consumer.on_delta(" more")
await asyncio.sleep(0.08)
consumer.on_delta(None) # tool boundary
consumer.on_delta("Here is the tool result.")
consumer.finish()
await task
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
# "more" must have been delivered, not dropped.
all_text = " ".join(sent_texts)
assert "more" in all_text, (
f"Pre-boundary tail 'more' was silently dropped: sends={sent_texts}"
)
# Post-boundary text must also reach the user.
assert "Here is the tool result." in all_text
@pytest.mark.asyncio
async def test_no_message_id_enters_fallback_mode(self):
@ -1161,3 +1253,87 @@ class TestBufferOnlyMode:
# text, the consumer may send then edit, or just send once at got_done.
# The key assertion: this doesn't break.
assert adapter.send.call_count >= 1
# ── Cursor stripping on fallback (#7183) ────────────────────────────────────
class TestCursorStrippingOnFallback:
"""Regression: cursor must be stripped when fallback continuation is empty (#7183).
When _send_fallback_final is called with nothing new to deliver (the visible
partial already matches final_text), the last edit may still show the cursor
character because fallback mode was entered after a failed edit. Before the
fix this would leave the message permanently frozen with a visible .
"""
@pytest.mark.asyncio
async def test_cursor_stripped_when_continuation_empty(self):
"""_send_fallback_final must attempt a final edit to strip the cursor."""
adapter = MagicMock()
adapter.MAX_MESSAGE_LENGTH = 4096
adapter.edit_message = AsyncMock(
return_value=SimpleNamespace(success=True, message_id="msg-1")
)
consumer = GatewayStreamConsumer(
adapter, "chat-1",
config=StreamConsumerConfig(cursor=""),
)
consumer._message_id = "msg-1"
consumer._last_sent_text = "Hello world ▉"
consumer._fallback_final_send = False
await consumer._send_fallback_final("Hello world")
adapter.edit_message.assert_called_once()
call_args = adapter.edit_message.call_args
assert call_args.kwargs["content"] == "Hello world"
assert consumer._already_sent is True
# _last_sent_text should reflect the cleaned text after a successful strip
assert consumer._last_sent_text == "Hello world"
@pytest.mark.asyncio
async def test_cursor_not_stripped_when_no_cursor_configured(self):
"""No edit attempted when cursor is not configured."""
adapter = MagicMock()
adapter.MAX_MESSAGE_LENGTH = 4096
adapter.edit_message = AsyncMock()
consumer = GatewayStreamConsumer(
adapter, "chat-1",
config=StreamConsumerConfig(cursor=""),
)
consumer._message_id = "msg-1"
consumer._last_sent_text = "Hello world"
consumer._fallback_final_send = False
await consumer._send_fallback_final("Hello world")
adapter.edit_message.assert_not_called()
assert consumer._already_sent is True
@pytest.mark.asyncio
async def test_cursor_strip_edit_failure_handled(self):
"""If the cursor-stripping edit itself fails, it must not crash and
must not corrupt _last_sent_text."""
adapter = MagicMock()
adapter.MAX_MESSAGE_LENGTH = 4096
adapter.edit_message = AsyncMock(
return_value=SimpleNamespace(success=False, error="flood_control")
)
consumer = GatewayStreamConsumer(
adapter, "chat-1",
config=StreamConsumerConfig(cursor=""),
)
consumer._message_id = "msg-1"
consumer._last_sent_text = "Hello ▉"
consumer._fallback_final_send = False
await consumer._send_fallback_final("Hello")
# Should still set already_sent despite the cursor-strip edit failure
assert consumer._already_sent is True
# _last_sent_text must NOT be updated when the edit failed
assert consumer._last_sent_text == "Hello ▉"

View file

@ -23,6 +23,7 @@ from gateway.platforms.base import (
MessageType,
SendResult,
SUPPORTED_DOCUMENT_TYPES,
SUPPORTED_VIDEO_TYPES,
)
@ -117,6 +118,12 @@ def _make_update(msg):
return update
def _make_video(file_obj=None):
video = MagicMock()
video.get_file = AsyncMock(return_value=file_obj or _make_file_obj(b"video-bytes"))
return video
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@ -132,10 +139,13 @@ def adapter():
@pytest.fixture(autouse=True)
def _redirect_cache(tmp_path, monkeypatch):
"""Point document cache to tmp_path so tests don't touch ~/.hermes."""
"""Point document/video cache to tmp_path so tests don't touch ~/.hermes."""
monkeypatch.setattr(
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
)
monkeypatch.setattr(
"gateway.platforms.base.VIDEO_CACHE_DIR", tmp_path / "video_cache"
)
# ---------------------------------------------------------------------------
@ -348,6 +358,37 @@ class TestDocumentDownloadBlock:
adapter.handle_message.assert_called_once()
class TestVideoDownloadBlock:
@pytest.mark.asyncio
async def test_native_video_is_cached(self, adapter):
file_obj = _make_file_obj(b"fake-mp4")
file_obj.file_path = "videos/clip.mp4"
msg = _make_message()
msg.video = _make_video(file_obj)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert event.message_type == MessageType.VIDEO
assert len(event.media_urls) == 1
assert os.path.exists(event.media_urls[0])
assert event.media_types == [SUPPORTED_VIDEO_TYPES[".mp4"]]
@pytest.mark.asyncio
async def test_mp4_document_is_treated_as_video(self, adapter):
file_obj = _make_file_obj(b"fake-mp4-doc")
doc = _make_document(file_name="good.mp4", mime_type="video/mp4", file_size=1024, file_obj=file_obj)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert event.message_type == MessageType.VIDEO
assert len(event.media_urls) == 1
assert os.path.exists(event.media_urls[0])
assert event.media_types == [SUPPORTED_VIDEO_TYPES[".mp4"]]
# ---------------------------------------------------------------------------
# TestMediaGroups — media group (album) buffering
# ---------------------------------------------------------------------------
@ -483,6 +524,32 @@ class TestSendDocument:
assert "not found" in result.error.lower()
connected_adapter._bot.send_document.assert_not_called()
@pytest.mark.asyncio
async def test_send_document_workspace_path_has_docker_hint(self, connected_adapter):
"""Container-local-looking paths get a more actionable Docker hint."""
result = await connected_adapter.send_document(
chat_id="12345",
file_path="/workspace/report.txt",
)
assert result.success is False
assert "docker sandbox" in result.error.lower()
assert "host-visible path" in result.error.lower()
connected_adapter._bot.send_document.assert_not_called()
@pytest.mark.asyncio
async def test_send_document_outputs_path_has_docker_hint(self, connected_adapter):
"""Legacy /outputs paths also get the Docker hint."""
result = await connected_adapter.send_document(
chat_id="12345",
file_path="/outputs/report.txt",
)
assert result.success is False
assert "docker sandbox" in result.error.lower()
assert "host-visible path" in result.error.lower()
connected_adapter._bot.send_document.assert_not_called()
@pytest.mark.asyncio
async def test_send_document_not_connected(self, adapter):
"""If bot is None, returns not connected error."""
@ -665,6 +732,17 @@ class TestSendVideo:
assert result.success is False
assert "not found" in result.error.lower()
@pytest.mark.asyncio
async def test_send_video_workspace_path_has_docker_hint(self, connected_adapter):
result = await connected_adapter.send_video(
chat_id="12345",
video_path="/workspace/video.mp4",
)
assert result.success is False
assert "docker sandbox" in result.error.lower()
assert "host-visible path" in result.error.lower()
@pytest.mark.asyncio
async def test_send_video_not_connected(self, adapter):
result = await adapter.send_video(

View file

@ -71,7 +71,17 @@ def test_group_messages_can_require_direct_trigger_via_config():
assert adapter._should_process_message(_group_message("hello everyone")) is False
assert adapter._should_process_message(_group_message("hi @hermes_bot", entities=[_mention_entity("hi @hermes_bot")])) is True
assert adapter._should_process_message(_group_message("replying", reply_to_bot=True)) is True
assert adapter._should_process_message(_group_message("/status"), is_command=True) is True
# Commands must also respect require_mention when it is enabled
assert adapter._should_process_message(_group_message("/status"), is_command=True) is False
# But commands with @mention still pass (Telegram emits a MENTION entity
# for /cmd@botname — the bot menu and python-telegram-bot's CommandHandler
# rely on this same mechanism)
assert adapter._should_process_message(
_group_message("/status@hermes_bot", entities=[_mention_entity("/status@hermes_bot")])
) is True
# And commands still pass unconditionally when require_mention is disabled
adapter_no_mention = _make_adapter(require_mention=False)
assert adapter_no_mention._should_process_message(_group_message("/status"), is_command=True) is True
def test_free_response_chats_bypass_mention_requirement():

View file

@ -0,0 +1,185 @@
"""Tests for Telegram bot mention detection (bug #12545).
The old implementation used a naive substring check
(`f"@{bot_username}" in text.lower()`), which incorrectly matched partial
substrings like 'foo@hermes_bot.example'.
Detection now relies entirely on the MessageEntity objects Telegram's server
emits for real mentions. A bare `@username` substring in message text without
a corresponding `MENTION` entity is NOT a mention this correctly ignores
@handles that appear inside URLs, code blocks, email-like strings, or quoted
text, because Telegram's parser does not emit mention entities for any of
those contexts.
"""
from types import SimpleNamespace
from gateway.config import Platform, PlatformConfig
from gateway.platforms.telegram import TelegramAdapter
def _make_adapter():
adapter = object.__new__(TelegramAdapter)
adapter.platform = Platform.TELEGRAM
adapter.config = PlatformConfig(enabled=True, token="***", extra={})
adapter._bot = SimpleNamespace(id=999, username="hermes_bot")
return adapter
def _mention_entity(text, mention="@hermes_bot"):
"""Build a MENTION entity pointing at a literal `@username` in `text`."""
offset = text.index(mention)
return SimpleNamespace(type="mention", offset=offset, length=len(mention))
def _text_mention_entity(offset, length, user_id):
"""Build a TEXT_MENTION entity (used when the target user has no public @handle)."""
return SimpleNamespace(
type="text_mention",
offset=offset,
length=length,
user=SimpleNamespace(id=user_id),
)
def _message(text=None, caption=None, entities=None, caption_entities=None):
return SimpleNamespace(
text=text,
caption=caption,
entities=entities or [],
caption_entities=caption_entities or [],
message_thread_id=None,
chat=SimpleNamespace(id=-100, type="group"),
reply_to_message=None,
)
class TestRealMentionsAreDetected:
"""A real Telegram mention always comes with a MENTION entity — detect those."""
def test_mention_at_start_of_message(self):
adapter = _make_adapter()
text = "@hermes_bot hello world"
msg = _message(text=text, entities=[_mention_entity(text)])
assert adapter._message_mentions_bot(msg) is True
def test_mention_mid_sentence(self):
adapter = _make_adapter()
text = "hey @hermes_bot, can you help?"
msg = _message(text=text, entities=[_mention_entity(text)])
assert adapter._message_mentions_bot(msg) is True
def test_mention_at_end_of_message(self):
adapter = _make_adapter()
text = "thanks for looking @hermes_bot"
msg = _message(text=text, entities=[_mention_entity(text)])
assert adapter._message_mentions_bot(msg) is True
def test_mention_in_caption(self):
adapter = _make_adapter()
caption = "photo for @hermes_bot"
msg = _message(caption=caption, caption_entities=[_mention_entity(caption)])
assert adapter._message_mentions_bot(msg) is True
def test_text_mention_entity_targets_bot(self):
"""TEXT_MENTION is Telegram's entity type for @FirstName -> user without a public handle."""
adapter = _make_adapter()
msg = _message(text="hey you", entities=[_text_mention_entity(4, 3, user_id=999)])
assert adapter._message_mentions_bot(msg) is True
class TestSubstringFalsePositivesAreRejected:
"""Bare `@bot_username` substrings without a MENTION entity must NOT match.
These are all inputs where the OLD substring check returned True incorrectly.
A word-boundary regex would still over-match some of these (code blocks,
URLs). Entity-based detection handles them all correctly because Telegram's
parser does not emit mention entities for non-mention contexts.
"""
def test_email_like_substring(self):
"""bug #12545 exact repro: 'foo@hermes_bot.example'."""
adapter = _make_adapter()
msg = _message(text="email me at foo@hermes_bot.example")
assert adapter._message_mentions_bot(msg) is False
def test_hostname_substring(self):
adapter = _make_adapter()
msg = _message(text="contact user@hermes_bot.domain.com")
assert adapter._message_mentions_bot(msg) is False
def test_superstring_username(self):
"""`@hermes_botx` is a different username; Telegram would emit a mention
entity for `@hermes_botx`, not `@hermes_bot`."""
adapter = _make_adapter()
msg = _message(text="@hermes_botx hello")
assert adapter._message_mentions_bot(msg) is False
def test_underscore_suffix_substring(self):
adapter = _make_adapter()
msg = _message(text="see @hermes_bot_admin for help")
assert adapter._message_mentions_bot(msg) is False
def test_substring_inside_url_without_entity(self):
"""@handle inside a URL produces a URL entity, not a MENTION entity."""
adapter = _make_adapter()
msg = _message(text="see https://example.com/@hermes_bot for details")
assert adapter._message_mentions_bot(msg) is False
def test_substring_inside_code_block_without_entity(self):
"""Telegram doesn't emit mention entities inside code/pre entities."""
adapter = _make_adapter()
msg = _message(text="use the string `@hermes_bot` in config")
assert adapter._message_mentions_bot(msg) is False
def test_plain_text_with_no_at_sign(self):
adapter = _make_adapter()
msg = _message(text="just a normal group message")
assert adapter._message_mentions_bot(msg) is False
def test_email_substring_in_caption(self):
adapter = _make_adapter()
msg = _message(caption="foo@hermes_bot.example")
assert adapter._message_mentions_bot(msg) is False
class TestEntityEdgeCases:
"""Malformed or mismatched entities should not crash or over-match."""
def test_mention_entity_for_different_username(self):
adapter = _make_adapter()
text = "@someone_else hi"
msg = _message(text=text, entities=[_mention_entity(text, mention="@someone_else")])
assert adapter._message_mentions_bot(msg) is False
def test_text_mention_entity_for_different_user(self):
adapter = _make_adapter()
msg = _message(text="hi there", entities=[_text_mention_entity(0, 2, user_id=12345)])
assert adapter._message_mentions_bot(msg) is False
def test_malformed_entity_with_negative_offset(self):
adapter = _make_adapter()
msg = _message(text="@hermes_bot hi",
entities=[SimpleNamespace(type="mention", offset=-1, length=11)])
assert adapter._message_mentions_bot(msg) is False
def test_malformed_entity_with_zero_length(self):
adapter = _make_adapter()
msg = _message(text="@hermes_bot hi",
entities=[SimpleNamespace(type="mention", offset=0, length=0)])
assert adapter._message_mentions_bot(msg) is False
class TestCaseInsensitivity:
"""Telegram usernames are case-insensitive; the slice-compare normalizes both sides."""
def test_uppercase_mention(self):
adapter = _make_adapter()
text = "hi @HERMES_BOT"
msg = _message(text=text, entities=[_mention_entity(text, mention="@HERMES_BOT")])
assert adapter._message_mentions_bot(msg) is True
def test_mixed_case_mention(self):
adapter = _make_adapter()
text = "hi @Hermes_Bot"
msg = _message(text=text, entities=[_mention_entity(text, mention="@Hermes_Bot")])
assert adapter._message_mentions_bot(msg) is True

View file

@ -0,0 +1,100 @@
"""Tests for GHSA-3vpc-7q5r-276h — Telegram webhook secret required.
Previously, when TELEGRAM_WEBHOOK_URL was set but TELEGRAM_WEBHOOK_SECRET
was not, python-telegram-bot received secret_token=None and the webhook
endpoint accepted any HTTP POST.
The fix refuses to start the adapter in webhook mode without the secret.
"""
from __future__ import annotations
import re
import sys
from pathlib import Path
import pytest
_repo = str(Path(__file__).resolve().parents[2])
if _repo not in sys.path:
sys.path.insert(0, _repo)
class TestTelegramWebhookSecretRequired:
"""Direct source-level check of the webhook-secret guard.
The guard is embedded in TelegramAdapter.connect() and hard to isolate
via mocks (requires a full python-telegram-bot ApplicationBuilder
chain). These tests exercise it via source inspection verifying the
check exists, raises RuntimeError with the advisory link, and only
fires in webhook mode. End-to-end validation is covered by CI +
manual deployment tests.
"""
def _get_source(self) -> str:
path = Path(_repo) / "gateway" / "platforms" / "telegram.py"
return path.read_text(encoding="utf-8")
def test_webhook_branch_checks_secret(self):
"""The webhook-mode branch of connect() must read
TELEGRAM_WEBHOOK_SECRET and refuse when empty."""
src = self._get_source()
# The guard must appear after TELEGRAM_WEBHOOK_URL is set
assert re.search(
r'TELEGRAM_WEBHOOK_SECRET.*?\.strip\(\)\s*\n\s*if not webhook_secret:',
src, re.DOTALL,
), (
"TelegramAdapter.connect() must strip TELEGRAM_WEBHOOK_SECRET "
"and raise when the secret is empty — see GHSA-3vpc-7q5r-276h"
)
def test_guard_raises_runtime_error(self):
"""The guard raises RuntimeError (not a silent log) so operators
see the failure at startup."""
src = self._get_source()
# Between the "if not webhook_secret:" line and the next blank
# line block, we should see a RuntimeError being raised
guard_match = re.search(
r'if not webhook_secret:\s*\n\s*raise\s+RuntimeError\(',
src,
)
assert guard_match, (
"Missing webhook secret must raise RuntimeError — silent "
"fall-through was the original GHSA-3vpc-7q5r-276h bypass"
)
def test_guard_message_includes_advisory_link(self):
"""The RuntimeError message should reference the advisory so
operators can read the full context."""
src = self._get_source()
assert "GHSA-3vpc-7q5r-276h" in src, (
"Guard error message must cite the advisory for operator context"
)
def test_guard_message_explains_remediation(self):
"""The error should tell the operator how to fix it."""
src = self._get_source()
# Should mention how to generate a secret
assert "openssl rand" in src or "TELEGRAM_WEBHOOK_SECRET=" in src, (
"Guard error message should show operators how to set "
"TELEGRAM_WEBHOOK_SECRET"
)
def test_polling_branch_has_no_secret_guard(self):
"""Polling mode (else-branch) must NOT require the webhook secret —
polling authenticates via the bot token, not a webhook secret."""
src = self._get_source()
# The guard should appear inside the `if webhook_url:` branch,
# not the `else:` polling branch. Rough check: the raise is
# followed (within ~60 lines) by an `else:` that starts the
# polling branch, and there's no secret-check in that polling
# branch.
webhook_block = re.search(
r'if webhook_url:\s*\n(.*?)\n else:\s*\n(.*?)\n',
src, re.DOTALL,
)
if webhook_block:
webhook_body = webhook_block.group(1)
polling_body = webhook_block.group(2)
assert "TELEGRAM_WEBHOOK_SECRET" in webhook_body
assert "TELEGRAM_WEBHOOK_SECRET" not in polling_body

View file

@ -148,6 +148,70 @@ class TestDiscordTextBatching:
await asyncio.sleep(0.25)
adapter.handle_message.assert_called_once()
@pytest.mark.asyncio
async def test_shield_protects_handle_message_from_cancel(self):
"""Regression guard: a follow-up chunk arriving while
handle_message is mid-flight must NOT cancel the running
dispatch. _enqueue_text_event fires prior_task.cancel() on
every new chunk; without asyncio.shield around handle_message
the cancel propagates into the agent's streaming request and
aborts the response.
"""
adapter = _make_discord_adapter()
handle_started = asyncio.Event()
release_handle = asyncio.Event()
first_handle_cancelled = asyncio.Event()
first_handle_completed = asyncio.Event()
call_count = [0]
async def slow_handle(event):
call_count[0] += 1
# Only the first call (batch 1) is the one we're protecting.
if call_count[0] == 1:
handle_started.set()
try:
await release_handle.wait()
first_handle_completed.set()
except asyncio.CancelledError:
first_handle_cancelled.set()
raise
# Second call (batch 2) returns immediately — not the subject
# of this test.
adapter.handle_message = slow_handle
# Prime batch 1 and wait for it to land inside handle_message.
adapter._enqueue_text_event(_make_event("batch 1", Platform.DISCORD))
await asyncio.wait_for(handle_started.wait(), timeout=1.0)
# A new chunk arrives — _enqueue_text_event fires
# prior_task.cancel() on batch 1's flush task, which is
# currently awaiting inside handle_message.
adapter._enqueue_text_event(_make_event("batch 2 follow-up", Platform.DISCORD))
# Let the cancel propagate.
await asyncio.sleep(0.05)
# CRITICAL ASSERTION: batch 1's handle_message must NOT have
# been cancelled. Without asyncio.shield this assertion fails
# because CancelledError propagates from the flush task's
# `await self.handle_message(event)` into slow_handle.
assert not first_handle_cancelled.is_set(), (
"handle_message for batch 1 was cancelled by a follow-up "
"chunk — asyncio.shield is missing or broken"
)
# Release batch 1's handle_message and let it complete.
release_handle.set()
await asyncio.wait_for(first_handle_completed.wait(), timeout=1.0)
assert first_handle_completed.is_set()
# Cleanup
for task in list(adapter._pending_text_batch_tasks.values()):
task.cancel()
await asyncio.sleep(0.01)
# =====================================================================
# Matrix text batching

View file

@ -63,6 +63,12 @@ def _make_runner(platform: Platform, config: GatewayConfig):
runner.pairing_store = MagicMock()
runner.pairing_store.is_approved.return_value = False
runner.pairing_store._is_rate_limited.return_value = False
# Attributes required by _handle_message for the authorized-user path
runner._running_agents = {}
runner._running_agents_ts = {}
runner._update_prompts = {}
runner.hooks = SimpleNamespace(dispatch=AsyncMock(return_value=None))
runner._sessions = {}
return runner, adapter
@ -295,3 +301,172 @@ async def test_global_ignore_suppresses_pairing_reply(monkeypatch):
assert result is None
runner.pairing_store.generate_code.assert_not_called()
adapter.send.assert_not_awaited()
# ---------------------------------------------------------------------------
# Allowlist-configured platforms default to "ignore" for unauthorized users
# (#9337: Signal gateway sends pairing spam when allowlist is configured)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_signal_with_allowlist_ignores_unauthorized_dm(monkeypatch):
"""When SIGNAL_ALLOWED_USERS is set, unauthorized DMs are silently dropped.
This is the primary regression test for #9337: before the fix, Signal
would send pairing codes to ANY sender even when a strict allowlist was
configured, spamming personal contacts with cryptic bot messages.
"""
_clear_auth_env(monkeypatch)
monkeypatch.setenv("SIGNAL_ALLOWED_USERS", "+15550000001") # allowlist set
config = GatewayConfig(
platforms={Platform.SIGNAL: PlatformConfig(enabled=True)},
)
runner, adapter = _make_runner(Platform.SIGNAL, config)
result = await runner._handle_message(
_make_event(Platform.SIGNAL, "+15559999999", "+15559999999") # not in allowlist
)
assert result is None
runner.pairing_store.generate_code.assert_not_called()
adapter.send.assert_not_awaited()
@pytest.mark.asyncio
async def test_telegram_with_allowlist_ignores_unauthorized_dm(monkeypatch):
"""Same behavior for Telegram: allowlist ⟹ ignore unauthorized DMs."""
_clear_auth_env(monkeypatch)
monkeypatch.setenv("TELEGRAM_ALLOWED_USERS", "111111111")
config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True)},
)
runner, adapter = _make_runner(Platform.TELEGRAM, config)
result = await runner._handle_message(
_make_event(Platform.TELEGRAM, "999999999", "999999999")
)
assert result is None
runner.pairing_store.generate_code.assert_not_called()
adapter.send.assert_not_awaited()
@pytest.mark.asyncio
async def test_global_allowlist_ignores_unauthorized_dm(monkeypatch):
"""GATEWAY_ALLOWED_USERS also triggers the 'ignore' behavior."""
_clear_auth_env(monkeypatch)
monkeypatch.setenv("GATEWAY_ALLOWED_USERS", "111111111")
config = GatewayConfig(
platforms={Platform.SIGNAL: PlatformConfig(enabled=True)},
)
runner, adapter = _make_runner(Platform.SIGNAL, config)
result = await runner._handle_message(
_make_event(Platform.SIGNAL, "+15559999999", "+15559999999")
)
assert result is None
runner.pairing_store.generate_code.assert_not_called()
adapter.send.assert_not_awaited()
@pytest.mark.asyncio
async def test_no_allowlist_still_pairs_by_default(monkeypatch):
"""Without any allowlist, pairing behavior is preserved (open gateway)."""
_clear_auth_env(monkeypatch)
# No SIGNAL_ALLOWED_USERS, no GATEWAY_ALLOWED_USERS
config = GatewayConfig(
platforms={Platform.SIGNAL: PlatformConfig(enabled=True)},
)
runner, adapter = _make_runner(Platform.SIGNAL, config)
runner.pairing_store.generate_code.return_value = "PAIR1234"
result = await runner._handle_message(
_make_event(Platform.SIGNAL, "+15559999999", "+15559999999")
)
assert result is None
runner.pairing_store.generate_code.assert_called_once()
adapter.send.assert_awaited_once()
assert "PAIR1234" in adapter.send.await_args.args[1]
def test_explicit_pair_config_overrides_allowlist_default(monkeypatch):
"""Explicit unauthorized_dm_behavior='pair' overrides the allowlist default.
Operators can opt back in to pairing even with an allowlist by setting
unauthorized_dm_behavior: pair in their platform config. We test the
_get_unauthorized_dm_behavior resolver directly to avoid the full
_handle_message pipeline which requires extensive runner state.
"""
_clear_auth_env(monkeypatch)
monkeypatch.setenv("SIGNAL_ALLOWED_USERS", "+15550000001")
config = GatewayConfig(
platforms={
Platform.SIGNAL: PlatformConfig(
enabled=True,
extra={"unauthorized_dm_behavior": "pair"}, # explicit override
),
},
)
runner, _adapter = _make_runner(Platform.SIGNAL, config)
# The per-platform explicit config should beat the allowlist-derived default
behavior = runner._get_unauthorized_dm_behavior(Platform.SIGNAL)
assert behavior == "pair"
def test_allowlist_authorized_user_returns_ignore_for_unauthorized(monkeypatch):
"""_get_unauthorized_dm_behavior returns 'ignore' when allowlist is set.
We test the resolver directly. The full _handle_message path for
authorized users is covered by the integration tests in this module.
"""
_clear_auth_env(monkeypatch)
monkeypatch.setenv("SIGNAL_ALLOWED_USERS", "+15550000001")
config = GatewayConfig(
platforms={Platform.SIGNAL: PlatformConfig(enabled=True)},
)
runner, _adapter = _make_runner(Platform.SIGNAL, config)
behavior = runner._get_unauthorized_dm_behavior(Platform.SIGNAL)
assert behavior == "ignore"
def test_get_unauthorized_dm_behavior_no_allowlist_returns_pair(monkeypatch):
"""Without any allowlist, 'pair' is still the default."""
_clear_auth_env(monkeypatch)
config = GatewayConfig(
platforms={Platform.SIGNAL: PlatformConfig(enabled=True)},
)
runner, _adapter = _make_runner(Platform.SIGNAL, config)
behavior = runner._get_unauthorized_dm_behavior(Platform.SIGNAL)
assert behavior == "pair"
def test_qqbot_with_allowlist_ignores_unauthorized_dm(monkeypatch):
"""QQBOT is included in the allowlist-aware default (QQ_ALLOWED_USERS).
Regression guard: the initial #9337 fix omitted QQBOT from the env map
inside _get_unauthorized_dm_behavior, even though _is_user_authorized
mapped it to QQ_ALLOWED_USERS. Without QQBOT here, a QQ operator with a
strict user allowlist would still get pairing codes sent to strangers.
"""
_clear_auth_env(monkeypatch)
monkeypatch.setenv("QQ_ALLOWED_USERS", "allowed-openid-1")
config = GatewayConfig(
platforms={Platform.QQBOT: PlatformConfig(enabled=True)},
)
runner, _adapter = _make_runner(Platform.QQBOT, config)
behavior = runner._get_unauthorized_dm_behavior(Platform.QQBOT)
assert behavior == "ignore"

View file

@ -175,3 +175,79 @@ class TestUsageCachedAgent:
result = await runner._handle_usage_command(event)
assert "Cost: included" in result
class TestUsageAccountSection:
"""Account-limits section appended to /usage output (PR #2486)."""
@pytest.mark.asyncio
async def test_usage_command_includes_account_section(self, monkeypatch):
agent = _make_mock_agent(provider="openai-codex")
agent.base_url = "https://chatgpt.com/backend-api/codex"
agent.api_key = "unused"
runner = _make_runner(SK, cached_agent=agent)
event = MagicMock()
monkeypatch.setattr(
"gateway.run.fetch_account_usage",
lambda provider, base_url=None, api_key=None: object(),
)
monkeypatch.setattr(
"gateway.run.render_account_usage_lines",
lambda snapshot, markdown=False: [
"📈 **Account limits**",
"Provider: openai-codex (Pro)",
"Session: 85% remaining (15% used)",
],
)
with patch("agent.rate_limit_tracker.format_rate_limit_compact", return_value="RPM: 50/60"), \
patch("agent.usage_pricing.estimate_usage_cost") as mock_cost:
mock_cost.return_value = MagicMock(amount_usd=None, status="included")
result = await runner._handle_usage_command(event)
assert "📊 **Session Token Usage**" in result
assert "📈 **Account limits**" in result
assert "Provider: openai-codex (Pro)" in result
@pytest.mark.asyncio
async def test_usage_command_uses_persisted_provider_when_agent_not_running(self, monkeypatch):
runner = _make_runner(SK)
runner._session_db = MagicMock()
runner._session_db.get_session.return_value = {
"billing_provider": "openai-codex",
"billing_base_url": "https://chatgpt.com/backend-api/codex",
}
session_entry = MagicMock()
session_entry.session_id = "sess-1"
runner.session_store.get_or_create_session.return_value = session_entry
runner.session_store.load_transcript.return_value = [
{"role": "user", "content": "earlier"},
]
calls = {}
async def _fake_to_thread(fn, *args, **kwargs):
calls["args"] = args
calls["kwargs"] = kwargs
return fn(*args, **kwargs)
monkeypatch.setattr("gateway.run.asyncio.to_thread", _fake_to_thread)
monkeypatch.setattr(
"gateway.run.fetch_account_usage",
lambda provider, base_url=None, api_key=None: object(),
)
monkeypatch.setattr(
"gateway.run.render_account_usage_lines",
lambda snapshot, markdown=False: [
"📈 **Account limits**",
"Provider: openai-codex (Pro)",
],
)
event = MagicMock()
result = await runner._handle_usage_command(event)
assert calls["args"] == ("openai-codex",)
assert calls["kwargs"]["base_url"] == "https://chatgpt.com/backend-api/codex"
assert "📊 **Session Info**" in result
assert "📈 **Account limits**" in result

View file

@ -99,22 +99,22 @@ class TestHandleVoiceCommand:
event = _make_event("/voice on")
result = await runner._handle_voice_command(event)
assert "enabled" in result.lower()
assert runner._voice_mode["123"] == "voice_only"
assert runner._voice_mode["telegram:123"] == "voice_only"
@pytest.mark.asyncio
async def test_voice_off(self, runner):
runner._voice_mode["123"] = "voice_only"
runner._voice_mode["telegram:123"] = "voice_only"
event = _make_event("/voice off")
result = await runner._handle_voice_command(event)
assert "disabled" in result.lower()
assert runner._voice_mode["123"] == "off"
assert runner._voice_mode["telegram:123"] == "off"
@pytest.mark.asyncio
async def test_voice_tts(self, runner):
event = _make_event("/voice tts")
result = await runner._handle_voice_command(event)
assert "tts" in result.lower()
assert runner._voice_mode["123"] == "all"
assert runner._voice_mode["telegram:123"] == "all"
@pytest.mark.asyncio
async def test_voice_status_off(self, runner):
@ -124,7 +124,7 @@ class TestHandleVoiceCommand:
@pytest.mark.asyncio
async def test_voice_status_on(self, runner):
runner._voice_mode["123"] = "voice_only"
runner._voice_mode["telegram:123"] = "voice_only"
event = _make_event("/voice status")
result = await runner._handle_voice_command(event)
assert "voice reply" in result.lower()
@ -134,15 +134,15 @@ class TestHandleVoiceCommand:
event = _make_event("/voice")
result = await runner._handle_voice_command(event)
assert "enabled" in result.lower()
assert runner._voice_mode["123"] == "voice_only"
assert runner._voice_mode["telegram:123"] == "voice_only"
@pytest.mark.asyncio
async def test_toggle_on_to_off(self, runner):
runner._voice_mode["123"] = "voice_only"
runner._voice_mode["telegram:123"] = "voice_only"
event = _make_event("/voice")
result = await runner._handle_voice_command(event)
assert "disabled" in result.lower()
assert runner._voice_mode["123"] == "off"
assert runner._voice_mode["telegram:123"] == "off"
@pytest.mark.asyncio
async def test_persistence_saved(self, runner):
@ -150,39 +150,47 @@ class TestHandleVoiceCommand:
await runner._handle_voice_command(event)
assert runner._VOICE_MODE_PATH.exists()
data = json.loads(runner._VOICE_MODE_PATH.read_text())
assert data["123"] == "voice_only"
assert data["telegram:123"] == "voice_only"
@pytest.mark.asyncio
async def test_persistence_loaded(self, runner):
runner._VOICE_MODE_PATH.write_text(json.dumps({"456": "all"}))
runner._VOICE_MODE_PATH.write_text(json.dumps({"telegram:456": "all"}))
loaded = runner._load_voice_modes()
assert loaded == {"456": "all"}
assert loaded == {"telegram:456": "all"}
@pytest.mark.asyncio
async def test_persistence_saved_for_off(self, runner):
event = _make_event("/voice off")
await runner._handle_voice_command(event)
data = json.loads(runner._VOICE_MODE_PATH.read_text())
assert data["123"] == "off"
assert data["telegram:123"] == "off"
def test_sync_voice_mode_state_to_adapter_restores_off_chats(self, runner):
runner._voice_mode = {"123": "off", "456": "all"}
adapter = SimpleNamespace(_auto_tts_disabled_chats=set())
from gateway.config import Platform
runner._voice_mode = {"telegram:123": "off", "telegram:456": "all"}
adapter = SimpleNamespace(
_auto_tts_disabled_chats=set(),
platform=Platform.TELEGRAM,
)
runner._sync_voice_mode_state_to_adapter(adapter)
assert adapter._auto_tts_disabled_chats == {"123"}
def test_restart_restores_voice_off_state(self, runner, tmp_path):
runner._VOICE_MODE_PATH.write_text(json.dumps({"123": "off"}))
from gateway.config import Platform
runner._VOICE_MODE_PATH.write_text(json.dumps({"telegram:123": "off"}))
restored_runner = _make_runner(tmp_path)
restored_runner._voice_mode = restored_runner._load_voice_modes()
adapter = SimpleNamespace(_auto_tts_disabled_chats=set())
adapter = SimpleNamespace(
_auto_tts_disabled_chats=set(),
platform=Platform.TELEGRAM,
)
restored_runner._sync_voice_mode_state_to_adapter(adapter)
assert restored_runner._voice_mode["123"] == "off"
assert restored_runner._voice_mode["telegram:123"] == "off"
assert adapter._auto_tts_disabled_chats == {"123"}
@pytest.mark.asyncio
@ -191,8 +199,21 @@ class TestHandleVoiceCommand:
e2 = _make_event("/voice tts", chat_id="bbb")
await runner._handle_voice_command(e1)
await runner._handle_voice_command(e2)
assert runner._voice_mode["aaa"] == "voice_only"
assert runner._voice_mode["bbb"] == "all"
assert runner._voice_mode["telegram:aaa"] == "voice_only"
assert runner._voice_mode["telegram:bbb"] == "all"
@pytest.mark.asyncio
async def test_platform_isolation(self, runner):
"""Same chat_id on different platforms must not collide (#12542)."""
telegram_event = _make_event("/voice on", chat_id="999")
slack_event = _make_event("/voice off", chat_id="999")
slack_event.source.platform.value = "slack"
await runner._handle_voice_command(telegram_event)
await runner._handle_voice_command(slack_event)
assert runner._voice_mode["telegram:999"] == "voice_only"
assert runner._voice_mode["slack:999"] == "off"
# =====================================================================
@ -223,9 +244,9 @@ class TestAutoVoiceReply:
"""Call real _should_send_voice_reply on a GatewayRunner instance."""
chat_id = "123"
if voice_mode != "off":
runner._voice_mode[chat_id] = voice_mode
runner._voice_mode["telegram:" + chat_id] = voice_mode
else:
runner._voice_mode.pop(chat_id, None)
runner._voice_mode.pop("telegram:" + chat_id, None)
event = _make_event(message_type=message_type)
@ -416,6 +437,7 @@ class TestDiscordPlayTtsSkip:
adapter.platform = Platform.DISCORD
adapter.config = config
adapter._voice_clients = {}
adapter._voice_locks = {}
adapter._voice_text_channels = {}
adapter._voice_sources = {}
adapter._voice_timeout_tasks = {}
@ -712,7 +734,7 @@ class TestVoiceChannelCommands:
result = await runner._handle_voice_channel_join(event)
assert "joined" in result.lower()
assert "General" in result
assert runner._voice_mode["123"] == "all"
assert runner._voice_mode["discord:123"] == "all"
assert mock_adapter._voice_sources[111]["chat_id"] == "123"
assert mock_adapter._voice_sources[111]["chat_type"] == "group"
@ -790,10 +812,10 @@ class TestVoiceChannelCommands:
mock_adapter.leave_voice_channel = AsyncMock()
event = self._make_discord_event("/voice leave")
runner.adapters[event.source.platform] = mock_adapter
runner._voice_mode["123"] = "all"
runner._voice_mode["discord:123"] = "all"
result = await runner._handle_voice_channel_leave(event)
assert "left" in result.lower()
assert runner._voice_mode["123"] == "off"
assert runner._voice_mode["discord:123"] == "off"
mock_adapter.leave_voice_channel.assert_called_once_with(111)
# -- _handle_voice_channel_input --
@ -931,6 +953,7 @@ class TestDiscordVoiceChannelMethods:
adapter.config = config
adapter._client = MagicMock()
adapter._voice_clients = {}
adapter._voice_locks = {}
adapter._voice_text_channels = {}
adapter._voice_sources = {}
adapter._voice_timeout_tasks = {}
@ -1296,11 +1319,11 @@ class TestLeaveExceptionHandling:
event = _make_event("/voice leave")
event.raw_message = SimpleNamespace(guild_id=111, guild=None)
runner.adapters[event.source.platform] = mock_adapter
runner._voice_mode["123"] = "all"
runner._voice_mode["telegram:123"] = "all"
result = await runner._handle_voice_channel_leave(event)
assert "left" in result.lower()
assert runner._voice_mode["123"] == "off"
assert runner._voice_mode["telegram:123"] == "off"
assert mock_adapter._voice_input_callback is None
@pytest.mark.asyncio
@ -1314,7 +1337,7 @@ class TestLeaveExceptionHandling:
event = _make_event("/voice leave")
event.raw_message = SimpleNamespace(guild_id=111, guild=None)
runner.adapters[event.source.platform] = mock_adapter
runner._voice_mode["123"] = "all"
runner._voice_mode["telegram:123"] = "all"
await runner._handle_voice_channel_leave(event)
assert mock_adapter._voice_input_callback is None
@ -1712,6 +1735,7 @@ class TestVoiceTimeoutCleansRunnerState:
adapter.platform = Platform.DISCORD
adapter.config = config
adapter._voice_clients = {}
adapter._voice_locks = {}
adapter._voice_text_channels = {}
adapter._voice_sources = {}
adapter._voice_timeout_tasks = {}
@ -1760,11 +1784,11 @@ class TestVoiceTimeoutCleansRunnerState:
async def test_runner_cleanup_method_removes_voice_mode(self, tmp_path):
"""_handle_voice_timeout_cleanup removes voice_mode for chat."""
runner = _make_runner(tmp_path)
runner._voice_mode["999"] = "all"
runner._voice_mode["discord:999"] = "all"
runner._handle_voice_timeout_cleanup("999")
assert runner._voice_mode["999"] == "off", \
assert runner._voice_mode["discord:999"] == "off", \
"voice_mode must persist explicit off state after timeout cleanup"
@pytest.mark.asyncio
@ -1802,6 +1826,7 @@ class TestPlaybackTimeout:
adapter.platform = Platform.DISCORD
adapter.config = config
adapter._voice_clients = {}
adapter._voice_locks = {}
adapter._voice_text_channels = {}
adapter._voice_sources = {}
adapter._voice_timeout_tasks = {}
@ -1983,6 +2008,7 @@ class TestVoiceChannelAwareness:
config.token = "fake-token"
adapter = object.__new__(DiscordAdapter)
adapter._voice_clients = {}
adapter._voice_locks = {}
adapter._voice_text_channels = {}
adapter._voice_sources = {}
adapter._voice_receivers = {}
@ -2453,6 +2479,7 @@ class TestVoiceTTSPlayback:
adapter.platform = Platform.DISCORD
adapter.config = config
adapter._voice_clients = {}
adapter._voice_locks = {}
adapter._voice_text_channels = {}
adapter._voice_sources = {}
adapter._voice_receivers = {}
@ -2518,7 +2545,7 @@ class TestVoiceTTSPlayback:
agent_msgs=None, already_sent=False):
from gateway.platforms.base import MessageType, MessageEvent, SessionSource
from gateway.config import Platform
runner._voice_mode["ch1"] = voice_mode
runner._voice_mode["discord:ch1"] = voice_mode
source = SessionSource(
platform=Platform.DISCORD, chat_id="ch1",
user_id="1", user_name="test", chat_type="channel",
@ -2633,6 +2660,7 @@ class TestUDPKeepalive:
adapter.platform = Platform.DISCORD
adapter.config = config
adapter._voice_clients = {}
adapter._voice_locks = {}
adapter._voice_text_channels = {}
adapter._voice_sources = {}
adapter._voice_receivers = {}

View file

@ -0,0 +1,218 @@
"""Tests for voice mode platform isolation (bug #12542).
Voice mode state stored as {chat_id: mode} without a platform namespace
caused collisions: Telegram chat '123' and Slack chat '123' shared the
same key. The fix prefixes keys with platform value: 'telegram:123' vs
'slack:123'.
"""
import json
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from gateway.config import Platform
from gateway.run import GatewayRunner
class TestVoiceKeyHelper:
"""Test the _voice_key helper method."""
def test_voice_key_format(self):
"""_voice_key returns 'platform:chat_id' format."""
runner = _make_runner()
assert runner._voice_key(Platform.TELEGRAM, "123") == "telegram:123"
assert runner._voice_key(Platform.SLACK, "456") == "slack:456"
assert runner._voice_key(Platform.DISCORD, "789") == "discord:789"
def test_voice_key_different_platforms_same_chat_id(self):
"""Same chat_id on different platforms yields different keys."""
runner = _make_runner()
key_telegram = runner._voice_key(Platform.TELEGRAM, "123")
key_slack = runner._voice_key(Platform.SLACK, "123")
key_discord = runner._voice_key(Platform.DISCORD, "123")
assert key_telegram != key_slack
assert key_slack != key_discord
assert key_telegram == "telegram:123"
assert key_slack == "slack:123"
assert key_discord == "discord:123"
class TestVoiceModePlatformIsolation:
"""Test that voice mode state is isolated by platform."""
def test_telegram_and_slack_voice_mode_independent(self):
"""Setting voice mode for Telegram chat '123' does not affect Slack chat '123'."""
runner = _make_runner()
# Enable voice mode for Telegram chat '123'
runner._voice_mode[runner._voice_key(Platform.TELEGRAM, "123")] = "all"
# Enable voice mode for Slack chat '123' to a different mode
runner._voice_mode[runner._voice_key(Platform.SLACK, "123")] = "voice_only"
# Verify they are independent
assert runner._voice_mode.get(runner._voice_key(Platform.TELEGRAM, "123")) == "all"
assert runner._voice_mode.get(runner._voice_key(Platform.SLACK, "123")) == "voice_only"
# Disabling Telegram should not affect Slack
runner._voice_mode[runner._voice_key(Platform.TELEGRAM, "123")] = "off"
assert runner._voice_mode.get(runner._voice_key(Platform.TELEGRAM, "123")) == "off"
assert runner._voice_mode.get(runner._voice_key(Platform.SLACK, "123")) == "voice_only"
class TestLegacyKeyMigration:
"""Test migration of legacy unprefixed keys in _load_voice_modes."""
def test_load_voice_modes_skips_legacy_keys(self):
"""_load_voice_modes skips keys without ':' prefix and logs a warning."""
runner = _make_runner()
# Simulate legacy persisted data with unprefixed keys
legacy_data = {
"123": "all",
"456": "voice_only",
# Also includes a properly prefixed key (from after the fix)
"telegram:789": "off",
}
with tempfile.TemporaryDirectory() as tmpdir:
voice_path = Path(tmpdir) / "gateway_voice_mode.json"
voice_path.write_text(json.dumps(legacy_data))
with patch.object(runner, "_VOICE_MODE_PATH", voice_path):
with patch("gateway.run.logger") as mock_logger:
result = runner._load_voice_modes()
# Legacy keys without ':' should be skipped
assert "123" not in result
assert "456" not in result
# Prefixed key should be preserved
assert result.get("telegram:789") == "off"
# Warning should be logged for each legacy key
assert mock_logger.warning.called
warning_calls = [str(call) for call in mock_logger.warning.call_args_list]
assert any("Skipping legacy unprefixed voice mode key" in str(c) for c in warning_calls)
def test_load_voice_modes_preserves_prefixed_keys(self):
"""_load_voice_modes correctly loads platform-prefixed keys."""
runner = _make_runner()
persisted_data = {
"telegram:123": "all",
"slack:456": "voice_only",
"discord:789": "off",
}
with tempfile.TemporaryDirectory() as tmpdir:
voice_path = Path(tmpdir) / "gateway_voice_mode.json"
voice_path.write_text(json.dumps(persisted_data))
with patch.object(runner, "_VOICE_MODE_PATH", voice_path):
result = runner._load_voice_modes()
assert result.get("telegram:123") == "all"
assert result.get("slack:456") == "voice_only"
assert result.get("discord:789") == "off"
def test_load_voice_modes_invalid_modes_filtered(self):
"""_load_voice_modes filters out invalid mode values."""
runner = _make_runner()
data = {
"telegram:123": "all",
"telegram:456": "invalid_mode",
"telegram:789": "voice_only",
}
with tempfile.TemporaryDirectory() as tmpdir:
voice_path = Path(tmpdir) / "gateway_voice_mode.json"
voice_path.write_text(json.dumps(data))
with patch.object(runner, "_VOICE_MODE_PATH", voice_path):
result = runner._load_voice_modes()
assert result.get("telegram:123") == "all"
assert "telegram:456" not in result
assert result.get("telegram:789") == "voice_only"
class TestSyncVoiceModeStateToAdapter:
"""Test _sync_voice_mode_state_to_adapter filters by platform."""
def test_sync_only_includes_platform_chats(self):
"""Only chats matching the adapter's platform are synced."""
runner = _make_runner()
# Set up voice mode state with multiple platforms
runner._voice_mode = {
"telegram:123": "off", # Should sync
"telegram:456": "all", # Should NOT sync (mode is not "off")
"slack:123": "off", # Should NOT sync (different platform)
"discord:789": "off", # Should NOT sync (different platform)
}
# Create a mock Telegram adapter
mock_adapter = MagicMock()
mock_adapter.platform = Platform.TELEGRAM
mock_adapter._auto_tts_disabled_chats = set()
runner._sync_voice_mode_state_to_adapter(mock_adapter)
# Only telegram:123 should be in disabled_chats (mode="off" for telegram)
assert mock_adapter._auto_tts_disabled_chats == {"123"}
def test_sync_clears_existing_state(self):
"""_sync_voice_mode_state_to_adapter clears existing disabled_chats first."""
runner = _make_runner()
runner._voice_mode = {
"telegram:123": "off",
}
mock_adapter = MagicMock()
mock_adapter.platform = Platform.TELEGRAM
mock_adapter._auto_tts_disabled_chats = {"old_chat_id", "another_old"}
runner._sync_voice_mode_state_to_adapter(mock_adapter)
# Old entries should be cleared
assert mock_adapter._auto_tts_disabled_chats == {"123"}
def test_sync_returns_early_without_platform(self):
"""_sync_voice_mode_state_to_adapter returns early if adapter has no platform."""
runner = _make_runner()
runner._voice_mode = {"telegram:123": "off"}
mock_adapter = MagicMock()
mock_adapter.platform = None
mock_adapter._auto_tts_disabled_chats = {"old"}
runner._sync_voice_mode_state_to_adapter(mock_adapter)
# disabled_chats should not be modified
assert mock_adapter._auto_tts_disabled_chats == {"old"}
def test_sync_returns_early_without_auto_tts_disabled_chats(self):
"""_sync_voice_mode_state_to_adapter returns early if adapter lacks _auto_tts_disabled_chats."""
runner = _make_runner()
runner._voice_mode = {"telegram:123": "off"}
mock_adapter = MagicMock(spec=[]) # No _auto_tts_disabled_chats attribute
# Should not raise
runner._sync_voice_mode_state_to_adapter(mock_adapter)
# ---------------------------------------------------------------------------
# Helper
# ---------------------------------------------------------------------------
def _make_runner() -> GatewayRunner:
"""Create a minimal GatewayRunner for testing."""
with patch("gateway.run.GatewayRunner._load_voice_modes", return_value={}):
runner = GatewayRunner.__new__(GatewayRunner)
runner._voice_mode = {}
runner.adapters = {}
return runner

View file

@ -0,0 +1,473 @@
"""Tests for the webhook adapter's ``deliver_only`` route mode.
``deliver_only`` lets external services (Supabase webhooks, monitoring
alerts, background jobs, other agents) push plain-text notifications to
a user's chat via the webhook adapter WITHOUT invoking the agent. The
rendered prompt template becomes the literal message body.
Covers:
- Agent is NOT invoked (``handle_message`` never called)
- Rendered content is delivered to the target platform adapter
- HTTP returns 200 OK on success, 502 on delivery failure
- Startup validation rejects ``deliver_only`` without a real delivery target
- HMAC auth, rate limiting, and idempotency still apply
"""
import asyncio
import hashlib
import hmac
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from aiohttp import web
from aiohttp.test_utils import TestClient, TestServer
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import MessageEvent, SendResult
from gateway.platforms.webhook import WebhookAdapter, _INSECURE_NO_AUTH
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_adapter(routes, **extra_kw) -> WebhookAdapter:
extra = {"host": "0.0.0.0", "port": 0, "routes": routes}
extra.update(extra_kw)
config = PlatformConfig(enabled=True, extra=extra)
return WebhookAdapter(config)
def _create_app(adapter: WebhookAdapter) -> web.Application:
app = web.Application()
app.router.add_get("/health", adapter._handle_health)
app.router.add_post("/webhooks/{route_name}", adapter._handle_webhook)
return app
def _wire_mock_target(adapter: WebhookAdapter, platform_name: str = "telegram"):
"""Attach a gateway_runner with a mocked target adapter."""
mock_target = AsyncMock()
mock_target.send = AsyncMock(return_value=SendResult(success=True))
mock_runner = MagicMock()
mock_runner.adapters = {Platform(platform_name): mock_target}
mock_runner.config.get_home_channel.return_value = None
adapter.gateway_runner = mock_runner
return mock_target
# ===================================================================
# Core behaviour: agent bypass
# ===================================================================
class TestDeliverOnlyBypassesAgent:
"""The whole point of the feature — handle_message must not be called."""
@pytest.mark.asyncio
async def test_post_delivers_directly_without_agent(self):
routes = {
"match-alert": {
"secret": _INSECURE_NO_AUTH,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "12345"},
"prompt": "{payload.user} matched with {payload.other}!",
}
}
adapter = _make_adapter(routes)
mock_target = _wire_mock_target(adapter)
# Guard: handle_message must NOT be called in deliver_only mode
handle_message_calls: list[MessageEvent] = []
async def _capture(event):
handle_message_calls.append(event)
adapter.handle_message = _capture
app = _create_app(adapter)
body = json.dumps(
{"payload": {"user": "alice", "other": "bob"}}
).encode()
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/match-alert",
data=body,
headers={
"Content-Type": "application/json",
"X-GitHub-Delivery": "delivery-1",
},
)
assert resp.status == 200
data = await resp.json()
assert data["status"] == "delivered"
assert data["route"] == "match-alert"
assert data["target"] == "telegram"
# Let any background tasks settle before asserting no agent call
await asyncio.sleep(0.05)
# Agent was NOT invoked
assert handle_message_calls == []
# Target adapter.send() WAS called with the rendered template
mock_target.send.assert_awaited_once()
call_args = mock_target.send.await_args
chat_id_arg, content_arg = call_args.args[0], call_args.args[1]
assert chat_id_arg == "12345"
assert content_arg == "alice matched with bob!"
@pytest.mark.asyncio
async def test_template_rendering_works(self):
"""Dot-notation template variables resolve in deliver_only mode."""
routes = {
"alert": {
"secret": _INSECURE_NO_AUTH,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "chat-1"},
"prompt": "Build {build.number} status: {build.status}",
}
}
adapter = _make_adapter(routes)
mock_target = _wire_mock_target(adapter)
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/alert",
json={"build": {"number": 77, "status": "FAILED"}},
headers={"X-GitHub-Delivery": "d-render-1"},
)
assert resp.status == 200
mock_target.send.assert_awaited_once()
content_arg = mock_target.send.await_args.args[1]
assert content_arg == "Build 77 status: FAILED"
@pytest.mark.asyncio
async def test_thread_id_passed_through(self):
"""deliver_extra.thread_id flows through to the target adapter."""
routes = {
"r": {
"secret": _INSECURE_NO_AUTH,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "c-1", "thread_id": "topic-42"},
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
mock_target = _wire_mock_target(adapter)
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": "d-thread-1"},
)
assert resp.status == 200
assert mock_target.send.await_args.kwargs["metadata"] == {
"thread_id": "topic-42"
}
# ===================================================================
# HTTP status codes
# ===================================================================
class TestDeliverOnlyStatusCodes:
@pytest.mark.asyncio
async def test_delivery_failure_returns_502(self):
"""If the target adapter returns SendResult(success=False), 502."""
routes = {
"r": {
"secret": _INSECURE_NO_AUTH,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "c-1"},
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
mock_target = _wire_mock_target(adapter)
mock_target.send = AsyncMock(
return_value=SendResult(success=False, error="rate limited by tg")
)
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": "d-fail-1"},
)
assert resp.status == 502
data = await resp.json()
# Generic error — no adapter-level detail leaks
assert data["error"] == "Delivery failed"
assert "rate limited" not in json.dumps(data)
@pytest.mark.asyncio
async def test_delivery_exception_returns_502(self):
"""If adapter.send() raises, we return 502 (not 500)."""
routes = {
"r": {
"secret": _INSECURE_NO_AUTH,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "c-1"},
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
mock_target = _wire_mock_target(adapter)
mock_target.send = AsyncMock(side_effect=RuntimeError("tg exploded"))
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": "d-exc-1"},
)
assert resp.status == 502
data = await resp.json()
assert data["error"] == "Delivery failed"
# Exception message must not leak
assert "exploded" not in json.dumps(data)
@pytest.mark.asyncio
async def test_target_platform_not_connected_returns_502(self):
"""deliver_only to a platform the gateway doesn't have → 502."""
routes = {
"r": {
"secret": _INSECURE_NO_AUTH,
"deliver": "discord", # not configured in mock runner
"deliver_only": True,
"deliver_extra": {"chat_id": "c-1"},
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
_wire_mock_target(adapter, platform_name="telegram") # only TG wired
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": "d-no-platform-1"},
)
assert resp.status == 502
# ===================================================================
# Startup validation
# ===================================================================
class TestDeliverOnlyStartupValidation:
@pytest.mark.asyncio
async def test_deliver_only_with_log_deliver_rejected(self):
"""deliver_only=true + deliver=log is nonsense — reject at connect()."""
routes = {
"bad": {
"secret": _INSECURE_NO_AUTH,
"deliver": "log",
"deliver_only": True,
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
with pytest.raises(ValueError, match="deliver_only=true but deliver is 'log'"):
await adapter.connect()
@pytest.mark.asyncio
async def test_deliver_only_with_missing_deliver_rejected(self):
"""deliver_only=true with no deliver field defaults to 'log' → reject."""
routes = {
"bad": {
"secret": _INSECURE_NO_AUTH,
# no deliver field
"deliver_only": True,
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
with pytest.raises(ValueError, match="deliver_only=true"):
await adapter.connect()
@pytest.mark.asyncio
async def test_deliver_only_with_real_target_accepted(self):
"""Sanity check — a valid deliver_only config passes validation."""
routes = {
"good": {
"secret": _INSECURE_NO_AUTH,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "c-1"},
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
# connect() does more than validation (binds a socket) — we just
# want to verify the validation doesn't raise. Call it and tear
# down immediately.
try:
started = await adapter.connect()
if started:
await adapter.disconnect()
except ValueError:
pytest.fail("valid deliver_only config should not raise ValueError")
# ===================================================================
# Security + reliability invariants still hold
# ===================================================================
class TestDeliverOnlySecurityInvariants:
@pytest.mark.asyncio
async def test_hmac_still_enforced(self):
"""deliver_only does NOT bypass HMAC validation."""
secret = "real-secret-123"
routes = {
"r": {
"secret": secret,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "c-1"},
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
mock_target = _wire_mock_target(adapter)
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
# No signature header → reject
resp = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": "d-noauth-1"},
)
assert resp.status == 401
# Target never called
mock_target.send.assert_not_awaited()
@pytest.mark.asyncio
async def test_idempotency_still_applies(self):
"""Same delivery_id posted twice → second is suppressed."""
routes = {
"r": {
"secret": _INSECURE_NO_AUTH,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "c-1"},
"prompt": "hi",
}
}
adapter = _make_adapter(routes)
mock_target = _wire_mock_target(adapter)
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
r1 = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": "dup-1"},
)
assert r1.status == 200
r2 = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": "dup-1"},
)
# Existing webhook adapter treats duplicates as 200 + status=duplicate
assert r2.status == 200
data = await r2.json()
assert data["status"] == "duplicate"
# Target was called exactly once
assert mock_target.send.await_count == 1
@pytest.mark.asyncio
async def test_rate_limit_still_applies(self):
"""Route-level rate limit caps deliver_only POSTs too."""
routes = {
"r": {
"secret": _INSECURE_NO_AUTH,
"deliver": "telegram",
"deliver_only": True,
"deliver_extra": {"chat_id": "c-1"},
"prompt": "hi",
}
}
adapter = _make_adapter(routes, rate_limit=2)
_wire_mock_target(adapter)
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
for i in range(2):
r = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": f"rl-{i}"},
)
assert r.status == 200
# Third within the window → 429
r3 = await cli.post(
"/webhooks/r",
json={},
headers={"X-GitHub-Delivery": "rl-3"},
)
assert r3.status == 429
# ===================================================================
# Unit: _direct_deliver dispatch
# ===================================================================
class TestDirectDeliverUnit:
@pytest.mark.asyncio
async def test_dispatches_to_cross_platform_for_messaging_targets(self):
adapter = _make_adapter({})
mock_target = _wire_mock_target(adapter, "telegram")
result = await adapter._direct_deliver(
"hello",
{"deliver": "telegram", "deliver_extra": {"chat_id": "c-1"}},
)
assert result.success is True
mock_target.send.assert_awaited_once_with(
"c-1", "hello", metadata=None
)
@pytest.mark.asyncio
async def test_dispatches_to_github_comment(self):
adapter = _make_adapter({})
with patch.object(
adapter, "_deliver_github_comment",
new=AsyncMock(return_value=SendResult(success=True)),
) as mock_gh:
result = await adapter._direct_deliver(
"review body",
{
"deliver": "github_comment",
"deliver_extra": {"repo": "org/r", "pr_number": "1"},
},
)
assert result.success is True
mock_gh.assert_awaited_once()

View file

@ -0,0 +1,289 @@
"""Test that HMAC signature validation happens BEFORE rate limiting.
This verifies the fix for bug #12544: invalid signature requests must NOT
consume rate-limit quota. Before the fix, rate limiting was applied before
signature validation, so an attacker could exhaust a victim's rate limit
with invalidly-signed requests and then make valid requests that get rejected
with 429.
The correct order is:
1. Read body
2. Validate HMAC signature (reject 401 if invalid)
3. Rate limit check (reject 429 if over limit)
4. Process the webhook
"""
import hashlib
import hmac
import json
import pytest
from aiohttp import web
from aiohttp.test_utils import TestClient, TestServer
from gateway.platforms.webhook import WebhookAdapter
from gateway.config import PlatformConfig
def _make_adapter(routes, rate_limit=5, **extra_kw) -> WebhookAdapter:
"""Create a WebhookAdapter with the given routes."""
extra = {
"host": "0.0.0.0",
"port": 0,
"routes": routes,
"rate_limit": rate_limit,
}
extra.update(extra_kw)
config = PlatformConfig(enabled=True, extra=extra)
return WebhookAdapter(config)
def _create_app(adapter: WebhookAdapter) -> web.Application:
"""Build the aiohttp Application from the adapter."""
app = web.Application()
app.router.add_get("/health", adapter._handle_health)
app.router.add_post("/webhooks/{route_name}", adapter._handle_webhook)
return app
def _github_signature(body: bytes, secret: str) -> str:
"""Compute X-Hub-Signature-256 for *body* using *secret*."""
return "sha256=" + hmac.new(
secret.encode(), body, hashlib.sha256
).hexdigest()
SIMPLE_PAYLOAD = {"event": "test", "data": "hello"}
class TestSignatureBeforeRateLimit:
"""Verify that invalid signatures do NOT consume rate limit quota."""
@pytest.mark.asyncio
async def test_invalid_signature_does_not_consume_rate_limit(self):
"""Send requests with invalid signatures up to the rate limit, then
send a valid-signed request and verify it succeeds.
BEFORE FIX: Invalid signatures consume the rate limit bucket, so
after 'rate_limit' bad requests the valid one would get 429.
AFTER FIX: Invalid signatures are rejected with 401 first (before
rate limiting), so the rate limit bucket is untouched. The valid
request after many bad ones still succeeds.
"""
secret = "test-secret-key"
route_name = "test-route"
routes = {
route_name: {
"secret": secret,
"events": ["push"],
"prompt": "Event: {event}",
"deliver": "log",
}
}
rate_limit = 5
adapter = _make_adapter(routes, rate_limit=rate_limit)
captured_events = []
async def _capture(event):
captured_events.append(event)
adapter.handle_message = _capture
app = _create_app(adapter)
body = json.dumps(SIMPLE_PAYLOAD).encode()
async with TestClient(TestServer(app)) as cli:
# First exhaust the rate limit with invalid signatures
for i in range(rate_limit):
resp = await cli.post(
f"/webhooks/{route_name}",
data=body,
headers={
"Content-Type": "application/json",
"X-GitHub-Event": "push",
"X-Hub-Signature-256": "sha256=invalid", # bad sig
"X-GitHub-Delivery": f"bad-{i}",
},
)
# Each invalid signature should be rejected with 401
assert resp.status == 401, (
f"Expected 401 for invalid signature, got {resp.status}"
)
# Now send a valid-signed request — it MUST succeed (202)
# BEFORE FIX: This would return 429 because the 5 bad requests
# consumed the rate limit bucket.
# AFTER FIX: Bad requests don't touch rate limiting, so valid
# request succeeds.
valid_sig = _github_signature(body, secret)
resp = await cli.post(
f"/webhooks/{route_name}",
data=body,
headers={
"Content-Type": "application/json",
"X-GitHub-Event": "push",
"X-Hub-Signature-256": valid_sig,
"X-GitHub-Delivery": "good-001",
},
)
assert resp.status == 202, (
f"Expected 202 for valid request after invalid signatures, "
f"got {resp.status}. Rate limit may have been consumed by "
f"invalid requests (bug #12544 not fixed)."
)
data = await resp.json()
assert data["status"] == "accepted"
# The valid event should have been captured
assert len(captured_events) == 1
@pytest.mark.asyncio
async def test_valid_signature_still_rate_limited(self):
"""Verify that VALID requests still respect rate limiting normally."""
secret = "test-secret-key"
route_name = "test-route"
routes = {
route_name: {
"secret": secret,
"events": ["push"],
"prompt": "Event: {event}",
"deliver": "log",
}
}
rate_limit = 3
adapter = _make_adapter(routes, rate_limit=rate_limit)
captured_events = []
async def _capture(event):
captured_events.append(event)
adapter.handle_message = _capture
app = _create_app(adapter)
body = json.dumps(SIMPLE_PAYLOAD).encode()
async with TestClient(TestServer(app)) as cli:
# Send 'rate_limit' valid requests — all should succeed
for i in range(rate_limit):
valid_sig = _github_signature(body, secret)
resp = await cli.post(
f"/webhooks/{route_name}",
data=body,
headers={
"Content-Type": "application/json",
"X-GitHub-Event": "push",
"X-Hub-Signature-256": valid_sig,
"X-GitHub-Delivery": f"good-{i}",
},
)
assert resp.status == 202
# The next valid request SHOULD be rate-limited
valid_sig = _github_signature(body, secret)
resp = await cli.post(
f"/webhooks/{route_name}",
data=body,
headers={
"Content-Type": "application/json",
"X-GitHub-Event": "push",
"X-Hub-Signature-256": valid_sig,
"X-GitHub-Delivery": "good-over-limit",
},
)
assert resp.status == 429, (
f"Expected 429 when exceeding rate limit with valid requests, "
f"got {resp.status}"
)
@pytest.mark.asyncio
async def test_mixed_valid_and_invalid_signatures(self):
"""Interleave invalid and valid requests. Only valid ones count
against the rate limit."""
secret = "test-secret-key"
route_name = "test-route"
routes = {
route_name: {
"secret": secret,
"events": ["push"],
"prompt": "Event: {event}",
"deliver": "log",
}
}
rate_limit = 3
adapter = _make_adapter(routes, rate_limit=rate_limit)
captured_events = []
async def _capture(event):
captured_events.append(event)
adapter.handle_message = _capture
app = _create_app(adapter)
body = json.dumps(SIMPLE_PAYLOAD).encode()
async with TestClient(TestServer(app)) as cli:
# Send 2 valid requests (should succeed)
for i in range(2):
valid_sig = _github_signature(body, secret)
resp = await cli.post(
f"/webhooks/{route_name}",
data=body,
headers={
"Content-Type": "application/json",
"X-GitHub-Event": "push",
"X-Hub-Signature-256": valid_sig,
"X-GitHub-Delivery": f"good-{i}",
},
)
assert resp.status == 202
# Send 10 invalid requests (should all get 401, not consume quota)
for i in range(10):
resp = await cli.post(
f"/webhooks/{route_name}",
data=body,
headers={
"Content-Type": "application/json",
"X-GitHub-Event": "push",
"X-Hub-Signature-256": "sha256=invalid",
"X-GitHub-Delivery": f"bad-{i}",
},
)
assert resp.status == 401
# One more valid request should STILL succeed (only 2 consumed)
valid_sig = _github_signature(body, secret)
resp = await cli.post(
f"/webhooks/{route_name}",
data=body,
headers={
"Content-Type": "application/json",
"X-GitHub-Event": "push",
"X-Hub-Signature-256": valid_sig,
"X-GitHub-Delivery": "good-3",
},
)
assert resp.status == 202, (
f"Expected 202 for 3rd valid request after many invalid ones, "
f"got {resp.status}"
)
# The 4th valid request should be rate-limited (2 + 2 = 4 = limit)
valid_sig = _github_signature(body, secret)
resp = await cli.post(
f"/webhooks/{route_name}",
data=body,
headers={
"Content-Type": "application/json",
"X-GitHub-Event": "push",
"X-Hub-Signature-256": valid_sig,
"X-GitHub-Delivery": "good-4",
},
)
assert resp.status == 429
assert len(captured_events) == 3

View file

@ -211,6 +211,30 @@ class TestFileHandleClosedOnError:
assert adapter._bridge_log_fh is None
class TestConnectCleanup:
"""Verify failure paths release the scoped session lock."""
@pytest.mark.asyncio
async def test_releases_lock_when_npm_install_fails(self):
adapter = _make_adapter()
def _path_exists(path_obj):
return not str(path_obj).endswith("node_modules")
install_result = MagicMock(returncode=1, stderr="install failed")
with patch("gateway.platforms.whatsapp.check_whatsapp_requirements", return_value=True), \
patch.object(Path, "exists", autospec=True, side_effect=_path_exists), \
patch("subprocess.run", return_value=install_result), \
patch("gateway.status.acquire_scoped_lock", return_value=(True, None)), \
patch("gateway.status.release_scoped_lock") as mock_release:
result = await adapter.connect()
assert result is False
mock_release.assert_called_once_with("whatsapp-session", str(adapter._session_path))
assert adapter._platform_lock_identity is None
class TestBridgeRuntimeFailure:
"""Verify runtime bridge death is surfaced as a fatal adapter error."""
@ -429,6 +453,33 @@ class TestKillPortProcess:
class TestHttpSessionLifecycle:
"""Verify persistent aiohttp.ClientSession is created and cleaned up."""
@pytest.mark.asyncio
async def test_disconnect_uses_taskkill_tree_on_windows(self):
"""Windows disconnect should target the bridge process tree, not just the parent PID."""
adapter = _make_adapter()
mock_proc = MagicMock()
mock_proc.pid = 12345
mock_proc.poll.side_effect = [0]
adapter._bridge_process = mock_proc
adapter._poll_task = None
adapter._http_session = None
adapter._running = True
adapter._session_lock_identity = None
with patch("gateway.platforms.whatsapp._IS_WINDOWS", True), \
patch("gateway.platforms.whatsapp.subprocess.run", return_value=MagicMock(returncode=0)) as mock_run, \
patch("gateway.platforms.whatsapp.asyncio.sleep", new_callable=AsyncMock):
await adapter.disconnect()
mock_run.assert_called_once_with(
["taskkill", "/PID", "12345", "/T"],
capture_output=True,
text=True,
timeout=10,
)
mock_proc.terminate.assert_not_called()
mock_proc.kill.assert_not_called()
@pytest.mark.asyncio
async def test_session_closed_on_disconnect(self):
"""disconnect() should close self._http_session."""

Some files were not shown because too many files have changed in this diff Show more