mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Merge branch 'main' into fix/tui-provider-resolution
This commit is contained in:
commit
ec374c0599
625 changed files with 68938 additions and 11055 deletions
170
tests/acp/test_approval_isolation.py
Normal file
170
tests/acp/test_approval_isolation.py
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
"""Tests for GHSA-96vc-wcxf-jjff and GHSA-qg5c-hvr5-hjgr.
|
||||
|
||||
Two related ACP approval-flow issues:
|
||||
- 96vc: ACP didn't set HERMES_EXEC_ASK, so `check_all_command_guards`
|
||||
took the non-interactive auto-approve path and never consulted the
|
||||
ACP-supplied callback.
|
||||
- qg5c: `_approval_callback` was a module-global in terminal_tool;
|
||||
overlapping ACP sessions overwrote each other's callback slot.
|
||||
|
||||
Both fixed together by:
|
||||
1. Setting HERMES_EXEC_ASK inside _run_agent (wraps the agent call).
|
||||
2. Storing the callback in thread-local state so concurrent executor
|
||||
threads don't collide.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestThreadLocalApprovalCallback:
|
||||
"""GHSA-qg5c-hvr5-hjgr: set_approval_callback must be per-thread so
|
||||
concurrent ACP sessions don't stomp on each other's handlers."""
|
||||
|
||||
def test_set_and_get_in_same_thread(self):
|
||||
from tools.terminal_tool import (
|
||||
set_approval_callback,
|
||||
_get_approval_callback,
|
||||
)
|
||||
|
||||
cb1 = lambda cmd, desc: "once" # noqa: E731
|
||||
set_approval_callback(cb1)
|
||||
assert _get_approval_callback() is cb1
|
||||
|
||||
def test_callback_not_visible_in_different_thread(self):
|
||||
"""Thread A's callback is NOT visible to Thread B."""
|
||||
from tools.terminal_tool import (
|
||||
set_approval_callback,
|
||||
_get_approval_callback,
|
||||
)
|
||||
|
||||
cb_a = lambda cmd, desc: "thread_a" # noqa: E731
|
||||
cb_b = lambda cmd, desc: "thread_b" # noqa: E731
|
||||
|
||||
seen_in_a = []
|
||||
seen_in_b = []
|
||||
|
||||
def thread_a():
|
||||
set_approval_callback(cb_a)
|
||||
# Pause so thread B has time to set its own callback
|
||||
import time
|
||||
time.sleep(0.05)
|
||||
seen_in_a.append(_get_approval_callback())
|
||||
|
||||
def thread_b():
|
||||
set_approval_callback(cb_b)
|
||||
import time
|
||||
time.sleep(0.05)
|
||||
seen_in_b.append(_get_approval_callback())
|
||||
|
||||
ta = threading.Thread(target=thread_a)
|
||||
tb = threading.Thread(target=thread_b)
|
||||
ta.start()
|
||||
tb.start()
|
||||
ta.join()
|
||||
tb.join()
|
||||
|
||||
# Each thread must see ONLY its own callback — not the other's
|
||||
assert seen_in_a == [cb_a]
|
||||
assert seen_in_b == [cb_b]
|
||||
|
||||
def test_main_thread_callback_not_leaked_to_worker(self):
|
||||
"""A callback set in the main thread does NOT leak into a
|
||||
freshly-spawned worker thread."""
|
||||
from tools.terminal_tool import (
|
||||
set_approval_callback,
|
||||
_get_approval_callback,
|
||||
)
|
||||
|
||||
cb_main = lambda cmd, desc: "main" # noqa: E731
|
||||
set_approval_callback(cb_main)
|
||||
|
||||
worker_saw = []
|
||||
|
||||
def worker():
|
||||
worker_saw.append(_get_approval_callback())
|
||||
|
||||
t = threading.Thread(target=worker)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
# Worker thread has no callback set — TLS is empty for it
|
||||
assert worker_saw == [None]
|
||||
# Main thread still has its callback
|
||||
assert _get_approval_callback() is cb_main
|
||||
|
||||
def test_sudo_password_callback_also_thread_local(self):
|
||||
"""Same protection applies to the sudo password callback."""
|
||||
from tools.terminal_tool import (
|
||||
set_sudo_password_callback,
|
||||
_get_sudo_password_callback,
|
||||
)
|
||||
|
||||
cb_main = lambda: "main-password" # noqa: E731
|
||||
set_sudo_password_callback(cb_main)
|
||||
|
||||
worker_saw = []
|
||||
|
||||
def worker():
|
||||
worker_saw.append(_get_sudo_password_callback())
|
||||
|
||||
t = threading.Thread(target=worker)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
assert worker_saw == [None]
|
||||
assert _get_sudo_password_callback() is cb_main
|
||||
|
||||
|
||||
class TestAcpExecAskGate:
|
||||
"""GHSA-96vc-wcxf-jjff: ACP's _run_agent must set HERMES_INTERACTIVE so
|
||||
that tools.approval.check_all_command_guards takes the CLI-interactive
|
||||
path (consults the registered callback via prompt_dangerous_approval)
|
||||
instead of the non-interactive auto-approve shortcut.
|
||||
|
||||
(HERMES_EXEC_ASK takes the gateway-queue path which requires a
|
||||
notify_cb registered in _gateway_notify_cbs — not applicable to ACP,
|
||||
which uses a direct callback shape.)"""
|
||||
|
||||
def test_interactive_env_var_routes_to_callback(self, monkeypatch):
|
||||
"""When HERMES_INTERACTIVE is set and an approval callback is
|
||||
registered, a dangerous command must route through the callback."""
|
||||
# Clean env
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
|
||||
monkeypatch.delenv("HERMES_EXEC_ASK", raising=False)
|
||||
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
|
||||
|
||||
from tools.approval import check_all_command_guards
|
||||
|
||||
called_with = []
|
||||
|
||||
def fake_cb(command, description, *, allow_permanent=True):
|
||||
called_with.append((command, description))
|
||||
return "once"
|
||||
|
||||
# Without HERMES_INTERACTIVE: takes auto-approve path, callback NOT called
|
||||
result = check_all_command_guards(
|
||||
"rm -rf /tmp/test-exec-ask", "local", approval_callback=fake_cb,
|
||||
)
|
||||
assert result["approved"] is True
|
||||
assert called_with == [], (
|
||||
"without HERMES_INTERACTIVE the non-interactive auto-approve "
|
||||
"path should fire without consulting the callback"
|
||||
)
|
||||
|
||||
# With HERMES_INTERACTIVE: callback IS called, approval flows through it
|
||||
monkeypatch.setenv("HERMES_INTERACTIVE", "1")
|
||||
called_with.clear()
|
||||
result = check_all_command_guards(
|
||||
"rm -rf /tmp/test-exec-ask", "local", approval_callback=fake_cb,
|
||||
)
|
||||
assert called_with, (
|
||||
"with HERMES_INTERACTIVE the approval path should consult the "
|
||||
"registered callback — this was the ACP bypass in "
|
||||
"GHSA-96vc-wcxf-jjff"
|
||||
)
|
||||
assert result["approved"] is True
|
||||
|
|
@ -73,3 +73,17 @@ class TestApprovalMapping:
|
|||
result = cb("rm -rf /", "dangerous")
|
||||
|
||||
assert result == "deny"
|
||||
|
||||
def test_approval_none_response_returns_deny(self):
|
||||
"""When request_permission resolves to None, the callback should return 'deny'."""
|
||||
loop = MagicMock(spec=asyncio.AbstractEventLoop)
|
||||
mock_rp = MagicMock(name="request_permission")
|
||||
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
|
||||
with patch("acp_adapter.permissions.asyncio.run_coroutine_threadsafe", return_value=future):
|
||||
cb = make_approval_callback(mock_rp, loop, session_id="s1", timeout=1.0)
|
||||
result = cb("echo hi", "demo")
|
||||
|
||||
assert result == "deny"
|
||||
|
|
|
|||
210
tests/acp/test_ping_suppression.py
Normal file
210
tests/acp/test_ping_suppression.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
"""Tests for acp_adapter.entry._BenignProbeMethodFilter.
|
||||
|
||||
Covers both the isolated filter logic and the full end-to-end path where a
|
||||
client sends a bare JSON-RPC ``ping`` request over stdio and the acp runtime
|
||||
surfaces the resulting ``RequestError`` via ``logging.exception("Background
|
||||
task failed", ...)``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from io import StringIO
|
||||
|
||||
import pytest
|
||||
|
||||
from acp.exceptions import RequestError
|
||||
|
||||
from acp_adapter.entry import _BenignProbeMethodFilter
|
||||
|
||||
|
||||
# -- Unit tests on the filter itself ----------------------------------------
|
||||
|
||||
|
||||
def _make_record(msg: str, exc: BaseException | None) -> logging.LogRecord:
|
||||
record = logging.LogRecord(
|
||||
name="root",
|
||||
level=logging.ERROR,
|
||||
pathname=__file__,
|
||||
lineno=0,
|
||||
msg=msg,
|
||||
args=(),
|
||||
exc_info=(type(exc), exc, exc.__traceback__) if exc else None,
|
||||
)
|
||||
return record
|
||||
|
||||
|
||||
def _bake_tb(exc: BaseException) -> BaseException:
|
||||
try:
|
||||
raise exc
|
||||
except BaseException as e: # noqa: BLE001
|
||||
return e
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", ["ping", "health", "healthcheck"])
|
||||
def test_filter_suppresses_benign_probe(method: str) -> None:
|
||||
f = _BenignProbeMethodFilter()
|
||||
exc = _bake_tb(RequestError.method_not_found(method))
|
||||
record = _make_record("Background task failed", exc)
|
||||
assert f.filter(record) is False
|
||||
|
||||
|
||||
def test_filter_allows_real_method_not_found() -> None:
|
||||
f = _BenignProbeMethodFilter()
|
||||
exc = _bake_tb(RequestError.method_not_found("session/custom"))
|
||||
record = _make_record("Background task failed", exc)
|
||||
assert f.filter(record) is True
|
||||
|
||||
|
||||
def test_filter_allows_non_request_error() -> None:
|
||||
f = _BenignProbeMethodFilter()
|
||||
exc = _bake_tb(RuntimeError("boom"))
|
||||
record = _make_record("Background task failed", exc)
|
||||
assert f.filter(record) is True
|
||||
|
||||
|
||||
def test_filter_allows_different_message_even_for_ping() -> None:
|
||||
"""Only 'Background task failed' is muted — other messages pass through."""
|
||||
f = _BenignProbeMethodFilter()
|
||||
exc = _bake_tb(RequestError.method_not_found("ping"))
|
||||
record = _make_record("Some other context", exc)
|
||||
assert f.filter(record) is True
|
||||
|
||||
|
||||
def test_filter_allows_request_error_with_different_code() -> None:
|
||||
f = _BenignProbeMethodFilter()
|
||||
exc = _bake_tb(RequestError.invalid_params({"method": "ping"}))
|
||||
record = _make_record("Background task failed", exc)
|
||||
assert f.filter(record) is True
|
||||
|
||||
|
||||
def test_filter_allows_log_without_exc_info() -> None:
|
||||
f = _BenignProbeMethodFilter()
|
||||
record = _make_record("Background task failed", None)
|
||||
assert f.filter(record) is True
|
||||
|
||||
|
||||
# -- End-to-end: drive a real JSON-RPC `ping` through acp.run_agent ---------
|
||||
|
||||
|
||||
class _FakeAgent:
|
||||
"""Minimal acp.Agent stub — we only need the router to build."""
|
||||
|
||||
async def initialize(self, **kwargs): # noqa: ANN003
|
||||
from acp.schema import AgentCapabilities, InitializeResponse
|
||||
|
||||
return InitializeResponse(protocol_version=1, agent_capabilities=AgentCapabilities())
|
||||
|
||||
async def new_session(self, cwd, mcp_servers=None, **kwargs): # noqa: ANN001, ANN003
|
||||
from acp.schema import NewSessionResponse
|
||||
|
||||
return NewSessionResponse(session_id="test")
|
||||
|
||||
async def prompt(self, session_id, prompt, **kwargs): # noqa: ANN001, ANN003
|
||||
from acp.schema import PromptResponse
|
||||
|
||||
return PromptResponse(stop_reason="end_turn")
|
||||
|
||||
async def cancel(self, session_id, **kwargs): # noqa: ANN001, ANN003
|
||||
pass
|
||||
|
||||
async def authenticate(self, **kwargs): # noqa: ANN003
|
||||
pass
|
||||
|
||||
def on_connect(self, conn): # noqa: ANN001
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ping_request_produces_proper_response_and_no_stderr_noise(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""A bare ``ping`` must get a JSON-RPC -32601 back AND leave stderr clean
|
||||
when the filter is installed on the handler.
|
||||
"""
|
||||
import acp
|
||||
|
||||
# Attach the filter to a fresh stream handler that mirrors entry._setup_logging.
|
||||
stream = StringIO()
|
||||
handler = logging.StreamHandler(stream)
|
||||
handler.setFormatter(logging.Formatter("%(name)s|%(levelname)s|%(message)s"))
|
||||
handler.addFilter(_BenignProbeMethodFilter())
|
||||
root = logging.getLogger()
|
||||
prior_handlers = root.handlers[:]
|
||||
prior_level = root.level
|
||||
root.handlers = [handler]
|
||||
root.setLevel(logging.INFO)
|
||||
# Also suppress propagation of caplog's default handler interfering with
|
||||
# our stream (caplog still captures via its own propagation hook).
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Pipe client -> agent
|
||||
client_to_agent_r, client_to_agent_w = os.pipe()
|
||||
# Pipe agent -> client
|
||||
agent_to_client_r, agent_to_client_w = os.pipe()
|
||||
|
||||
in_read_file = os.fdopen(client_to_agent_r, "rb", buffering=0)
|
||||
in_write_file = os.fdopen(client_to_agent_w, "wb", buffering=0)
|
||||
out_read_file = os.fdopen(agent_to_client_r, "rb", buffering=0)
|
||||
out_write_file = os.fdopen(agent_to_client_w, "wb", buffering=0)
|
||||
|
||||
# Agent reads its input from this StreamReader:
|
||||
agent_input = asyncio.StreamReader(limit=1024 * 1024, loop=loop)
|
||||
agent_input_proto = asyncio.StreamReaderProtocol(agent_input, loop=loop)
|
||||
await loop.connect_read_pipe(lambda: agent_input_proto, in_read_file)
|
||||
|
||||
# Agent writes its output via this StreamWriter:
|
||||
out_transport, out_protocol = await loop.connect_write_pipe(
|
||||
asyncio.streams.FlowControlMixin, out_write_file
|
||||
)
|
||||
agent_output = asyncio.StreamWriter(out_transport, out_protocol, None, loop)
|
||||
|
||||
# Test harness reads agent output via this StreamReader:
|
||||
client_input = asyncio.StreamReader(limit=1024 * 1024, loop=loop)
|
||||
client_input_proto = asyncio.StreamReaderProtocol(client_input, loop=loop)
|
||||
await loop.connect_read_pipe(lambda: client_input_proto, out_read_file)
|
||||
|
||||
agent_task = asyncio.create_task(
|
||||
acp.run_agent(
|
||||
_FakeAgent(),
|
||||
input_stream=agent_output,
|
||||
output_stream=agent_input,
|
||||
use_unstable_protocol=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Send a bare `ping`
|
||||
request = {"jsonrpc": "2.0", "id": 1, "method": "ping", "params": {}}
|
||||
in_write_file.write((json.dumps(request) + "\n").encode())
|
||||
in_write_file.flush()
|
||||
|
||||
response_line = await asyncio.wait_for(client_input.readline(), timeout=5.0)
|
||||
# Give the supervisor task a tick to fire (filter should eat it)
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
response = json.loads(response_line.decode())
|
||||
assert response["error"]["code"] == -32601, response
|
||||
assert response["error"]["data"] == {"method": "ping"}, response
|
||||
|
||||
logs = stream.getvalue()
|
||||
assert "Background task failed" not in logs, (
|
||||
f"ping noise leaked to stderr:\n{logs}"
|
||||
)
|
||||
|
||||
# Clean shutdown
|
||||
in_write_file.close()
|
||||
try:
|
||||
await asyncio.wait_for(agent_task, timeout=2.0)
|
||||
except (asyncio.TimeoutError, Exception):
|
||||
agent_task.cancel()
|
||||
try:
|
||||
await agent_task
|
||||
except BaseException: # noqa: BLE001
|
||||
pass
|
||||
finally:
|
||||
root.handlers = prior_handlers
|
||||
root.setLevel(prior_level)
|
||||
|
|
@ -95,19 +95,37 @@ class TestInitialize:
|
|||
|
||||
class TestAuthenticate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_with_provider_configured(self, agent, monkeypatch):
|
||||
async def test_authenticate_with_matching_method_id(self, agent, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"acp_adapter.server.has_provider",
|
||||
lambda: True,
|
||||
"acp_adapter.server.detect_provider",
|
||||
lambda: "openrouter",
|
||||
)
|
||||
resp = await agent.authenticate(method_id="openrouter")
|
||||
assert isinstance(resp, AuthenticateResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_is_case_insensitive(self, agent, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"acp_adapter.server.detect_provider",
|
||||
lambda: "openrouter",
|
||||
)
|
||||
resp = await agent.authenticate(method_id="OpenRouter")
|
||||
assert isinstance(resp, AuthenticateResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_rejects_mismatched_method_id(self, agent, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"acp_adapter.server.detect_provider",
|
||||
lambda: "openrouter",
|
||||
)
|
||||
resp = await agent.authenticate(method_id="totally-invalid-method")
|
||||
assert resp is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_without_provider(self, agent, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"acp_adapter.server.has_provider",
|
||||
lambda: False,
|
||||
"acp_adapter.server.detect_provider",
|
||||
lambda: None,
|
||||
)
|
||||
resp = await agent.authenticate(method_id="openrouter")
|
||||
assert resp is None
|
||||
|
|
@ -252,6 +270,57 @@ class TestListAndFork:
|
|||
|
||||
mock_list.assert_called_once_with(cwd="/mnt/e/Projects/AI/browser-link-3")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_pagination_first_page(self, agent):
|
||||
from acp_adapter import server as acp_server
|
||||
|
||||
infos = [
|
||||
{"session_id": f"s{i}", "cwd": "/tmp", "title": None, "updated_at": 0.0}
|
||||
for i in range(acp_server._LIST_SESSIONS_PAGE_SIZE + 5)
|
||||
]
|
||||
with patch.object(agent.session_manager, "list_sessions", return_value=infos):
|
||||
resp = await agent.list_sessions()
|
||||
|
||||
assert len(resp.sessions) == acp_server._LIST_SESSIONS_PAGE_SIZE
|
||||
assert resp.next_cursor == resp.sessions[-1].session_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_pagination_no_more(self, agent):
|
||||
infos = [
|
||||
{"session_id": f"s{i}", "cwd": "/tmp", "title": None, "updated_at": 0.0}
|
||||
for i in range(3)
|
||||
]
|
||||
with patch.object(agent.session_manager, "list_sessions", return_value=infos):
|
||||
resp = await agent.list_sessions()
|
||||
|
||||
assert len(resp.sessions) == 3
|
||||
assert resp.next_cursor is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_cursor_resumes_after_match(self, agent):
|
||||
infos = [
|
||||
{"session_id": "s1", "cwd": "/tmp", "title": None, "updated_at": 0.0},
|
||||
{"session_id": "s2", "cwd": "/tmp", "title": None, "updated_at": 0.0},
|
||||
{"session_id": "s3", "cwd": "/tmp", "title": None, "updated_at": 0.0},
|
||||
]
|
||||
with patch.object(agent.session_manager, "list_sessions", return_value=infos):
|
||||
resp = await agent.list_sessions(cursor="s1")
|
||||
|
||||
assert [s.session_id for s in resp.sessions] == ["s2", "s3"]
|
||||
assert resp.next_cursor is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_unknown_cursor_returns_empty(self, agent):
|
||||
infos = [
|
||||
{"session_id": "s1", "cwd": "/tmp", "title": None, "updated_at": 0.0},
|
||||
{"session_id": "s2", "cwd": "/tmp", "title": None, "updated_at": 0.0},
|
||||
]
|
||||
with patch.object(agent.session_manager, "list_sessions", return_value=infos):
|
||||
resp = await agent.list_sessions(cursor="does-not-exist")
|
||||
|
||||
assert resp.sessions == []
|
||||
assert resp.next_cursor is None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session configuration / model routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -414,7 +414,11 @@ class TestRunOauthSetupToken:
|
|||
token = run_oauth_setup_token()
|
||||
|
||||
assert token == "from-cred-file"
|
||||
mock_run.assert_called_once()
|
||||
# Don't assert exact call count — the contract is "credentials flow
|
||||
# through", not "exactly one subprocess call". xdist cross-test
|
||||
# pollution (other tests shimming subprocess via plugins) has flaked
|
||||
# assert_called_once() in CI.
|
||||
assert mock_run.called
|
||||
|
||||
def test_returns_token_from_env_var(self, monkeypatch, tmp_path):
|
||||
"""Falls back to CLAUDE_CODE_OAUTH_TOKEN env var when no cred files."""
|
||||
|
|
|
|||
238
tests/agent/test_anthropic_normalize_v2.py
Normal file
238
tests/agent/test_anthropic_normalize_v2.py
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
"""Regression tests: normalize_anthropic_response_v2 vs v1.
|
||||
|
||||
Constructs mock Anthropic responses and asserts that the v2 function
|
||||
(returning NormalizedResponse) produces identical field values to the
|
||||
original v1 function (returning SimpleNamespace + finish_reason).
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from agent.anthropic_adapter import (
|
||||
normalize_anthropic_response,
|
||||
normalize_anthropic_response_v2,
|
||||
)
|
||||
from agent.transports.types import NormalizedResponse, ToolCall
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers to build mock Anthropic SDK responses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _text_block(text: str):
|
||||
return SimpleNamespace(type="text", text=text)
|
||||
|
||||
|
||||
def _thinking_block(thinking: str, signature: str = "sig_abc"):
|
||||
return SimpleNamespace(type="thinking", thinking=thinking, signature=signature)
|
||||
|
||||
|
||||
def _tool_use_block(id: str, name: str, input: dict):
|
||||
return SimpleNamespace(type="tool_use", id=id, name=name, input=input)
|
||||
|
||||
|
||||
def _response(content_blocks, stop_reason="end_turn"):
|
||||
return SimpleNamespace(
|
||||
content=content_blocks,
|
||||
stop_reason=stop_reason,
|
||||
usage=SimpleNamespace(
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTextOnly:
|
||||
"""Text-only response — no tools, no thinking."""
|
||||
|
||||
def setup_method(self):
|
||||
self.resp = _response([_text_block("Hello world")])
|
||||
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
|
||||
self.v2 = normalize_anthropic_response_v2(self.resp)
|
||||
|
||||
def test_type(self):
|
||||
assert isinstance(self.v2, NormalizedResponse)
|
||||
|
||||
def test_content_matches(self):
|
||||
assert self.v2.content == self.v1_msg.content
|
||||
|
||||
def test_finish_reason_matches(self):
|
||||
assert self.v2.finish_reason == self.v1_finish
|
||||
|
||||
def test_no_tool_calls(self):
|
||||
assert self.v2.tool_calls is None
|
||||
assert self.v1_msg.tool_calls is None
|
||||
|
||||
def test_no_reasoning(self):
|
||||
assert self.v2.reasoning is None
|
||||
assert self.v1_msg.reasoning is None
|
||||
|
||||
|
||||
class TestWithToolCalls:
|
||||
"""Response with tool calls."""
|
||||
|
||||
def setup_method(self):
|
||||
self.resp = _response(
|
||||
[
|
||||
_text_block("I'll check that"),
|
||||
_tool_use_block("toolu_abc", "terminal", {"command": "ls"}),
|
||||
_tool_use_block("toolu_def", "read_file", {"path": "/tmp"}),
|
||||
],
|
||||
stop_reason="tool_use",
|
||||
)
|
||||
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
|
||||
self.v2 = normalize_anthropic_response_v2(self.resp)
|
||||
|
||||
def test_finish_reason(self):
|
||||
assert self.v2.finish_reason == "tool_calls"
|
||||
assert self.v1_finish == "tool_calls"
|
||||
|
||||
def test_tool_call_count(self):
|
||||
assert len(self.v2.tool_calls) == 2
|
||||
assert len(self.v1_msg.tool_calls) == 2
|
||||
|
||||
def test_tool_call_ids_match(self):
|
||||
for i in range(2):
|
||||
assert self.v2.tool_calls[i].id == self.v1_msg.tool_calls[i].id
|
||||
|
||||
def test_tool_call_names_match(self):
|
||||
assert self.v2.tool_calls[0].name == "terminal"
|
||||
assert self.v2.tool_calls[1].name == "read_file"
|
||||
for i in range(2):
|
||||
assert self.v2.tool_calls[i].name == self.v1_msg.tool_calls[i].function.name
|
||||
|
||||
def test_tool_call_arguments_match(self):
|
||||
for i in range(2):
|
||||
assert self.v2.tool_calls[i].arguments == self.v1_msg.tool_calls[i].function.arguments
|
||||
|
||||
def test_content_preserved(self):
|
||||
assert self.v2.content == self.v1_msg.content
|
||||
assert "check that" in self.v2.content
|
||||
|
||||
|
||||
class TestWithThinking:
|
||||
"""Response with thinking blocks (Claude 3.5+ extended thinking)."""
|
||||
|
||||
def setup_method(self):
|
||||
self.resp = _response([
|
||||
_thinking_block("Let me think about this carefully..."),
|
||||
_text_block("The answer is 42."),
|
||||
])
|
||||
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
|
||||
self.v2 = normalize_anthropic_response_v2(self.resp)
|
||||
|
||||
def test_reasoning_matches(self):
|
||||
assert self.v2.reasoning == self.v1_msg.reasoning
|
||||
assert "think about this" in self.v2.reasoning
|
||||
|
||||
def test_reasoning_details_in_provider_data(self):
|
||||
v1_details = self.v1_msg.reasoning_details
|
||||
v2_details = self.v2.provider_data.get("reasoning_details") if self.v2.provider_data else None
|
||||
assert v1_details is not None
|
||||
assert v2_details is not None
|
||||
assert len(v2_details) == len(v1_details)
|
||||
|
||||
def test_content_excludes_thinking(self):
|
||||
assert self.v2.content == "The answer is 42."
|
||||
|
||||
|
||||
class TestMixed:
|
||||
"""Response with thinking + text + tool calls."""
|
||||
|
||||
def setup_method(self):
|
||||
self.resp = _response(
|
||||
[
|
||||
_thinking_block("Planning my approach..."),
|
||||
_text_block("I'll run the command"),
|
||||
_tool_use_block("toolu_xyz", "terminal", {"command": "pwd"}),
|
||||
],
|
||||
stop_reason="tool_use",
|
||||
)
|
||||
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
|
||||
self.v2 = normalize_anthropic_response_v2(self.resp)
|
||||
|
||||
def test_all_fields_present(self):
|
||||
assert self.v2.content is not None
|
||||
assert self.v2.tool_calls is not None
|
||||
assert self.v2.reasoning is not None
|
||||
assert self.v2.finish_reason == "tool_calls"
|
||||
|
||||
def test_content_matches(self):
|
||||
assert self.v2.content == self.v1_msg.content
|
||||
|
||||
def test_reasoning_matches(self):
|
||||
assert self.v2.reasoning == self.v1_msg.reasoning
|
||||
|
||||
def test_tool_call_matches(self):
|
||||
assert self.v2.tool_calls[0].id == self.v1_msg.tool_calls[0].id
|
||||
assert self.v2.tool_calls[0].name == self.v1_msg.tool_calls[0].function.name
|
||||
|
||||
|
||||
class TestStopReasons:
|
||||
"""Verify finish_reason mapping matches between v1 and v2."""
|
||||
|
||||
@pytest.mark.parametrize("stop_reason,expected", [
|
||||
("end_turn", "stop"),
|
||||
("tool_use", "tool_calls"),
|
||||
("max_tokens", "length"),
|
||||
("stop_sequence", "stop"),
|
||||
("refusal", "content_filter"),
|
||||
("model_context_window_exceeded", "length"),
|
||||
("unknown_future_reason", "stop"),
|
||||
])
|
||||
def test_stop_reason_mapping(self, stop_reason, expected):
|
||||
resp = _response([_text_block("x")], stop_reason=stop_reason)
|
||||
v1_msg, v1_finish = normalize_anthropic_response(resp)
|
||||
v2 = normalize_anthropic_response_v2(resp)
|
||||
assert v2.finish_reason == v1_finish == expected
|
||||
|
||||
|
||||
class TestStripToolPrefix:
|
||||
"""Verify mcp_ prefix stripping works identically."""
|
||||
|
||||
def test_prefix_stripped(self):
|
||||
resp = _response(
|
||||
[_tool_use_block("toolu_1", "mcp_terminal", {"cmd": "ls"})],
|
||||
stop_reason="tool_use",
|
||||
)
|
||||
v1_msg, _ = normalize_anthropic_response(resp, strip_tool_prefix=True)
|
||||
v2 = normalize_anthropic_response_v2(resp, strip_tool_prefix=True)
|
||||
assert v1_msg.tool_calls[0].function.name == "terminal"
|
||||
assert v2.tool_calls[0].name == "terminal"
|
||||
|
||||
def test_prefix_kept(self):
|
||||
resp = _response(
|
||||
[_tool_use_block("toolu_1", "mcp_terminal", {"cmd": "ls"})],
|
||||
stop_reason="tool_use",
|
||||
)
|
||||
v1_msg, _ = normalize_anthropic_response(resp, strip_tool_prefix=False)
|
||||
v2 = normalize_anthropic_response_v2(resp, strip_tool_prefix=False)
|
||||
assert v1_msg.tool_calls[0].function.name == "mcp_terminal"
|
||||
assert v2.tool_calls[0].name == "mcp_terminal"
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Edge cases: empty content, no blocks, etc."""
|
||||
|
||||
def test_empty_content_blocks(self):
|
||||
resp = _response([])
|
||||
v1_msg, v1_finish = normalize_anthropic_response(resp)
|
||||
v2 = normalize_anthropic_response_v2(resp)
|
||||
assert v2.content == v1_msg.content
|
||||
assert v2.content is None
|
||||
|
||||
def test_no_reasoning_details_means_none_provider_data(self):
|
||||
resp = _response([_text_block("hi")])
|
||||
v2 = normalize_anthropic_response_v2(resp)
|
||||
assert v2.provider_data is None
|
||||
|
||||
def test_v2_returns_dataclass_not_namespace(self):
|
||||
resp = _response([_text_block("hi")])
|
||||
v2 = normalize_anthropic_response_v2(resp)
|
||||
assert isinstance(v2, NormalizedResponse)
|
||||
assert not isinstance(v2, SimpleNamespace)
|
||||
|
|
@ -476,6 +476,133 @@ class TestGetTextAuxiliaryClient:
|
|||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
|
||||
class TestNousAuxiliaryRefresh:
|
||||
def test_try_nous_prefers_runtime_credentials(self):
|
||||
fresh_base = "https://inference-api.nousresearch.com/v1"
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value={"access_token": "stale-token"}),
|
||||
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", fresh_base)),
|
||||
patch("hermes_cli.models.get_nous_recommended_aux_model", return_value=None),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
from agent.auxiliary_client import _try_nous
|
||||
|
||||
mock_openai.return_value = MagicMock()
|
||||
client, model = _try_nous()
|
||||
|
||||
assert client is not None
|
||||
# No Portal recommendation → falls back to the hardcoded default.
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "fresh-agent-key"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == fresh_base
|
||||
|
||||
def test_try_nous_uses_portal_recommendation_for_text(self):
|
||||
"""When the Portal recommends a compaction model, _try_nous honors it."""
|
||||
fresh_base = "https://inference-api.nousresearch.com/v1"
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value={"access_token": "***"}),
|
||||
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", fresh_base)),
|
||||
patch("hermes_cli.models.get_nous_recommended_aux_model", return_value="minimax/minimax-m2.7") as mock_rec,
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
from agent.auxiliary_client import _try_nous
|
||||
|
||||
mock_openai.return_value = MagicMock()
|
||||
client, model = _try_nous(vision=False)
|
||||
|
||||
assert client is not None
|
||||
assert model == "minimax/minimax-m2.7"
|
||||
assert mock_rec.call_args.kwargs["vision"] is False
|
||||
|
||||
def test_try_nous_uses_portal_recommendation_for_vision(self):
|
||||
"""Vision tasks should ask for the vision-specific recommendation."""
|
||||
fresh_base = "https://inference-api.nousresearch.com/v1"
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value={"access_token": "***"}),
|
||||
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", fresh_base)),
|
||||
patch("hermes_cli.models.get_nous_recommended_aux_model", return_value="google/gemini-3-flash-preview") as mock_rec,
|
||||
patch("agent.auxiliary_client.OpenAI"),
|
||||
):
|
||||
from agent.auxiliary_client import _try_nous
|
||||
client, model = _try_nous(vision=True)
|
||||
|
||||
assert client is not None
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert mock_rec.call_args.kwargs["vision"] is True
|
||||
|
||||
def test_try_nous_falls_back_when_recommendation_lookup_raises(self):
|
||||
"""If the Portal lookup throws, we must still return a usable model."""
|
||||
fresh_base = "https://inference-api.nousresearch.com/v1"
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value={"access_token": "***"}),
|
||||
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", fresh_base)),
|
||||
patch("hermes_cli.models.get_nous_recommended_aux_model", side_effect=RuntimeError("portal down")),
|
||||
patch("agent.auxiliary_client.OpenAI"),
|
||||
):
|
||||
from agent.auxiliary_client import _try_nous
|
||||
client, model = _try_nous()
|
||||
|
||||
assert client is not None
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
|
||||
def test_call_llm_retries_nous_after_401(self):
|
||||
class _Auth401(Exception):
|
||||
status_code = 401
|
||||
|
||||
stale_client = MagicMock()
|
||||
stale_client.base_url = "https://inference-api.nousresearch.com/v1"
|
||||
stale_client.chat.completions.create.side_effect = _Auth401("stale nous key")
|
||||
|
||||
fresh_client = MagicMock()
|
||||
fresh_client.base_url = "https://inference-api.nousresearch.com/v1"
|
||||
fresh_client.chat.completions.create.return_value = {"ok": True}
|
||||
|
||||
with (
|
||||
patch("agent.auxiliary_client._resolve_task_provider_model", return_value=("nous", "nous-model", None, None, None)),
|
||||
patch("agent.auxiliary_client._get_cached_client", return_value=(stale_client, "nous-model")),
|
||||
patch("agent.auxiliary_client.OpenAI", return_value=fresh_client),
|
||||
patch("agent.auxiliary_client._validate_llm_response", side_effect=lambda resp, _task: resp),
|
||||
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", "https://inference-api.nousresearch.com/v1")),
|
||||
):
|
||||
result = call_llm(
|
||||
task="compression",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
||||
assert result == {"ok": True}
|
||||
assert stale_client.chat.completions.create.call_count == 1
|
||||
assert fresh_client.chat.completions.create.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_call_llm_retries_nous_after_401(self):
|
||||
class _Auth401(Exception):
|
||||
status_code = 401
|
||||
|
||||
stale_client = MagicMock()
|
||||
stale_client.base_url = "https://inference-api.nousresearch.com/v1"
|
||||
stale_client.chat.completions.create = AsyncMock(side_effect=_Auth401("stale nous key"))
|
||||
|
||||
fresh_async_client = MagicMock()
|
||||
fresh_async_client.base_url = "https://inference-api.nousresearch.com/v1"
|
||||
fresh_async_client.chat.completions.create = AsyncMock(return_value={"ok": True})
|
||||
|
||||
with (
|
||||
patch("agent.auxiliary_client._resolve_task_provider_model", return_value=("nous", "nous-model", None, None, None)),
|
||||
patch("agent.auxiliary_client._get_cached_client", return_value=(stale_client, "nous-model")),
|
||||
patch("agent.auxiliary_client._to_async_client", return_value=(fresh_async_client, "nous-model")),
|
||||
patch("agent.auxiliary_client._validate_llm_response", side_effect=lambda resp, _task: resp),
|
||||
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", "https://inference-api.nousresearch.com/v1")),
|
||||
):
|
||||
result = await async_call_llm(
|
||||
task="session_search",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
||||
assert result == {"ok": True}
|
||||
assert stale_client.chat.completions.create.await_count == 1
|
||||
assert fresh_async_client.chat.completions.create.await_count == 1
|
||||
|
||||
# ── Payment / credit exhaustion fallback ─────────────────────────────────
|
||||
|
||||
|
||||
|
|
@ -696,27 +823,46 @@ class TestIsConnectionError:
|
|||
assert _is_connection_error(err) is False
|
||||
|
||||
|
||||
class TestKimiForCodingTemperature:
|
||||
"""Moonshot kimi-for-coding models require fixed temperatures.
|
||||
class TestKimiTemperatureOmitted:
|
||||
"""Kimi/Moonshot models should have temperature OMITTED from API kwargs.
|
||||
|
||||
k2.5 / k2-turbo-preview / k2-0905-preview → 0.6 (non-thinking lock).
|
||||
k2-thinking / k2-thinking-turbo → 1.0 (thinking lock).
|
||||
kimi-k2-instruct* and every other model preserve the caller's temperature.
|
||||
The Kimi gateway selects the correct temperature server-side based on the
|
||||
active mode (thinking → 1.0, non-thinking → 0.6). Sending any temperature
|
||||
value conflicts with gateway-managed defaults.
|
||||
"""
|
||||
|
||||
def test_build_call_kwargs_forces_fixed_temperature(self):
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"kimi-for-coding",
|
||||
"kimi-k2.5",
|
||||
"kimi-k2.6",
|
||||
"kimi-k2-turbo-preview",
|
||||
"kimi-k2-0905-preview",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2-thinking-turbo",
|
||||
"kimi-k2-instruct",
|
||||
"kimi-k2-instruct-0905",
|
||||
"moonshotai/kimi-k2.5",
|
||||
"moonshotai/Kimi-K2-Thinking",
|
||||
"moonshotai/Kimi-K2-Instruct",
|
||||
],
|
||||
)
|
||||
def test_kimi_models_omit_temperature(self, model):
|
||||
"""No kimi model should have a temperature key in kwargs."""
|
||||
from agent.auxiliary_client import _build_call_kwargs
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
provider="kimi-coding",
|
||||
model="kimi-for-coding",
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
assert kwargs["temperature"] == 0.6
|
||||
assert "temperature" not in kwargs
|
||||
|
||||
def test_build_call_kwargs_injects_temperature_when_missing(self):
|
||||
def test_kimi_for_coding_no_temperature_when_none(self):
|
||||
"""When caller passes temperature=None, still no temperature key."""
|
||||
from agent.auxiliary_client import _build_call_kwargs
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
|
|
@ -726,9 +872,9 @@ class TestKimiForCodingTemperature:
|
|||
temperature=None,
|
||||
)
|
||||
|
||||
assert kwargs["temperature"] == 0.6
|
||||
assert "temperature" not in kwargs
|
||||
|
||||
def test_auto_routed_kimi_for_coding_sync_call_uses_fixed_temperature(self):
|
||||
def test_sync_call_omits_temperature(self):
|
||||
client = MagicMock()
|
||||
client.base_url = "https://api.kimi.com/coding/v1"
|
||||
response = MagicMock()
|
||||
|
|
@ -750,10 +896,10 @@ class TestKimiForCodingTemperature:
|
|||
assert result is response
|
||||
kwargs = client.chat.completions.create.call_args.kwargs
|
||||
assert kwargs["model"] == "kimi-for-coding"
|
||||
assert kwargs["temperature"] == 0.6
|
||||
assert "temperature" not in kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_routed_kimi_for_coding_async_call_uses_fixed_temperature(self):
|
||||
async def test_async_call_omits_temperature(self):
|
||||
client = MagicMock()
|
||||
client.base_url = "https://api.kimi.com/coding/v1"
|
||||
response = MagicMock()
|
||||
|
|
@ -775,52 +921,17 @@ class TestKimiForCodingTemperature:
|
|||
assert result is response
|
||||
kwargs = client.chat.completions.create.call_args.kwargs
|
||||
assert kwargs["model"] == "kimi-for-coding"
|
||||
assert kwargs["temperature"] == 0.6
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,expected",
|
||||
[
|
||||
("kimi-k2.5", 0.6),
|
||||
("kimi-k2-turbo-preview", 0.6),
|
||||
("kimi-k2-0905-preview", 0.6),
|
||||
("kimi-k2-thinking", 1.0),
|
||||
("kimi-k2-thinking-turbo", 1.0),
|
||||
("moonshotai/kimi-k2.5", 0.6),
|
||||
("moonshotai/Kimi-K2-Thinking", 1.0),
|
||||
],
|
||||
)
|
||||
def test_kimi_k2_family_temperature_override(self, model, expected):
|
||||
"""Moonshot kimi-k2.* models only accept fixed temperatures.
|
||||
|
||||
Non-thinking models → 0.6, thinking-mode models → 1.0.
|
||||
"""
|
||||
from agent.auxiliary_client import _build_call_kwargs
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
provider="kimi-coding",
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
assert kwargs["temperature"] == expected
|
||||
assert "temperature" not in kwargs
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"anthropic/claude-sonnet-4-6",
|
||||
"gpt-5.4",
|
||||
# kimi-k2-instruct is the non-coding K2 family — temperature is
|
||||
# variable (recommended 0.6 but not enforced). Must not clamp.
|
||||
"kimi-k2-instruct",
|
||||
"moonshotai/Kimi-K2-Instruct",
|
||||
"moonshotai/Kimi-K2-Instruct-0905",
|
||||
"kimi-k2-instruct-0905",
|
||||
# Hypothetical future kimi name not in the whitelist.
|
||||
"kimi-k2-experimental",
|
||||
"deepseek-chat",
|
||||
],
|
||||
)
|
||||
def test_non_restricted_model_preserves_temperature(self, model):
|
||||
def test_non_kimi_models_preserve_temperature(self, model):
|
||||
from agent.auxiliary_client import _build_call_kwargs
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
|
|
@ -832,6 +943,28 @@ class TestKimiForCodingTemperature:
|
|||
|
||||
assert kwargs["temperature"] == 0.3
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"base_url",
|
||||
[
|
||||
"https://api.moonshot.ai/v1",
|
||||
"https://api.moonshot.cn/v1",
|
||||
"https://api.kimi.com/coding/v1",
|
||||
],
|
||||
)
|
||||
def test_kimi_k2_5_omits_temperature_regardless_of_endpoint(self, base_url):
|
||||
"""Temperature is omitted regardless of which Kimi endpoint is used."""
|
||||
from agent.auxiliary_client import _build_call_kwargs
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
provider="kimi-coding",
|
||||
model="kimi-k2.5",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.1,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
assert "temperature" not in kwargs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# async_call_llm payment / connection fallback (#7512 bug 2)
|
||||
|
|
@ -858,6 +991,70 @@ class TestStaleBaseUrlWarning:
|
|||
"Expected a warning about stale OPENAI_BASE_URL"
|
||||
assert mod._stale_base_url_warned is True
|
||||
|
||||
|
||||
class TestAuxiliaryTaskExtraBody:
|
||||
def test_sync_call_merges_task_extra_body_from_config(self):
|
||||
client = MagicMock()
|
||||
client.base_url = "https://api.example.com/v1"
|
||||
response = MagicMock()
|
||||
client.chat.completions.create.return_value = response
|
||||
|
||||
config = {
|
||||
"auxiliary": {
|
||||
"session_search": {
|
||||
"extra_body": {
|
||||
"enable_thinking": False,
|
||||
"reasoning": {"effort": "none"},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
with patch("hermes_cli.config.load_config", return_value=config), patch(
|
||||
"agent.auxiliary_client._get_cached_client",
|
||||
return_value=(client, "glm-4.5-air"),
|
||||
):
|
||||
result = call_llm(
|
||||
task="session_search",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
extra_body={"metadata": {"source": "test"}},
|
||||
)
|
||||
|
||||
assert result is response
|
||||
kwargs = client.chat.completions.create.call_args.kwargs
|
||||
assert kwargs["extra_body"]["enable_thinking"] is False
|
||||
assert kwargs["extra_body"]["reasoning"] == {"effort": "none"}
|
||||
assert kwargs["extra_body"]["metadata"] == {"source": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_call_explicit_extra_body_overrides_task_config(self):
|
||||
client = MagicMock()
|
||||
client.base_url = "https://api.example.com/v1"
|
||||
response = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(return_value=response)
|
||||
|
||||
config = {
|
||||
"auxiliary": {
|
||||
"session_search": {
|
||||
"extra_body": {"enable_thinking": False}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
with patch("hermes_cli.config.load_config", return_value=config), patch(
|
||||
"agent.auxiliary_client._get_cached_client",
|
||||
return_value=(client, "glm-4.5-air"),
|
||||
):
|
||||
result = await async_call_llm(
|
||||
task="session_search",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
extra_body={"enable_thinking": True},
|
||||
)
|
||||
|
||||
assert result is response
|
||||
kwargs = client.chat.completions.create.call_args.kwargs
|
||||
assert kwargs["extra_body"]["enable_thinking"] is True
|
||||
|
||||
def test_no_warning_when_provider_is_custom(self, monkeypatch, caplog):
|
||||
"""No warning when the provider is 'custom' — OPENAI_BASE_URL is expected."""
|
||||
import agent.auxiliary_client as mod
|
||||
|
|
|
|||
107
tests/agent/test_auxiliary_client_anthropic_custom.py
Normal file
107
tests/agent/test_auxiliary_client_anthropic_custom.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""Tests for agent.auxiliary_client._try_custom_endpoint's anthropic_messages branch.
|
||||
|
||||
When a user configures a custom endpoint with ``api_mode: anthropic_messages``
|
||||
(e.g. MiniMax, Zhipu GLM, LiteLLM in Anthropic-proxy mode), auxiliary tasks
|
||||
(compression, web_extract, session_search, title generation) must use the
|
||||
native Anthropic transport rather than being silently downgraded to an
|
||||
OpenAI-wire client that speaks the wrong protocol.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_env(monkeypatch):
|
||||
for key in (
|
||||
"OPENAI_API_KEY", "OPENAI_BASE_URL",
|
||||
"ANTHROPIC_API_KEY", "ANTHROPIC_TOKEN",
|
||||
):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
||||
def _install_anthropic_adapter_mocks():
|
||||
"""Patch build_anthropic_client so the test doesn't need the SDK."""
|
||||
fake_client = MagicMock(name="anthropic_client")
|
||||
return patch(
|
||||
"agent.anthropic_adapter.build_anthropic_client",
|
||||
return_value=fake_client,
|
||||
), fake_client
|
||||
|
||||
|
||||
def test_custom_endpoint_anthropic_messages_builds_anthropic_wrapper():
|
||||
"""api_mode=anthropic_messages → returns AnthropicAuxiliaryClient, not OpenAI."""
|
||||
from agent.auxiliary_client import _try_custom_endpoint, AnthropicAuxiliaryClient
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._resolve_custom_runtime",
|
||||
return_value=(
|
||||
"https://api.minimax.io/anthropic",
|
||||
"minimax-key",
|
||||
"anthropic_messages",
|
||||
),
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="claude-sonnet-4-6",
|
||||
):
|
||||
adapter_patch, fake_client = _install_anthropic_adapter_mocks()
|
||||
with adapter_patch:
|
||||
client, model = _try_custom_endpoint()
|
||||
|
||||
assert isinstance(client, AnthropicAuxiliaryClient), (
|
||||
"Custom endpoint with api_mode=anthropic_messages must return the "
|
||||
f"native Anthropic wrapper, got {type(client).__name__}"
|
||||
)
|
||||
assert model == "claude-sonnet-4-6"
|
||||
# Wrapper should NOT be marked as OAuth — third-party endpoints are
|
||||
# always API-key authenticated.
|
||||
assert client.api_key == "minimax-key"
|
||||
assert client.base_url == "https://api.minimax.io/anthropic"
|
||||
|
||||
|
||||
def test_custom_endpoint_anthropic_messages_falls_back_when_sdk_missing():
|
||||
"""Graceful degradation when anthropic SDK is unavailable."""
|
||||
from agent.auxiliary_client import _try_custom_endpoint
|
||||
|
||||
import_error = ImportError("anthropic package not installed")
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._resolve_custom_runtime",
|
||||
return_value=("https://api.minimax.io/anthropic", "k", "anthropic_messages"),
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="claude-sonnet-4-6",
|
||||
), patch(
|
||||
"agent.anthropic_adapter.build_anthropic_client",
|
||||
side_effect=import_error,
|
||||
):
|
||||
client, model = _try_custom_endpoint()
|
||||
|
||||
# Should fall back to an OpenAI-wire client rather than returning
|
||||
# (None, None) — the tool still needs to do *something*.
|
||||
assert client is not None
|
||||
assert model == "claude-sonnet-4-6"
|
||||
# OpenAI client, not AnthropicAuxiliaryClient.
|
||||
from agent.auxiliary_client import AnthropicAuxiliaryClient
|
||||
assert not isinstance(client, AnthropicAuxiliaryClient)
|
||||
|
||||
|
||||
def test_custom_endpoint_chat_completions_still_uses_openai_wire():
|
||||
"""Regression: default path (no api_mode) must remain OpenAI client."""
|
||||
from agent.auxiliary_client import _try_custom_endpoint, AnthropicAuxiliaryClient
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._resolve_custom_runtime",
|
||||
return_value=("https://api.example.com/v1", "key", None),
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="my-model",
|
||||
):
|
||||
client, model = _try_custom_endpoint()
|
||||
|
||||
assert client is not None
|
||||
assert model == "my-model"
|
||||
assert not isinstance(client, AnthropicAuxiliaryClient)
|
||||
|
|
@ -167,7 +167,7 @@ class TestResolveAutoMainFirst:
|
|||
|
||||
|
||||
class TestResolveVisionMainFirst:
|
||||
"""Vision auto-detection prefers main provider + main model first."""
|
||||
"""Vision auto-detection prefers the main provider first."""
|
||||
|
||||
def test_openrouter_main_vision_uses_main_model(self, monkeypatch):
|
||||
"""OpenRouter main with vision-capable model → aux vision uses main model."""
|
||||
|
|
@ -200,28 +200,49 @@ class TestResolveVisionMainFirst:
|
|||
assert mock_resolve.call_args.args[0] == "openrouter"
|
||||
assert mock_resolve.call_args.args[1] == "anthropic/claude-sonnet-4.6"
|
||||
|
||||
def test_nous_main_vision_uses_main_model(self):
|
||||
"""Nous Portal main → aux vision uses main model, not free-tier MiMo-V2-Omni."""
|
||||
def test_nous_main_vision_uses_paid_nous_vision_backend(self):
|
||||
"""Paid Nous main → aux vision uses the dedicated Nous vision backend."""
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="nous",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="openai/gpt-5",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve, patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_strict_vision_backend",
|
||||
return_value=(MagicMock(), "google/gemini-3-flash-preview"),
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_resolve.return_value = (mock_client, "openai/gpt-5")
|
||||
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "nous"
|
||||
assert model == "openai/gpt-5"
|
||||
assert client is not None
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
|
||||
def test_nous_main_vision_uses_free_tier_nous_vision_backend(self):
|
||||
"""Free-tier Nous main → aux vision uses MiMo omni, not the text main model."""
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="nous",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="xiaomi/mimo-v2-pro",
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_strict_vision_backend",
|
||||
return_value=(MagicMock(), "xiaomi/mimo-v2-omni"),
|
||||
):
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "nous"
|
||||
assert client is not None
|
||||
assert model == "xiaomi/mimo-v2-omni"
|
||||
|
||||
def test_exotic_provider_with_vision_override_preserved(self):
|
||||
"""xiaomi → mimo-v2-omni override still wins over main_model."""
|
||||
|
|
|
|||
|
|
@ -267,3 +267,174 @@ class TestPackaging:
|
|||
from pathlib import Path
|
||||
content = (Path(__file__).parent.parent.parent / "pyproject.toml").read_text()
|
||||
assert '"hermes-agent[bedrock]"' in content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model ID dot preservation — regression for #11976
|
||||
# ---------------------------------------------------------------------------
|
||||
# AWS Bedrock inference-profile model IDs embed structural dots:
|
||||
#
|
||||
# global.anthropic.claude-opus-4-7
|
||||
# us.anthropic.claude-sonnet-4-5-20250929-v1:0
|
||||
# apac.anthropic.claude-haiku-4-5
|
||||
#
|
||||
# ``agent.anthropic_adapter.normalize_model_name`` converts dots to hyphens
|
||||
# unless the caller opts in via ``preserve_dots=True``. Before this fix,
|
||||
# ``AIAgent._anthropic_preserve_dots`` returned False for the ``bedrock``
|
||||
# provider, so Claude-on-Bedrock requests went out with
|
||||
# ``global-anthropic-claude-opus-4-7`` (all dots mangled to hyphens) and
|
||||
# Bedrock rejected them with:
|
||||
#
|
||||
# HTTP 400: The provided model identifier is invalid.
|
||||
#
|
||||
# The fix adds ``bedrock`` to the preserve-dots provider allowlist and
|
||||
# ``bedrock-runtime.`` to the base-URL heuristic, mirroring the shape of
|
||||
# the opencode-go fix for #5211 (commit f77be22c), which extended this
|
||||
# same allowlist.
|
||||
|
||||
|
||||
class TestBedrockPreserveDotsFlag:
|
||||
"""``AIAgent._anthropic_preserve_dots`` must return True on Bedrock so
|
||||
inference-profile IDs survive the normalize step intact."""
|
||||
|
||||
def test_bedrock_provider_preserves_dots(self):
|
||||
from types import SimpleNamespace
|
||||
agent = SimpleNamespace(provider="bedrock", base_url="")
|
||||
from run_agent import AIAgent
|
||||
assert AIAgent._anthropic_preserve_dots(agent) is True
|
||||
|
||||
def test_bedrock_runtime_us_east_1_url_preserves_dots(self):
|
||||
"""Defense-in-depth: even without an explicit ``provider="bedrock"``,
|
||||
a ``bedrock-runtime.us-east-1.amazonaws.com`` base URL must not
|
||||
mangle dots."""
|
||||
from types import SimpleNamespace
|
||||
agent = SimpleNamespace(
|
||||
provider="custom",
|
||||
base_url="https://bedrock-runtime.us-east-1.amazonaws.com",
|
||||
)
|
||||
from run_agent import AIAgent
|
||||
assert AIAgent._anthropic_preserve_dots(agent) is True
|
||||
|
||||
def test_bedrock_runtime_ap_northeast_2_url_preserves_dots(self):
|
||||
"""Reporter-reported region (ap-northeast-2) exercises the same
|
||||
base-URL heuristic."""
|
||||
from types import SimpleNamespace
|
||||
agent = SimpleNamespace(
|
||||
provider="custom",
|
||||
base_url="https://bedrock-runtime.ap-northeast-2.amazonaws.com",
|
||||
)
|
||||
from run_agent import AIAgent
|
||||
assert AIAgent._anthropic_preserve_dots(agent) is True
|
||||
|
||||
def test_non_bedrock_aws_url_does_not_preserve_dots(self):
|
||||
"""Unrelated AWS endpoints (e.g. ``s3.us-east-1.amazonaws.com``)
|
||||
must not accidentally activate the dot-preservation heuristic —
|
||||
the heuristic is scoped to the ``bedrock-runtime.`` substring
|
||||
specifically."""
|
||||
from types import SimpleNamespace
|
||||
agent = SimpleNamespace(
|
||||
provider="custom",
|
||||
base_url="https://s3.us-east-1.amazonaws.com",
|
||||
)
|
||||
from run_agent import AIAgent
|
||||
assert AIAgent._anthropic_preserve_dots(agent) is False
|
||||
|
||||
def test_anthropic_native_still_does_not_preserve_dots(self):
|
||||
"""Canary: adding Bedrock to the allowlist must not weaken the
|
||||
existing Anthropic native behaviour — ``claude-sonnet-4.6`` still
|
||||
becomes ``claude-sonnet-4-6`` for the Anthropic API."""
|
||||
from types import SimpleNamespace
|
||||
agent = SimpleNamespace(provider="anthropic", base_url="https://api.anthropic.com")
|
||||
from run_agent import AIAgent
|
||||
assert AIAgent._anthropic_preserve_dots(agent) is False
|
||||
|
||||
|
||||
class TestBedrockModelNameNormalization:
|
||||
"""End-to-end: ``normalize_model_name`` + the preserve-dots flag
|
||||
reproduce the exact production request shape for each Bedrock model
|
||||
family, confirming the fix resolves the reporter's HTTP 400."""
|
||||
|
||||
def test_global_anthropic_inference_profile_preserved(self):
|
||||
"""The reporter's exact model ID."""
|
||||
from agent.anthropic_adapter import normalize_model_name
|
||||
assert normalize_model_name(
|
||||
"global.anthropic.claude-opus-4-7", preserve_dots=True
|
||||
) == "global.anthropic.claude-opus-4-7"
|
||||
|
||||
def test_us_anthropic_dated_inference_profile_preserved(self):
|
||||
"""Regional + dated Sonnet inference profile."""
|
||||
from agent.anthropic_adapter import normalize_model_name
|
||||
assert normalize_model_name(
|
||||
"us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
preserve_dots=True,
|
||||
) == "us.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
|
||||
def test_apac_anthropic_haiku_inference_profile_preserved(self):
|
||||
"""APAC inference profile — same structural-dot shape."""
|
||||
from agent.anthropic_adapter import normalize_model_name
|
||||
assert normalize_model_name(
|
||||
"apac.anthropic.claude-haiku-4-5", preserve_dots=True
|
||||
) == "apac.anthropic.claude-haiku-4-5"
|
||||
|
||||
def test_preserve_false_mangles_as_documented(self):
|
||||
"""Canary: with ``preserve_dots=False`` the function still
|
||||
produces the broken all-hyphen form — this is the shape that
|
||||
Bedrock rejected and that the fix avoids. Keeping this test
|
||||
locks in the existing behaviour of ``normalize_model_name`` so a
|
||||
future refactor doesn't accidentally decouple the knob from its
|
||||
effect."""
|
||||
from agent.anthropic_adapter import normalize_model_name
|
||||
assert normalize_model_name(
|
||||
"global.anthropic.claude-opus-4-7", preserve_dots=False
|
||||
) == "global-anthropic-claude-opus-4-7"
|
||||
|
||||
def test_bare_foundation_model_id_preserved(self):
|
||||
"""Non-inference-profile Bedrock IDs
|
||||
(e.g. ``anthropic.claude-3-5-sonnet-20241022-v2:0``) use dots as
|
||||
vendor separators and must also survive intact under
|
||||
``preserve_dots=True``."""
|
||||
from agent.anthropic_adapter import normalize_model_name
|
||||
assert normalize_model_name(
|
||||
"anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
preserve_dots=True,
|
||||
) == "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
|
||||
|
||||
class TestBedrockBuildAnthropicKwargsEndToEnd:
|
||||
"""Integration: calling ``build_anthropic_kwargs`` with a Bedrock-
|
||||
shaped model ID and ``preserve_dots=True`` produces the unmangled
|
||||
model string in the outgoing kwargs — the exact body sent to the
|
||||
``bedrock-runtime.`` endpoint. This is the integration-level
|
||||
regression for the reporter's HTTP 400."""
|
||||
|
||||
def test_bedrock_inference_profile_survives_build_kwargs(self):
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="global.anthropic.claude-opus-4-7",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=None,
|
||||
max_tokens=1024,
|
||||
reasoning_config=None,
|
||||
preserve_dots=True,
|
||||
)
|
||||
assert kwargs["model"] == "global.anthropic.claude-opus-4-7", (
|
||||
"Bedrock inference-profile ID was mangled in build_anthropic_kwargs: "
|
||||
f"{kwargs['model']!r}"
|
||||
)
|
||||
|
||||
def test_bedrock_model_mangled_without_preserve_dots(self):
|
||||
"""Inverse canary: without the flag, ``build_anthropic_kwargs``
|
||||
still produces the broken form — so the fix in
|
||||
``_anthropic_preserve_dots`` is the load-bearing piece that
|
||||
wires ``preserve_dots=True`` through to this builder for the
|
||||
Bedrock case."""
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="global.anthropic.claude-opus-4-7",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=None,
|
||||
max_tokens=1024,
|
||||
reasoning_config=None,
|
||||
preserve_dots=False,
|
||||
)
|
||||
assert kwargs["model"] == "global-anthropic-claude-opus-4-7"
|
||||
|
|
|
|||
253
tests/agent/test_codex_cloudflare_headers.py
Normal file
253
tests/agent/test_codex_cloudflare_headers.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
"""Regression guard: Codex Cloudflare 403 mitigation headers.
|
||||
|
||||
The ``chatgpt.com/backend-api/codex`` endpoint sits behind a Cloudflare layer
|
||||
that whitelists a small set of first-party originators (``codex_cli_rs``,
|
||||
``codex_vscode``, ``codex_sdk_ts``, ``Codex*``). Requests from non-residential
|
||||
IPs (VPS, always-on servers, some corporate egress) that don't advertise an
|
||||
allowed originator are served 403 with ``cf-mitigated: challenge`` regardless
|
||||
of auth correctness.
|
||||
|
||||
``_codex_cloudflare_headers`` in ``agent.auxiliary_client`` centralizes the
|
||||
header set so the primary chat client (``run_agent.AIAgent.__init__`` +
|
||||
``_apply_client_headers_for_base_url``) and the auxiliary client paths
|
||||
(``_try_codex`` and the ``raw_codex`` branch of ``resolve_provider_client``)
|
||||
all emit the same headers.
|
||||
|
||||
These tests pin:
|
||||
- the originator value (must be ``codex_cli_rs`` — the whitelisted one)
|
||||
- the User-Agent shape (codex_cli_rs-prefixed)
|
||||
- ``ChatGPT-Account-ID`` extraction from the OAuth JWT (canonical casing,
|
||||
from codex-rs ``auth.rs``)
|
||||
- graceful handling of malformed tokens (drop the account-ID header, don't
|
||||
raise)
|
||||
- primary-client wiring at both entry points in ``run_agent.py``
|
||||
- aux-client wiring at both entry points in ``agent/auxiliary_client.py``
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_codex_jwt(account_id: str = "acct-test-123") -> str:
|
||||
"""Build a syntactically valid Codex-style JWT with the account_id claim."""
|
||||
def b64url(data: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
|
||||
header = b64url(b'{"alg":"RS256","typ":"JWT"}')
|
||||
claims = {
|
||||
"sub": "user-xyz",
|
||||
"exp": 9999999999,
|
||||
"https://api.openai.com/auth": {
|
||||
"chatgpt_account_id": account_id,
|
||||
"chatgpt_plan_type": "plus",
|
||||
},
|
||||
}
|
||||
payload = b64url(json.dumps(claims).encode())
|
||||
sig = b64url(b"fake-sig")
|
||||
return f"{header}.{payload}.{sig}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _codex_cloudflare_headers — the shared helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCodexCloudflareHeaders:
|
||||
def test_originator_is_codex_cli_rs(self):
|
||||
"""Cloudflare whitelists codex_cli_rs — any other value is 403'd."""
|
||||
from agent.auxiliary_client import _codex_cloudflare_headers
|
||||
headers = _codex_cloudflare_headers(_make_codex_jwt())
|
||||
assert headers["originator"] == "codex_cli_rs"
|
||||
|
||||
def test_user_agent_advertises_codex_cli_rs(self):
|
||||
from agent.auxiliary_client import _codex_cloudflare_headers
|
||||
headers = _codex_cloudflare_headers(_make_codex_jwt())
|
||||
assert headers["User-Agent"].startswith("codex_cli_rs/")
|
||||
|
||||
def test_account_id_extracted_from_jwt(self):
|
||||
from agent.auxiliary_client import _codex_cloudflare_headers
|
||||
headers = _codex_cloudflare_headers(_make_codex_jwt("acct-abc-999"))
|
||||
# Canonical casing — matches codex-rs auth.rs
|
||||
assert headers["ChatGPT-Account-ID"] == "acct-abc-999"
|
||||
|
||||
def test_canonical_header_casing(self):
|
||||
"""Upstream codex-rs uses PascalCase with trailing -ID. Match exactly."""
|
||||
from agent.auxiliary_client import _codex_cloudflare_headers
|
||||
headers = _codex_cloudflare_headers(_make_codex_jwt())
|
||||
assert "ChatGPT-Account-ID" in headers
|
||||
# The lowercase/titlecase variants MUST NOT be used — pin to be explicit
|
||||
assert "chatgpt-account-id" not in headers
|
||||
assert "ChatGPT-Account-Id" not in headers
|
||||
|
||||
def test_malformed_token_drops_account_id_without_raising(self):
|
||||
from agent.auxiliary_client import _codex_cloudflare_headers
|
||||
for bad in ["not-a-jwt", "", "only.one", " ", "...."]:
|
||||
headers = _codex_cloudflare_headers(bad)
|
||||
# Still returns base headers — never raises
|
||||
assert headers["originator"] == "codex_cli_rs"
|
||||
assert "ChatGPT-Account-ID" not in headers
|
||||
|
||||
def test_non_string_token_handled(self):
|
||||
from agent.auxiliary_client import _codex_cloudflare_headers
|
||||
headers = _codex_cloudflare_headers(None) # type: ignore[arg-type]
|
||||
assert headers["originator"] == "codex_cli_rs"
|
||||
assert "ChatGPT-Account-ID" not in headers
|
||||
|
||||
def test_jwt_without_chatgpt_account_id_claim(self):
|
||||
"""A valid JWT that lacks the account_id claim should still return headers."""
|
||||
from agent.auxiliary_client import _codex_cloudflare_headers
|
||||
import base64 as _b64, json as _json
|
||||
|
||||
def b64url(data: bytes) -> str:
|
||||
return _b64.urlsafe_b64encode(data).rstrip(b"=").decode()
|
||||
payload = b64url(_json.dumps({"sub": "user-xyz", "exp": 9999999999}).encode())
|
||||
token = f"{b64url(b'{}')}.{payload}.{b64url(b'sig')}"
|
||||
headers = _codex_cloudflare_headers(token)
|
||||
assert headers["originator"] == "codex_cli_rs"
|
||||
assert "ChatGPT-Account-ID" not in headers
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Primary chat client wiring (run_agent.AIAgent)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPrimaryClientWiring:
|
||||
def test_init_wires_codex_headers_for_chatgpt_base_url(self):
|
||||
from run_agent import AIAgent
|
||||
token = _make_codex_jwt("acct-primary-init")
|
||||
with patch("run_agent.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
AIAgent(
|
||||
api_key=token,
|
||||
base_url="https://chatgpt.com/backend-api/codex",
|
||||
provider="openai-codex",
|
||||
model="gpt-5.4",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
headers = mock_openai.call_args.kwargs.get("default_headers") or {}
|
||||
assert headers.get("originator") == "codex_cli_rs"
|
||||
assert headers.get("ChatGPT-Account-ID") == "acct-primary-init"
|
||||
assert headers.get("User-Agent", "").startswith("codex_cli_rs/")
|
||||
|
||||
def test_apply_client_headers_on_base_url_change(self):
|
||||
"""Credential-rotation / base-url change path must also emit codex headers."""
|
||||
from run_agent import AIAgent
|
||||
token = _make_codex_jwt("acct-rotation")
|
||||
with patch("run_agent.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
agent = AIAgent(
|
||||
api_key="placeholder-openrouter-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
provider="openrouter",
|
||||
model="anthropic/claude-sonnet-4.6",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
# Simulate rotation into a Codex credential
|
||||
agent._client_kwargs["api_key"] = token
|
||||
agent._apply_client_headers_for_base_url(
|
||||
"https://chatgpt.com/backend-api/codex"
|
||||
)
|
||||
headers = agent._client_kwargs.get("default_headers") or {}
|
||||
assert headers.get("originator") == "codex_cli_rs"
|
||||
assert headers.get("ChatGPT-Account-ID") == "acct-rotation"
|
||||
assert headers.get("User-Agent", "").startswith("codex_cli_rs/")
|
||||
|
||||
def test_apply_client_headers_clears_codex_headers_off_chatgpt(self):
|
||||
"""Switching AWAY from chatgpt.com must drop the codex headers."""
|
||||
from run_agent import AIAgent
|
||||
token = _make_codex_jwt()
|
||||
with patch("run_agent.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
agent = AIAgent(
|
||||
api_key=token,
|
||||
base_url="https://chatgpt.com/backend-api/codex",
|
||||
provider="openai-codex",
|
||||
model="gpt-5.4",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
# Sanity: headers are set initially
|
||||
assert "originator" in (agent._client_kwargs.get("default_headers") or {})
|
||||
agent._apply_client_headers_for_base_url(
|
||||
"https://api.anthropic.com"
|
||||
)
|
||||
# default_headers should be popped for anthropic base
|
||||
assert "default_headers" not in agent._client_kwargs
|
||||
|
||||
def test_openrouter_base_url_does_not_get_codex_headers(self):
|
||||
from run_agent import AIAgent
|
||||
with patch("run_agent.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
AIAgent(
|
||||
api_key="sk-or-test",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
provider="openrouter",
|
||||
model="anthropic/claude-sonnet-4.6",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
headers = mock_openai.call_args.kwargs.get("default_headers") or {}
|
||||
assert headers.get("originator") != "codex_cli_rs"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auxiliary client wiring (agent.auxiliary_client)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAuxiliaryClientWiring:
|
||||
def test_try_codex_passes_codex_headers(self, monkeypatch):
|
||||
"""_try_codex builds the OpenAI client used for compression / vision /
|
||||
title generation when routed through Codex. Must emit codex headers."""
|
||||
from agent import auxiliary_client
|
||||
token = _make_codex_jwt("acct-aux-try-codex")
|
||||
|
||||
# Force _select_pool_entry to return "no pool" so we fall through to
|
||||
# _read_codex_access_token.
|
||||
monkeypatch.setattr(
|
||||
auxiliary_client, "_select_pool_entry",
|
||||
lambda provider: (False, None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auxiliary_client, "_read_codex_access_token",
|
||||
lambda: token,
|
||||
)
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
client, model = auxiliary_client._try_codex()
|
||||
assert client is not None
|
||||
headers = mock_openai.call_args.kwargs.get("default_headers") or {}
|
||||
assert headers.get("originator") == "codex_cli_rs"
|
||||
assert headers.get("ChatGPT-Account-ID") == "acct-aux-try-codex"
|
||||
assert headers.get("User-Agent", "").startswith("codex_cli_rs/")
|
||||
|
||||
def test_resolve_provider_client_raw_codex_passes_codex_headers(self, monkeypatch):
|
||||
"""The ``raw_codex=True`` branch (used by the main agent loop for direct
|
||||
responses.stream() access) must also emit codex headers."""
|
||||
from agent import auxiliary_client
|
||||
token = _make_codex_jwt("acct-aux-raw-codex")
|
||||
monkeypatch.setattr(
|
||||
auxiliary_client, "_read_codex_access_token",
|
||||
lambda: token,
|
||||
)
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
client, model = auxiliary_client.resolve_provider_client(
|
||||
"openai-codex", raw_codex=True,
|
||||
)
|
||||
assert client is not None
|
||||
headers = mock_openai.call_args.kwargs.get("default_headers") or {}
|
||||
assert headers.get("originator") == "codex_cli_rs"
|
||||
assert headers.get("ChatGPT-Account-ID") == "acct-aux-raw-codex"
|
||||
assert headers.get("User-Agent", "").startswith("codex_cli_rs/")
|
||||
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -124,6 +125,31 @@ def test_expand_file_range_and_folder_listing(sample_repo: Path):
|
|||
assert not result.warnings
|
||||
|
||||
|
||||
def test_folder_listing_falls_back_when_rg_is_blocked(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
real_run = subprocess.run
|
||||
|
||||
def blocked_rg(*args, **kwargs):
|
||||
cmd = args[0] if args else kwargs.get("args")
|
||||
if isinstance(cmd, list) and cmd and cmd[0] == "rg":
|
||||
raise PermissionError("rg blocked by policy")
|
||||
return real_run(*args, **kwargs)
|
||||
|
||||
with patch("agent.context_references.subprocess.run", side_effect=blocked_rg):
|
||||
result = preprocess_context_references(
|
||||
"Review @folder:src/",
|
||||
cwd=sample_repo,
|
||||
context_length=100_000,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert "src/" in result.message
|
||||
assert "main.py" in result.message
|
||||
assert "helper.py" in result.message
|
||||
assert not result.warnings
|
||||
|
||||
|
||||
def test_expand_quoted_file_reference_with_spaces(tmp_path: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
|
|
|
|||
146
tests/agent/test_copilot_acp_client.py
Normal file
146
tests/agent/test_copilot_acp_client.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
"""Focused regressions for the Copilot ACP shim safety layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent.copilot_acp_client import CopilotACPClient
|
||||
|
||||
|
||||
class _FakeProcess:
|
||||
def __init__(self) -> None:
|
||||
self.stdin = io.StringIO()
|
||||
|
||||
|
||||
class CopilotACPClientSafetyTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.client = CopilotACPClient(acp_cwd="/tmp")
|
||||
|
||||
def _dispatch(self, message: dict, *, cwd: str) -> dict:
|
||||
process = _FakeProcess()
|
||||
handled = self.client._handle_server_message(
|
||||
message,
|
||||
process=process,
|
||||
cwd=cwd,
|
||||
text_parts=[],
|
||||
reasoning_parts=[],
|
||||
)
|
||||
self.assertTrue(handled)
|
||||
payload = process.stdin.getvalue().strip()
|
||||
self.assertTrue(payload)
|
||||
return json.loads(payload)
|
||||
|
||||
def test_request_permission_is_not_auto_allowed(self) -> None:
|
||||
response = self._dispatch(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "session/request_permission",
|
||||
"params": {},
|
||||
},
|
||||
cwd="/tmp",
|
||||
)
|
||||
|
||||
outcome = (((response.get("result") or {}).get("outcome") or {}).get("outcome"))
|
||||
self.assertEqual(outcome, "cancelled")
|
||||
|
||||
def test_read_text_file_blocks_internal_hermes_hub_files(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
home = Path(tmpdir) / "home"
|
||||
blocked = home / ".hermes" / "skills" / ".hub" / "index-cache" / "entry.json"
|
||||
blocked.parent.mkdir(parents=True, exist_ok=True)
|
||||
blocked.write_text('{"token":"sk-test-secret-1234567890"}')
|
||||
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"HOME": str(home), "HERMES_HOME": str(home / ".hermes")},
|
||||
clear=False,
|
||||
):
|
||||
response = self._dispatch(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "fs/read_text_file",
|
||||
"params": {"path": str(blocked)},
|
||||
},
|
||||
cwd=str(home),
|
||||
)
|
||||
|
||||
self.assertIn("error", response)
|
||||
|
||||
def test_read_text_file_redacts_sensitive_content(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
root = Path(tmpdir)
|
||||
secret_file = root / "config.env"
|
||||
secret_file.write_text("OPENAI_API_KEY=sk-proj-abc123def456ghi789jkl012")
|
||||
|
||||
response = self._dispatch(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "fs/read_text_file",
|
||||
"params": {"path": str(secret_file)},
|
||||
},
|
||||
cwd=str(root),
|
||||
)
|
||||
|
||||
content = ((response.get("result") or {}).get("content") or "")
|
||||
self.assertNotIn("abc123def456", content)
|
||||
self.assertIn("OPENAI_API_KEY=", content)
|
||||
|
||||
def test_write_text_file_reuses_write_denylist(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
home = Path(tmpdir) / "home"
|
||||
target = home / ".ssh" / "id_rsa"
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with patch("agent.copilot_acp_client.is_write_denied", return_value=True, create=True):
|
||||
response = self._dispatch(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"method": "fs/write_text_file",
|
||||
"params": {
|
||||
"path": str(target),
|
||||
"content": "fake-private-key",
|
||||
},
|
||||
},
|
||||
cwd=str(home),
|
||||
)
|
||||
|
||||
self.assertIn("error", response)
|
||||
self.assertFalse(target.exists())
|
||||
|
||||
def test_write_text_file_respects_safe_root(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
root = Path(tmpdir)
|
||||
safe_root = root / "workspace"
|
||||
safe_root.mkdir()
|
||||
outside = root / "outside.txt"
|
||||
|
||||
with patch.dict(os.environ, {"HERMES_WRITE_SAFE_ROOT": str(safe_root)}, clear=False):
|
||||
response = self._dispatch(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 5,
|
||||
"method": "fs/write_text_file",
|
||||
"params": {
|
||||
"path": str(outside),
|
||||
"content": "should-not-write",
|
||||
},
|
||||
},
|
||||
cwd=str(root),
|
||||
)
|
||||
|
||||
self.assertIn("error", response)
|
||||
self.assertFalse(outside.exists())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -1,129 +1,25 @@
|
|||
"""Tests for credential pool preservation through smart routing and 429 recovery.
|
||||
"""Tests for credential pool preservation through turn config and 429 recovery.
|
||||
|
||||
Covers:
|
||||
1. credential_pool flows through resolve_turn_route (no-route and fallback paths)
|
||||
2. CLI _resolve_turn_agent_config passes credential_pool to primary dict
|
||||
3. Gateway _resolve_turn_agent_config passes credential_pool to primary dict
|
||||
4. Eager fallback deferred when credential pool has credentials
|
||||
5. Eager fallback fires when no credential pool exists
|
||||
6. Full 429 rotation cycle: retry-same → rotate → exhaust → fallback
|
||||
1. CLI _resolve_turn_agent_config passes credential_pool to runtime dict
|
||||
2. Gateway _resolve_turn_agent_config passes credential_pool to runtime dict
|
||||
3. Eager fallback deferred when credential pool has credentials
|
||||
4. Eager fallback fires when no credential pool exists
|
||||
5. Full 429 rotation cycle: retry-same → rotate → exhaust → fallback
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. smart_model_routing: credential_pool preserved in no-route path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSmartRoutingPoolPreservation:
|
||||
def test_no_route_preserves_credential_pool(self):
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
|
||||
fake_pool = MagicMock(name="CredentialPool")
|
||||
primary = {
|
||||
"model": "gpt-5.4",
|
||||
"api_key": "sk-test",
|
||||
"base_url": None,
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"command": None,
|
||||
"args": [],
|
||||
"credential_pool": fake_pool,
|
||||
}
|
||||
# routing disabled
|
||||
result = resolve_turn_route("hello", None, primary)
|
||||
assert result["runtime"]["credential_pool"] is fake_pool
|
||||
|
||||
def test_no_route_none_pool(self):
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
|
||||
primary = {
|
||||
"model": "gpt-5.4",
|
||||
"api_key": "sk-test",
|
||||
"base_url": None,
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"command": None,
|
||||
"args": [],
|
||||
}
|
||||
result = resolve_turn_route("hello", None, primary)
|
||||
assert result["runtime"]["credential_pool"] is None
|
||||
|
||||
def test_routing_disabled_preserves_pool(self):
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
|
||||
fake_pool = MagicMock(name="CredentialPool")
|
||||
primary = {
|
||||
"model": "gpt-5.4",
|
||||
"api_key": "sk-test",
|
||||
"base_url": None,
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"command": None,
|
||||
"args": [],
|
||||
"credential_pool": fake_pool,
|
||||
}
|
||||
# routing explicitly disabled
|
||||
result = resolve_turn_route("hello", {"enabled": False}, primary)
|
||||
assert result["runtime"]["credential_pool"] is fake_pool
|
||||
|
||||
def test_route_fallback_on_resolve_error_preserves_pool(self, monkeypatch):
|
||||
"""When smart routing picks a cheap model but resolve_runtime_provider
|
||||
fails, the fallback to primary must still include credential_pool."""
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
|
||||
fake_pool = MagicMock(name="CredentialPool")
|
||||
primary = {
|
||||
"model": "gpt-5.4",
|
||||
"api_key": "sk-test",
|
||||
"base_url": None,
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"command": None,
|
||||
"args": [],
|
||||
"credential_pool": fake_pool,
|
||||
}
|
||||
routing_config = {
|
||||
"enabled": True,
|
||||
"cheap_model": "openai/gpt-4.1-mini",
|
||||
"cheap_provider": "openrouter",
|
||||
"max_tokens": 200,
|
||||
"patterns": ["^(hi|hello|hey)"],
|
||||
}
|
||||
# Force resolve_runtime_provider to fail so it falls back to primary
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
MagicMock(side_effect=RuntimeError("no credentials")),
|
||||
)
|
||||
result = resolve_turn_route("hi", routing_config, primary)
|
||||
assert result["runtime"]["credential_pool"] is fake_pool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2 & 3. CLI and Gateway _resolve_turn_agent_config include credential_pool
|
||||
# 1. CLI _resolve_turn_agent_config includes credential_pool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCliTurnRoutePool:
|
||||
def test_resolve_turn_includes_pool(self, monkeypatch, tmp_path):
|
||||
"""CLI's _resolve_turn_agent_config must pass credential_pool to primary."""
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
captured = {}
|
||||
|
||||
def spy_resolve(user_message, routing_config, primary):
|
||||
captured["primary"] = primary
|
||||
return resolve_turn_route(user_message, routing_config, primary)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"agent.smart_model_routing.resolve_turn_route", spy_resolve
|
||||
)
|
||||
|
||||
# Build a minimal HermesCLI-like object with the method
|
||||
def test_resolve_turn_includes_pool(self):
|
||||
"""CLI's _resolve_turn_agent_config must pass credential_pool in runtime."""
|
||||
fake_pool = MagicMock(name="FakePool")
|
||||
shell = SimpleNamespace(
|
||||
model="gpt-5.4",
|
||||
api_key="sk-test",
|
||||
|
|
@ -132,58 +28,46 @@ class TestCliTurnRoutePool:
|
|||
api_mode="codex_responses",
|
||||
acp_command=None,
|
||||
acp_args=[],
|
||||
_credential_pool=MagicMock(name="FakePool"),
|
||||
_smart_model_routing={"enabled": False},
|
||||
_credential_pool=fake_pool,
|
||||
service_tier=None,
|
||||
)
|
||||
|
||||
# Import and bind the real method
|
||||
from cli import HermesCLI
|
||||
bound = HermesCLI._resolve_turn_agent_config.__get__(shell)
|
||||
bound("test message")
|
||||
route = bound("test message")
|
||||
|
||||
assert "credential_pool" in captured["primary"]
|
||||
assert captured["primary"]["credential_pool"] is shell._credential_pool
|
||||
assert route["runtime"]["credential_pool"] is fake_pool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Gateway _resolve_turn_agent_config includes credential_pool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGatewayTurnRoutePool:
|
||||
def test_resolve_turn_includes_pool(self, monkeypatch):
|
||||
def test_resolve_turn_includes_pool(self):
|
||||
"""Gateway's _resolve_turn_agent_config must pass credential_pool."""
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
captured = {}
|
||||
|
||||
def spy_resolve(user_message, routing_config, primary):
|
||||
captured["primary"] = primary
|
||||
return resolve_turn_route(user_message, routing_config, primary)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"agent.smart_model_routing.resolve_turn_route", spy_resolve
|
||||
)
|
||||
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = SimpleNamespace(
|
||||
_smart_model_routing={"enabled": False},
|
||||
)
|
||||
|
||||
fake_pool = MagicMock(name="FakePool")
|
||||
runner = SimpleNamespace(_service_tier=None)
|
||||
runtime_kwargs = {
|
||||
"api_key": "sk-test",
|
||||
"api_key": "***",
|
||||
"base_url": None,
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"command": None,
|
||||
"args": [],
|
||||
"credential_pool": MagicMock(name="FakePool"),
|
||||
"credential_pool": fake_pool,
|
||||
}
|
||||
|
||||
bound = GatewayRunner._resolve_turn_agent_config.__get__(runner)
|
||||
bound("test message", "gpt-5.4", runtime_kwargs)
|
||||
route = bound("test message", "gpt-5.4", runtime_kwargs)
|
||||
|
||||
assert "credential_pool" in captured["primary"]
|
||||
assert captured["primary"]["credential_pool"] is runtime_kwargs["credential_pool"]
|
||||
assert route["runtime"]["credential_pool"] is fake_pool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4 & 5. Eager fallback deferred/fires based on credential pool
|
||||
# 3 & 4. Eager fallback deferred/fires based on credential pool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEagerFallbackWithPool:
|
||||
|
|
@ -251,7 +135,7 @@ class TestEagerFallbackWithPool:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. Full 429 rotation cycle via _recover_with_credential_pool
|
||||
# 5. Full 429 rotation cycle via _recover_with_credential_pool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPoolRotationCycle:
|
||||
|
|
|
|||
27
tests/agent/test_direct_provider_url_detection.py
Normal file
27
tests/agent/test_direct_provider_url_detection.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
def _agent_with_base_url(base_url: str) -> AIAgent:
|
||||
agent = object.__new__(AIAgent)
|
||||
agent.base_url = base_url
|
||||
return agent
|
||||
|
||||
|
||||
def test_direct_openai_url_requires_openai_host():
|
||||
agent = _agent_with_base_url("https://api.openai.com.example/v1")
|
||||
|
||||
assert agent._is_direct_openai_url() is False
|
||||
|
||||
|
||||
def test_direct_openai_url_ignores_path_segment_match():
|
||||
agent = _agent_with_base_url("https://proxy.example.test/api.openai.com/v1")
|
||||
|
||||
assert agent._is_direct_openai_url() is False
|
||||
|
||||
|
||||
def test_direct_openai_url_accepts_native_host():
|
||||
agent = _agent_with_base_url("https://api.openai.com/v1")
|
||||
|
||||
assert agent._is_direct_openai_url() is True
|
||||
|
|
@ -83,6 +83,13 @@ class TestBuildToolPreview:
|
|||
assert result is not None
|
||||
assert "user" in result
|
||||
|
||||
def test_memory_replace_missing_old_text_marked(self):
|
||||
# Avoid empty quotes "" in the preview when old_text is missing/None.
|
||||
result = build_tool_preview("memory", {"action": "replace", "target": "memory"})
|
||||
assert result == '~memory: "<missing old_text>"'
|
||||
result = build_tool_preview("memory", {"action": "remove", "target": "memory", "old_text": None})
|
||||
assert result == '-memory: "<missing old_text>"'
|
||||
|
||||
def test_session_search_preview(self):
|
||||
result = build_tool_preview("session_search", {"query": "find something"})
|
||||
assert result is not None
|
||||
|
|
|
|||
|
|
@ -298,9 +298,15 @@ class TestClassifyApiError:
|
|||
assert result.retryable is False
|
||||
|
||||
def test_404_generic(self):
|
||||
# Generic 404 with no "model not found" signal — common for local
|
||||
# llama.cpp/Ollama/vLLM endpoints with slightly wrong paths. Treat
|
||||
# as unknown (retryable) so the real error surfaces, rather than
|
||||
# claiming the model is missing and silently falling back.
|
||||
e = MockAPIError("Not Found", status_code=404)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.model_not_found
|
||||
assert result.reason == FailoverReason.unknown
|
||||
assert result.retryable is True
|
||||
assert result.should_fallback is False
|
||||
|
||||
# ── Payload too large ──
|
||||
|
||||
|
|
@ -849,3 +855,97 @@ class TestAdversarialEdgeCases:
|
|||
)
|
||||
result = classify_api_error(e, provider="openrouter")
|
||||
assert result.reason == FailoverReason.model_not_found
|
||||
|
||||
# ── Regression: dict-typed message field (Issue #11233) ──
|
||||
|
||||
def test_pydantic_dict_message_no_crash(self):
|
||||
"""Pydantic validation errors return message as dict, not string.
|
||||
|
||||
Regression: classify_api_error must not crash when body['message']
|
||||
is a dict (e.g. {"detail": [...]} from FastAPI/Pydantic). The
|
||||
'or ""' fallback only handles None/falsy values — a non-empty
|
||||
dict is truthy and passed to .lower(), causing AttributeError.
|
||||
"""
|
||||
e = MockAPIError(
|
||||
"Unprocessable Entity",
|
||||
status_code=422,
|
||||
body={
|
||||
"object": "error",
|
||||
"message": {
|
||||
"detail": [
|
||||
{
|
||||
"type": "extra_forbidden",
|
||||
"loc": ["body", "think"],
|
||||
"msg": "Extra inputs are not permitted",
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.format_error
|
||||
assert result.status_code == 422
|
||||
assert result.retryable is False
|
||||
|
||||
def test_nested_error_dict_message_no_crash(self):
|
||||
"""Nested body['error']['message'] as dict must not crash.
|
||||
|
||||
Some providers wrap Pydantic errors in an 'error' object.
|
||||
"""
|
||||
e = MockAPIError(
|
||||
"Validation error",
|
||||
status_code=400,
|
||||
body={
|
||||
"error": {
|
||||
"message": {
|
||||
"detail": [
|
||||
{"type": "missing", "loc": ["body", "required"]}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
result = classify_api_error(e, approx_tokens=1000)
|
||||
assert result.reason == FailoverReason.format_error
|
||||
assert result.status_code == 400
|
||||
|
||||
def test_metadata_raw_dict_message_no_crash(self):
|
||||
"""OpenRouter metadata.raw with dict message must not crash."""
|
||||
e = MockAPIError(
|
||||
"Provider error",
|
||||
status_code=400,
|
||||
body={
|
||||
"error": {
|
||||
"message": "Provider error",
|
||||
"metadata": {
|
||||
"raw": '{"error":{"message":{"detail":[{"type":"invalid"}]}}}'
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.format_error
|
||||
|
||||
# Broader non-string type guards — defense against other provider quirks.
|
||||
|
||||
def test_list_message_no_crash(self):
|
||||
"""Some providers return message as a list of error entries."""
|
||||
e = MockAPIError(
|
||||
"validation",
|
||||
status_code=400,
|
||||
body={"message": [{"msg": "field required"}]},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
assert result is not None
|
||||
|
||||
def test_int_message_no_crash(self):
|
||||
"""Any non-string type must be coerced safely."""
|
||||
e = MockAPIError("server error", status_code=500, body={"message": 42})
|
||||
result = classify_api_error(e)
|
||||
assert result is not None
|
||||
|
||||
def test_none_message_still_works(self):
|
||||
"""Regression: None fallback (the 'or \"\"' path) must still work."""
|
||||
e = MockAPIError("server error", status_code=500, body={"message": None})
|
||||
result = classify_api_error(e)
|
||||
assert result is not None
|
||||
|
|
|
|||
|
|
@ -652,6 +652,42 @@ class TestBuildGeminiRequest:
|
|||
assert decls[0]["description"] == "foo"
|
||||
assert decls[0]["parameters"] == {"type": "object"}
|
||||
|
||||
def test_tools_strip_json_schema_only_fields_from_parameters(self):
|
||||
from agent.gemini_cloudcode_adapter import build_gemini_request
|
||||
|
||||
req = build_gemini_request(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[
|
||||
{"type": "function", "function": {
|
||||
"name": "fn1",
|
||||
"description": "foo",
|
||||
"parameters": {
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"$schema": "ignored",
|
||||
"description": "City name",
|
||||
"additionalProperties": False,
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
}},
|
||||
],
|
||||
)
|
||||
params = req["tools"][0]["functionDeclarations"][0]["parameters"]
|
||||
assert "$schema" not in params
|
||||
assert "additionalProperties" not in params
|
||||
assert params["type"] == "object"
|
||||
assert params["required"] == ["city"]
|
||||
assert params["properties"]["city"] == {
|
||||
"type": "string",
|
||||
"description": "City name",
|
||||
}
|
||||
|
||||
def test_tool_choice_auto(self):
|
||||
from agent.gemini_cloudcode_adapter import build_gemini_request
|
||||
|
||||
|
|
@ -814,6 +850,69 @@ class TestTranslateGeminiResponse:
|
|||
assert _map_gemini_finish_reason("RECITATION") == "content_filter"
|
||||
|
||||
|
||||
class TestTranslateStreamEvent:
|
||||
def test_parallel_calls_to_same_tool_get_unique_indices(self):
|
||||
"""Gemini may emit several functionCall parts with the same name in a
|
||||
single turn (e.g. parallel file reads). Each must get its own OpenAI
|
||||
``index`` — otherwise downstream aggregators collapse them into one.
|
||||
"""
|
||||
from agent.gemini_cloudcode_adapter import _translate_stream_event
|
||||
|
||||
event = {
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {"parts": [
|
||||
{"functionCall": {"name": "read_file", "args": {"path": "a"}}},
|
||||
{"functionCall": {"name": "read_file", "args": {"path": "b"}}},
|
||||
{"functionCall": {"name": "read_file", "args": {"path": "c"}}},
|
||||
]},
|
||||
}],
|
||||
}
|
||||
}
|
||||
counter = [0]
|
||||
chunks = _translate_stream_event(event, model="gemini-2.5-flash",
|
||||
tool_call_counter=counter)
|
||||
indices = [c.choices[0].delta.tool_calls[0].index for c in chunks]
|
||||
assert indices == [0, 1, 2]
|
||||
assert counter[0] == 3
|
||||
|
||||
def test_counter_persists_across_events(self):
|
||||
"""Index assignment must continue across SSE events in the same stream."""
|
||||
from agent.gemini_cloudcode_adapter import _translate_stream_event
|
||||
|
||||
def _event(name):
|
||||
return {"response": {"candidates": [{
|
||||
"content": {"parts": [{"functionCall": {"name": name, "args": {}}}]},
|
||||
}]}}
|
||||
|
||||
counter = [0]
|
||||
chunks_a = _translate_stream_event(_event("foo"), model="m", tool_call_counter=counter)
|
||||
chunks_b = _translate_stream_event(_event("bar"), model="m", tool_call_counter=counter)
|
||||
chunks_c = _translate_stream_event(_event("foo"), model="m", tool_call_counter=counter)
|
||||
|
||||
assert chunks_a[0].choices[0].delta.tool_calls[0].index == 0
|
||||
assert chunks_b[0].choices[0].delta.tool_calls[0].index == 1
|
||||
assert chunks_c[0].choices[0].delta.tool_calls[0].index == 2
|
||||
|
||||
def test_finish_reason_switches_to_tool_calls_when_any_seen(self):
|
||||
from agent.gemini_cloudcode_adapter import _translate_stream_event
|
||||
|
||||
counter = [0]
|
||||
# First event emits one tool call.
|
||||
_translate_stream_event(
|
||||
{"response": {"candidates": [{
|
||||
"content": {"parts": [{"functionCall": {"name": "x", "args": {}}}]},
|
||||
}]}},
|
||||
model="m", tool_call_counter=counter,
|
||||
)
|
||||
# Second event carries only the terminal finishReason.
|
||||
chunks = _translate_stream_event(
|
||||
{"response": {"candidates": [{"finishReason": "STOP"}]}},
|
||||
model="m", tool_call_counter=counter,
|
||||
)
|
||||
assert chunks[-1].choices[0].finish_reason == "tool_calls"
|
||||
|
||||
|
||||
class TestGeminiCloudCodeClient:
|
||||
def test_client_exposes_openai_interface(self):
|
||||
from agent.gemini_cloudcode_adapter import GeminiCloudCodeClient
|
||||
|
|
|
|||
315
tests/agent/test_gemini_native_adapter.py
Normal file
315
tests/agent/test_gemini_native_adapter.py
Normal file
|
|
@ -0,0 +1,315 @@
|
|||
"""Tests for the native Google AI Studio Gemini adapter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class DummyResponse:
|
||||
def __init__(self, status_code=200, payload=None, headers=None, text=None):
|
||||
self.status_code = status_code
|
||||
self._payload = payload or {}
|
||||
self.headers = headers or {}
|
||||
self.text = text if text is not None else json.dumps(self._payload)
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
def test_build_native_request_preserves_thought_signature_on_tool_replay():
|
||||
from agent.gemini_native_adapter import build_gemini_request
|
||||
|
||||
request = build_gemini_request(
|
||||
messages=[
|
||||
{"role": "system", "content": "Be helpful."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Paris"}',
|
||||
},
|
||||
"extra_content": {
|
||||
"google": {"thought_signature": "sig-123"}
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
tools=[],
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
parts = request["contents"][0]["parts"]
|
||||
assert parts[0]["functionCall"]["name"] == "get_weather"
|
||||
assert parts[0]["thoughtSignature"] == "sig-123"
|
||||
|
||||
|
||||
def test_build_native_request_uses_original_function_name_for_tool_result():
|
||||
from agent.gemini_native_adapter import build_gemini_request
|
||||
|
||||
request = build_gemini_request(
|
||||
messages=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Paris"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": '{"forecast": "sunny"}',
|
||||
},
|
||||
],
|
||||
tools=[],
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
tool_response = request["contents"][1]["parts"][0]["functionResponse"]
|
||||
assert tool_response["name"] == "get_weather"
|
||||
|
||||
|
||||
def test_build_native_request_strips_json_schema_only_fields_from_tool_parameters():
|
||||
from agent.gemini_native_adapter import build_gemini_request
|
||||
|
||||
request = build_gemini_request(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup_weather",
|
||||
"description": "Weather lookup",
|
||||
"parameters": {
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"$schema": "ignored",
|
||||
"description": "City name",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
params = request["tools"][0]["functionDeclarations"][0]["parameters"]
|
||||
assert "$schema" not in params
|
||||
assert "additionalProperties" not in params
|
||||
assert params["type"] == "object"
|
||||
assert params["properties"]["city"] == {
|
||||
"type": "string",
|
||||
"description": "City name",
|
||||
}
|
||||
|
||||
|
||||
def test_translate_native_response_surfaces_reasoning_and_tool_calls():
|
||||
from agent.gemini_native_adapter import translate_gemini_response
|
||||
|
||||
payload = {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{"thought": True, "text": "thinking..."},
|
||||
{"functionCall": {"name": "search", "args": {"q": "hermes"}}},
|
||||
]
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 10,
|
||||
"candidatesTokenCount": 5,
|
||||
"totalTokenCount": 15,
|
||||
},
|
||||
}
|
||||
|
||||
response = translate_gemini_response(payload, model="gemini-2.5-flash")
|
||||
choice = response.choices[0]
|
||||
assert choice.finish_reason == "tool_calls"
|
||||
assert choice.message.reasoning == "thinking..."
|
||||
assert choice.message.tool_calls[0].function.name == "search"
|
||||
assert json.loads(choice.message.tool_calls[0].function.arguments) == {"q": "hermes"}
|
||||
|
||||
|
||||
def test_native_client_uses_x_goog_api_key_and_native_models_endpoint(monkeypatch):
|
||||
from agent.gemini_native_adapter import GeminiNativeClient
|
||||
|
||||
recorded = {}
|
||||
|
||||
class DummyHTTP:
|
||||
def post(self, url, json=None, headers=None, timeout=None):
|
||||
recorded["url"] = url
|
||||
recorded["json"] = json
|
||||
recorded["headers"] = headers
|
||||
return DummyResponse(
|
||||
payload={
|
||||
"candidates": [
|
||||
{
|
||||
"content": {"parts": [{"text": "hello"}]},
|
||||
"finishReason": "STOP",
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 1,
|
||||
"candidatesTokenCount": 1,
|
||||
"totalTokenCount": 2,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def close(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("agent.gemini_native_adapter.httpx.Client", lambda *a, **k: DummyHTTP())
|
||||
|
||||
client = GeminiNativeClient(api_key="AIza-test", base_url="https://generativelanguage.googleapis.com/v1beta")
|
||||
response = client.chat.completions.create(
|
||||
model="gemini-2.5-flash",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
|
||||
assert recorded["url"] == "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent"
|
||||
assert recorded["headers"]["x-goog-api-key"] == "AIza-test"
|
||||
assert "Authorization" not in recorded["headers"]
|
||||
assert response.choices[0].message.content == "hello"
|
||||
|
||||
|
||||
def test_native_http_error_keeps_status_and_retry_after():
|
||||
from agent.gemini_native_adapter import gemini_http_error
|
||||
|
||||
response = DummyResponse(
|
||||
status_code=429,
|
||||
headers={"Retry-After": "17"},
|
||||
payload={
|
||||
"error": {
|
||||
"code": 429,
|
||||
"message": "quota exhausted",
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
|
||||
"reason": "RESOURCE_EXHAUSTED",
|
||||
"metadata": {"service": "generativelanguage.googleapis.com"},
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
err = gemini_http_error(response)
|
||||
assert getattr(err, "status_code", None) == 429
|
||||
assert getattr(err, "retry_after", None) == 17.0
|
||||
assert "quota exhausted" in str(err)
|
||||
|
||||
|
||||
def test_native_client_accepts_injected_http_client():
|
||||
from agent.gemini_native_adapter import GeminiNativeClient
|
||||
|
||||
injected = SimpleNamespace(close=lambda: None)
|
||||
client = GeminiNativeClient(api_key="AIza-test", http_client=injected)
|
||||
assert client._http is injected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_native_client_streams_without_requiring_async_iterator_from_sync_client():
|
||||
from agent.gemini_native_adapter import AsyncGeminiNativeClient
|
||||
|
||||
chunk = SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="hi"), finish_reason=None)])
|
||||
sync_stream = iter([chunk])
|
||||
|
||||
def _advance(iterator):
|
||||
try:
|
||||
return False, next(iterator)
|
||||
except StopIteration:
|
||||
return True, None
|
||||
|
||||
sync_client = SimpleNamespace(
|
||||
api_key="AIza-test",
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta",
|
||||
chat=SimpleNamespace(completions=SimpleNamespace(create=lambda **kwargs: sync_stream)),
|
||||
_advance_stream_iterator=_advance,
|
||||
close=lambda: None,
|
||||
)
|
||||
|
||||
async_client = AsyncGeminiNativeClient(sync_client)
|
||||
stream = await async_client.chat.completions.create(stream=True)
|
||||
collected = []
|
||||
async for item in stream:
|
||||
collected.append(item)
|
||||
assert collected == [chunk]
|
||||
|
||||
|
||||
def test_stream_event_translation_emits_tool_call_delta_with_stable_index():
|
||||
from agent.gemini_native_adapter import translate_stream_event
|
||||
|
||||
tool_call_indices = {}
|
||||
event = {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{"functionCall": {"name": "search", "args": {"q": "abc"}}}
|
||||
]
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
first = translate_stream_event(event, model="gemini-2.5-flash", tool_call_indices=tool_call_indices)
|
||||
second = translate_stream_event(event, model="gemini-2.5-flash", tool_call_indices=tool_call_indices)
|
||||
|
||||
assert first[0].choices[0].delta.tool_calls[0].index == 0
|
||||
assert second[0].choices[0].delta.tool_calls[0].index == 0
|
||||
assert first[0].choices[0].delta.tool_calls[0].id == second[0].choices[0].delta.tool_calls[0].id
|
||||
assert first[0].choices[0].delta.tool_calls[0].function.arguments == '{"q": "abc"}'
|
||||
assert second[0].choices[0].delta.tool_calls[0].function.arguments == ""
|
||||
assert first[-1].choices[0].finish_reason == "tool_calls"
|
||||
|
||||
|
||||
def test_stream_event_translation_keeps_identical_calls_in_distinct_parts():
|
||||
from agent.gemini_native_adapter import translate_stream_event
|
||||
|
||||
event = {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{"functionCall": {"name": "search", "args": {"q": "abc"}}},
|
||||
{"functionCall": {"name": "search", "args": {"q": "abc"}}},
|
||||
]
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
chunks = translate_stream_event(event, model="gemini-2.5-flash", tool_call_indices={})
|
||||
tool_chunks = [chunk for chunk in chunks if chunk.choices[0].delta.tool_calls]
|
||||
assert tool_chunks[0].choices[0].delta.tool_calls[0].index == 0
|
||||
assert tool_chunks[1].choices[0].delta.tool_calls[0].index == 1
|
||||
assert tool_chunks[0].choices[0].delta.tool_calls[0].id != tool_chunks[1].choices[0].delta.tool_calls[0].id
|
||||
111
tests/agent/test_image_gen_registry.py
Normal file
111
tests/agent/test_image_gen_registry.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Tests for agent/image_gen_registry.py — provider registration & active lookup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agent import image_gen_registry
|
||||
from agent.image_gen_provider import ImageGenProvider
|
||||
|
||||
|
||||
class _FakeProvider(ImageGenProvider):
|
||||
def __init__(self, name: str, available: bool = True):
|
||||
self._name = name
|
||||
self._available = available
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._available
|
||||
|
||||
def generate(self, prompt, aspect_ratio="landscape", **kw):
|
||||
return {"success": True, "image": f"{self._name}://{prompt}"}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
image_gen_registry._reset_for_tests()
|
||||
yield
|
||||
image_gen_registry._reset_for_tests()
|
||||
|
||||
|
||||
class TestRegisterProvider:
|
||||
def test_register_and_lookup(self):
|
||||
provider = _FakeProvider("fake")
|
||||
image_gen_registry.register_provider(provider)
|
||||
assert image_gen_registry.get_provider("fake") is provider
|
||||
|
||||
def test_rejects_non_provider(self):
|
||||
with pytest.raises(TypeError):
|
||||
image_gen_registry.register_provider("not a provider") # type: ignore[arg-type]
|
||||
|
||||
def test_rejects_empty_name(self):
|
||||
class Empty(ImageGenProvider):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return ""
|
||||
|
||||
def generate(self, prompt, aspect_ratio="landscape", **kw):
|
||||
return {}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
image_gen_registry.register_provider(Empty())
|
||||
|
||||
def test_reregister_overwrites(self):
|
||||
a = _FakeProvider("same")
|
||||
b = _FakeProvider("same")
|
||||
image_gen_registry.register_provider(a)
|
||||
image_gen_registry.register_provider(b)
|
||||
assert image_gen_registry.get_provider("same") is b
|
||||
|
||||
def test_list_is_sorted(self):
|
||||
image_gen_registry.register_provider(_FakeProvider("zeta"))
|
||||
image_gen_registry.register_provider(_FakeProvider("alpha"))
|
||||
names = [p.name for p in image_gen_registry.list_providers()]
|
||||
assert names == ["alpha", "zeta"]
|
||||
|
||||
|
||||
class TestGetActiveProvider:
|
||||
def test_single_provider_autoresolves(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
image_gen_registry.register_provider(_FakeProvider("solo"))
|
||||
active = image_gen_registry.get_active_provider()
|
||||
assert active is not None and active.name == "solo"
|
||||
|
||||
def test_fal_preferred_on_multi_without_config(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
image_gen_registry.register_provider(_FakeProvider("fal"))
|
||||
image_gen_registry.register_provider(_FakeProvider("openai"))
|
||||
active = image_gen_registry.get_active_provider()
|
||||
assert active is not None and active.name == "fal"
|
||||
|
||||
def test_explicit_config_wins(self, tmp_path, monkeypatch):
|
||||
import yaml
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
yaml.safe_dump({"image_gen": {"provider": "openai"}})
|
||||
)
|
||||
image_gen_registry.register_provider(_FakeProvider("fal"))
|
||||
image_gen_registry.register_provider(_FakeProvider("openai"))
|
||||
active = image_gen_registry.get_active_provider()
|
||||
assert active is not None and active.name == "openai"
|
||||
|
||||
def test_missing_configured_provider_falls_back(self, tmp_path, monkeypatch):
|
||||
import yaml
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
yaml.safe_dump({"image_gen": {"provider": "replicate"}})
|
||||
)
|
||||
# Only FAL is registered — configured provider doesn't exist
|
||||
image_gen_registry.register_provider(_FakeProvider("fal"))
|
||||
active = image_gen_registry.get_active_provider()
|
||||
# Falls back to FAL preference (legacy default) rather than None
|
||||
assert active is not None and active.name == "fal"
|
||||
|
||||
def test_none_when_empty(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
assert image_gen_registry.get_active_provider() is None
|
||||
|
|
@ -51,6 +51,12 @@ def populated_db(db):
|
|||
db.append_message("s1", role="assistant", content="I found the bug. Let me fix it.",
|
||||
tool_calls=[{"function": {"name": "patch"}}])
|
||||
db.append_message("s1", role="tool", content="patched successfully", tool_name="patch")
|
||||
db.append_message(
|
||||
"s1",
|
||||
role="assistant",
|
||||
content="Let me load the PR workflow skill.",
|
||||
tool_calls=[{"function": {"name": "skill_view", "arguments": '{"name":"github-pr-workflow"}'}}],
|
||||
)
|
||||
db.append_message("s1", role="user", content="Thanks!")
|
||||
db.append_message("s1", role="assistant", content="You're welcome!")
|
||||
|
||||
|
|
@ -88,6 +94,12 @@ def populated_db(db):
|
|||
db.append_message("s3", role="assistant", content="And search files",
|
||||
tool_calls=[{"function": {"name": "search_files"}}])
|
||||
db.append_message("s3", role="tool", content="found stuff", tool_name="search_files")
|
||||
db.append_message(
|
||||
"s3",
|
||||
role="assistant",
|
||||
content="Load the debugging skill.",
|
||||
tool_calls=[{"function": {"name": "skill_view", "arguments": '{"name":"systematic-debugging"}'}}],
|
||||
)
|
||||
|
||||
# Session 4: Discord, same model as s1, ended, 1 day ago
|
||||
db.create_session(
|
||||
|
|
@ -100,6 +112,15 @@ def populated_db(db):
|
|||
db.update_token_counts("s4", input_tokens=10000, output_tokens=5000)
|
||||
db.append_message("s4", role="user", content="Quick question")
|
||||
db.append_message("s4", role="assistant", content="Sure, go ahead")
|
||||
db.append_message(
|
||||
"s4",
|
||||
role="assistant",
|
||||
content="Load and update GitHub skills.",
|
||||
tool_calls=[
|
||||
{"function": {"name": "skill_view", "arguments": '{"name":"github-pr-workflow"}'}},
|
||||
{"function": {"name": "skill_manage", "arguments": '{"name":"github-code-review"}'}},
|
||||
],
|
||||
)
|
||||
|
||||
# Session 5: Old session, 45 days ago (should be excluded from 30-day window)
|
||||
db.create_session(
|
||||
|
|
@ -332,6 +353,35 @@ class TestInsightsPopulated:
|
|||
total_pct = sum(t["percentage"] for t in tools)
|
||||
assert total_pct == pytest.approx(100.0, abs=0.1)
|
||||
|
||||
def test_skill_breakdown(self, populated_db):
|
||||
engine = InsightsEngine(populated_db)
|
||||
report = engine.generate(days=30)
|
||||
skills = report["skills"]
|
||||
|
||||
assert skills["summary"]["distinct_skills_used"] == 3
|
||||
assert skills["summary"]["total_skill_loads"] == 3
|
||||
assert skills["summary"]["total_skill_edits"] == 1
|
||||
assert skills["summary"]["total_skill_actions"] == 4
|
||||
|
||||
top_skill = skills["top_skills"][0]
|
||||
assert top_skill["skill"] == "github-pr-workflow"
|
||||
assert top_skill["view_count"] == 2
|
||||
assert top_skill["manage_count"] == 0
|
||||
assert top_skill["total_count"] == 2
|
||||
assert top_skill["last_used_at"] is not None
|
||||
|
||||
def test_skill_breakdown_respects_days_filter(self, populated_db):
|
||||
engine = InsightsEngine(populated_db)
|
||||
report = engine.generate(days=3)
|
||||
skills = report["skills"]
|
||||
|
||||
assert skills["summary"]["distinct_skills_used"] == 2
|
||||
assert skills["summary"]["total_skill_loads"] == 2
|
||||
assert skills["summary"]["total_skill_edits"] == 1
|
||||
|
||||
skill_names = [s["skill"] for s in skills["top_skills"]]
|
||||
assert "systematic-debugging" not in skill_names
|
||||
|
||||
def test_activity_patterns(self, populated_db):
|
||||
engine = InsightsEngine(populated_db)
|
||||
report = engine.generate(days=30)
|
||||
|
|
@ -401,6 +451,7 @@ class TestTerminalFormatting:
|
|||
assert "Overview" in text
|
||||
assert "Models Used" in text
|
||||
assert "Top Tools" in text
|
||||
assert "Top Skills" in text
|
||||
assert "Activity Patterns" in text
|
||||
assert "Notable Sessions" in text
|
||||
|
||||
|
|
@ -465,12 +516,12 @@ class TestGatewayFormatting:
|
|||
assert "**" in text # Markdown bold
|
||||
|
||||
def test_gateway_format_hides_cost(self, populated_db):
|
||||
"""Gateway format omits dollar figures and internal cache details."""
|
||||
engine = InsightsEngine(populated_db)
|
||||
report = engine.generate(days=30)
|
||||
text = engine.format_gateway(report)
|
||||
|
||||
assert "$" not in text
|
||||
assert "Est. cost" not in text
|
||||
assert "cache" not in text.lower()
|
||||
|
||||
def test_gateway_format_shows_models(self, populated_db):
|
||||
|
|
|
|||
115
tests/agent/test_kimi_coding_anthropic_thinking.py
Normal file
115
tests/agent/test_kimi_coding_anthropic_thinking.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
"""Regression guard: don't send Anthropic ``thinking`` to Kimi's /coding endpoint.
|
||||
|
||||
Kimi's ``api.kimi.com/coding`` endpoint speaks the Anthropic Messages protocol
|
||||
but has its own thinking semantics. When ``thinking.enabled`` is present in
|
||||
the request, Kimi validates the message history and requires every prior
|
||||
assistant tool-call message to carry OpenAI-style ``reasoning_content``.
|
||||
|
||||
The Anthropic path never populates that field, and
|
||||
``convert_messages_to_anthropic`` strips Anthropic thinking blocks on
|
||||
third-party endpoints — so after one turn with tool calls the next request
|
||||
fails with HTTP 400::
|
||||
|
||||
thinking is enabled but reasoning_content is missing in assistant
|
||||
tool call message at index N
|
||||
|
||||
Kimi on the chat_completions route handles ``thinking`` via ``extra_body`` in
|
||||
``ChatCompletionsTransport`` (#13503). On the Anthropic route the right
|
||||
thing to do is drop the parameter entirely and let Kimi drive reasoning
|
||||
server-side.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestKimiCodingSkipsAnthropicThinking:
|
||||
"""build_anthropic_kwargs must not inject ``thinking`` for Kimi /coding."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"base_url",
|
||||
[
|
||||
"https://api.kimi.com/coding",
|
||||
"https://api.kimi.com/coding/v1",
|
||||
"https://api.kimi.com/coding/anthropic",
|
||||
"https://api.kimi.com/coding/",
|
||||
],
|
||||
)
|
||||
def test_kimi_coding_endpoint_omits_thinking(self, base_url: str) -> None:
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="kimi-k2.5",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
base_url=base_url,
|
||||
)
|
||||
assert "thinking" not in kwargs, (
|
||||
"Anthropic thinking must not be sent to Kimi /coding — "
|
||||
"endpoint requires reasoning_content on history we don't preserve."
|
||||
)
|
||||
assert "output_config" not in kwargs
|
||||
|
||||
def test_kimi_coding_with_explicit_disabled_also_omits(self) -> None:
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="kimi-k2.5",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": False},
|
||||
base_url="https://api.kimi.com/coding",
|
||||
)
|
||||
assert "thinking" not in kwargs
|
||||
|
||||
def test_non_kimi_third_party_still_gets_thinking(self) -> None:
|
||||
"""MiniMax and other third-party Anthropic endpoints must retain thinking."""
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="MiniMax-M2.7",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
base_url="https://api.minimax.io/anthropic",
|
||||
)
|
||||
assert "thinking" in kwargs
|
||||
assert kwargs["thinking"]["type"] == "enabled"
|
||||
|
||||
def test_native_anthropic_still_gets_thinking(self) -> None:
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="claude-sonnet-4-20250514",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
base_url=None,
|
||||
)
|
||||
assert "thinking" in kwargs
|
||||
|
||||
def test_kimi_root_endpoint_unaffected(self) -> None:
|
||||
"""Only the /coding route is special-cased — plain api.kimi.com is not.
|
||||
|
||||
``api.kimi.com`` without ``/coding`` uses the chat_completions transport
|
||||
(see runtime_provider._detect_api_mode_for_url); build_anthropic_kwargs
|
||||
should never see it, but if it somehow does we should not suppress
|
||||
thinking there — that path has different semantics.
|
||||
"""
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="kimi-k2.5",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
base_url="https://api.kimi.com/v1",
|
||||
)
|
||||
assert "thinking" in kwargs
|
||||
|
|
@ -971,8 +971,6 @@ class TestHonchoCadenceTracking:
|
|||
class FakeManager:
|
||||
def prefetch_context(self, key, query=None):
|
||||
pass
|
||||
def prefetch_dialectic(self, key, query):
|
||||
pass
|
||||
|
||||
p._manager = FakeManager()
|
||||
|
||||
|
|
|
|||
|
|
@ -79,6 +79,28 @@ class TestMemoryManagerUserIdThreading:
|
|||
assert p._init_kwargs.get("platform") == "telegram"
|
||||
assert p._init_session_id == "sess-123"
|
||||
|
||||
def test_chat_context_forwarded_to_provider(self):
|
||||
mgr = MemoryManager()
|
||||
p = RecordingProvider()
|
||||
mgr.add_provider(p)
|
||||
|
||||
mgr.initialize_all(
|
||||
session_id="sess-chat",
|
||||
platform="discord",
|
||||
user_id="discord_u_7",
|
||||
user_name="fakeusername",
|
||||
chat_id="1485316232612941897",
|
||||
chat_name="fakeassistantname-forums",
|
||||
chat_type="thread",
|
||||
thread_id="1491249007475949698",
|
||||
)
|
||||
|
||||
assert p._init_kwargs.get("user_name") == "fakeusername"
|
||||
assert p._init_kwargs.get("chat_id") == "1485316232612941897"
|
||||
assert p._init_kwargs.get("chat_name") == "fakeassistantname-forums"
|
||||
assert p._init_kwargs.get("chat_type") == "thread"
|
||||
assert p._init_kwargs.get("thread_id") == "1491249007475949698"
|
||||
|
||||
def test_no_user_id_when_cli(self):
|
||||
"""CLI sessions should not have user_id in kwargs."""
|
||||
mgr = MemoryManager()
|
||||
|
|
@ -208,34 +230,81 @@ class TestMem0UserIdScoping:
|
|||
|
||||
|
||||
class TestHonchoUserIdScoping:
|
||||
"""Verify Honcho plugin uses gateway user_id for peer_name when provided."""
|
||||
"""Verify Honcho plugin keeps runtime user scoping separate from config peer_name."""
|
||||
|
||||
def test_gateway_user_id_overrides_peer_name(self):
|
||||
"""When user_id is in kwargs and no explicit peer_name, user_id should be used."""
|
||||
def test_gateway_user_id_is_passed_as_runtime_peer(self):
|
||||
"""Gateway user_id should scope Honcho sessions without mutating config peer_name."""
|
||||
from plugins.memory.honcho import HonchoMemoryProvider
|
||||
|
||||
provider = HonchoMemoryProvider()
|
||||
|
||||
# Create a mock config with NO explicit peer_name
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.enabled = True
|
||||
mock_cfg.api_key = "test-key"
|
||||
mock_cfg.base_url = None
|
||||
mock_cfg.peer_name = "" # No explicit peer_name — user_id should fill it
|
||||
mock_cfg.recall_mode = "tools" # Use tools mode to defer session init
|
||||
mock_cfg.peer_name = "static-user"
|
||||
mock_cfg.recall_mode = "context"
|
||||
mock_cfg.context_tokens = None
|
||||
mock_cfg.raw = {}
|
||||
mock_cfg.dialectic_depth = 1
|
||||
mock_cfg.dialectic_depth_levels = None
|
||||
mock_cfg.init_on_session_start = False
|
||||
mock_cfg.ai_peer = "hermes"
|
||||
mock_cfg.resolve_session_name.return_value = "test-sess"
|
||||
mock_cfg.session_strategy = "shared"
|
||||
|
||||
with patch(
|
||||
"plugins.memory.honcho.client.HonchoClientConfig.from_global_config",
|
||||
return_value=mock_cfg,
|
||||
):
|
||||
), patch(
|
||||
"plugins.memory.honcho.client.get_honcho_client",
|
||||
return_value=MagicMock(),
|
||||
), patch(
|
||||
"plugins.memory.honcho.session.HonchoSessionManager",
|
||||
) as mock_manager_cls:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_or_create.return_value = MagicMock(messages=[])
|
||||
mock_manager_cls.return_value = mock_manager
|
||||
provider.initialize(
|
||||
session_id="test-sess",
|
||||
user_id="discord_user_789",
|
||||
platform="discord",
|
||||
)
|
||||
|
||||
# The config's peer_name should have been overridden with the user_id
|
||||
assert mock_cfg.peer_name == "discord_user_789"
|
||||
assert mock_cfg.peer_name == "static-user"
|
||||
assert mock_manager_cls.call_args.kwargs["runtime_user_peer_name"] == "discord_user_789"
|
||||
|
||||
def test_session_manager_prefers_runtime_user_id_over_config_peer_name(self):
|
||||
"""Session manager should isolate gateway users even when config peer_name is static."""
|
||||
from plugins.memory.honcho.session import HonchoSessionManager
|
||||
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.peer_name = "static-user"
|
||||
mock_cfg.ai_peer = "hermes"
|
||||
mock_cfg.write_frequency = "sync"
|
||||
mock_cfg.dialectic_reasoning_level = "low"
|
||||
mock_cfg.dialectic_dynamic = True
|
||||
mock_cfg.dialectic_max_chars = 600
|
||||
mock_cfg.observation_mode = "directional"
|
||||
mock_cfg.user_observe_me = True
|
||||
mock_cfg.user_observe_others = True
|
||||
mock_cfg.ai_observe_me = True
|
||||
mock_cfg.ai_observe_others = True
|
||||
|
||||
manager = HonchoSessionManager(
|
||||
honcho=MagicMock(),
|
||||
config=mock_cfg,
|
||||
runtime_user_peer_name="discord_user_789",
|
||||
)
|
||||
|
||||
with patch.object(manager, "_get_or_create_peer", return_value=MagicMock()), patch.object(
|
||||
manager,
|
||||
"_get_or_create_honcho_session",
|
||||
return_value=(MagicMock(), []),
|
||||
):
|
||||
session = manager.get_or_create("discord:channel-1")
|
||||
|
||||
assert session.user_peer_id == "discord_user_789"
|
||||
|
||||
def test_no_user_id_preserves_config_peer_name(self):
|
||||
"""Without user_id, the config peer_name should be preserved."""
|
||||
|
|
@ -287,3 +356,4 @@ class TestAIAgentUserIdPropagation:
|
|||
agent = object.__new__(AIAgent)
|
||||
agent._user_id = None
|
||||
assert agent._user_id is None
|
||||
|
||||
|
|
|
|||
|
|
@ -84,38 +84,6 @@ class TestMinimaxAuxModel:
|
|||
assert "highspeed" not in _API_KEY_PROVIDER_AUX_MODELS["minimax-cn"]
|
||||
|
||||
|
||||
class TestMinimaxModelCatalog:
|
||||
"""Verify the model catalog matches official Anthropic-compat endpoint models.
|
||||
|
||||
Source: https://platform.minimax.io/docs/api-reference/text-anthropic-api
|
||||
"""
|
||||
|
||||
def test_catalog_includes_current_models(self):
|
||||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
for provider in ("minimax", "minimax-cn"):
|
||||
models = _PROVIDER_MODELS[provider]
|
||||
assert "MiniMax-M2.7" in models
|
||||
assert "MiniMax-M2.5" in models
|
||||
assert "MiniMax-M2.1" in models
|
||||
assert "MiniMax-M2" in models
|
||||
|
||||
def test_catalog_excludes_m1_family(self):
|
||||
"""M1 models are not available on the /anthropic endpoint."""
|
||||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
for provider in ("minimax", "minimax-cn"):
|
||||
models = _PROVIDER_MODELS[provider]
|
||||
assert "MiniMax-M1" not in models
|
||||
|
||||
def test_catalog_excludes_highspeed(self):
|
||||
"""Highspeed variants are available but not shown in default catalog
|
||||
(users can still specify them manually)."""
|
||||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
for provider in ("minimax", "minimax-cn"):
|
||||
models = _PROVIDER_MODELS[provider]
|
||||
assert "MiniMax-M2.7-highspeed" not in models
|
||||
assert "MiniMax-M2.5-highspeed" not in models
|
||||
|
||||
|
||||
class TestMinimaxBetaHeaders:
|
||||
"""MiniMax Anthropic-compat endpoints reject fine-grained-tool-streaming beta.
|
||||
|
||||
|
|
|
|||
|
|
@ -385,6 +385,7 @@ class TestStripProviderPrefix:
|
|||
assert _strip_provider_prefix("local:my-model") == "my-model"
|
||||
assert _strip_provider_prefix("openrouter:anthropic/claude-sonnet-4") == "anthropic/claude-sonnet-4"
|
||||
assert _strip_provider_prefix("anthropic:claude-sonnet-4") == "claude-sonnet-4"
|
||||
assert _strip_provider_prefix("stepfun:step-3.5-flash") == "step-3.5-flash"
|
||||
|
||||
def test_ollama_model_tag_preserved(self):
|
||||
"""Ollama model:tag format must NOT be stripped."""
|
||||
|
|
|
|||
|
|
@ -424,6 +424,68 @@ class TestQueryLocalContextLengthLmStudio:
|
|||
)
|
||||
|
||||
|
||||
class TestDetectLocalServerTypeAuth:
|
||||
def test_passes_bearer_token_to_probe_requests(self):
|
||||
from agent.model_metadata import detect_local_server_type
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
|
||||
client_mock = MagicMock()
|
||||
client_mock.__enter__ = lambda s: client_mock
|
||||
client_mock.__exit__ = MagicMock(return_value=False)
|
||||
client_mock.get.return_value = resp
|
||||
|
||||
with patch("httpx.Client", return_value=client_mock) as mock_client:
|
||||
result = detect_local_server_type("http://localhost:1234/v1", api_key="lm-token")
|
||||
|
||||
assert result == "lm-studio"
|
||||
assert mock_client.call_args.kwargs["headers"] == {
|
||||
"Authorization": "Bearer lm-token"
|
||||
}
|
||||
|
||||
|
||||
class TestFetchEndpointModelMetadataLmStudio:
|
||||
"""fetch_endpoint_model_metadata should use LM Studio's native models endpoint."""
|
||||
|
||||
def _make_resp(self, body):
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
resp.json.return_value = body
|
||||
return resp
|
||||
|
||||
def test_uses_native_models_endpoint_only(self):
|
||||
from agent.model_metadata import fetch_endpoint_model_metadata
|
||||
|
||||
native_resp = self._make_resp(
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"key": "lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf",
|
||||
"id": "lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf",
|
||||
"max_context_length": 131072,
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
|
||||
patch("agent.model_metadata.requests.get", return_value=native_resp) as mock_get:
|
||||
result = fetch_endpoint_model_metadata(
|
||||
"http://localhost:1234/v1",
|
||||
api_key="lm-token",
|
||||
force_refresh=True,
|
||||
)
|
||||
|
||||
assert mock_get.call_count == 1
|
||||
assert mock_get.call_args[0][0] == "http://localhost:1234/api/v1/models"
|
||||
assert mock_get.call_args.kwargs["headers"] == {
|
||||
"Authorization": "Bearer lm-token"
|
||||
}
|
||||
assert result["lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf"]["context_length"] == 131072
|
||||
assert result["Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf"]["context_length"] == 131072
|
||||
|
||||
|
||||
class TestQueryLocalContextLengthNetworkError:
|
||||
"""_query_local_context_length handles network failures gracefully."""
|
||||
|
||||
|
|
|
|||
|
|
@ -82,6 +82,7 @@ class TestProviderMapping:
|
|||
def test_known_providers_mapped(self):
|
||||
assert PROVIDER_TO_MODELS_DEV["anthropic"] == "anthropic"
|
||||
assert PROVIDER_TO_MODELS_DEV["copilot"] == "github-copilot"
|
||||
assert PROVIDER_TO_MODELS_DEV["stepfun"] == "stepfun"
|
||||
assert PROVIDER_TO_MODELS_DEV["kilocode"] == "kilo"
|
||||
assert PROVIDER_TO_MODELS_DEV["ai-gateway"] == "vercel"
|
||||
|
||||
|
|
|
|||
|
|
@ -354,6 +354,24 @@ class TestBuildSkillsSystemPrompt:
|
|||
assert "web-search" in result
|
||||
assert "old-tool" not in result
|
||||
|
||||
def test_rebuilds_prompt_when_disabled_skills_change(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
skill_dir = tmp_path / "skills" / "tools" / "cached-skill"
|
||||
skill_dir.mkdir(parents=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\nname: cached-skill\ndescription: Cached skill\n---\n"
|
||||
)
|
||||
|
||||
first = build_skills_system_prompt()
|
||||
assert "cached-skill" in first
|
||||
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
"skills:\n disabled: [cached-skill]\n"
|
||||
)
|
||||
|
||||
second = build_skills_system_prompt()
|
||||
assert "cached-skill" not in second
|
||||
|
||||
def test_includes_setup_needed_skills(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.delenv("MISSING_API_KEY_XYZ", raising=False)
|
||||
|
|
@ -771,6 +789,24 @@ class TestPromptBuilderConstants:
|
|||
assert "cron" in PLATFORM_HINTS
|
||||
assert "cli" in PLATFORM_HINTS
|
||||
|
||||
def test_cli_hint_does_not_suggest_media_tags(self):
|
||||
# Regression: MEDIA:/path tags are intercepted only by messaging
|
||||
# gateway platforms. On the CLI they render as literal text and
|
||||
# confuse users. The CLI hint must steer the agent away from them.
|
||||
cli_hint = PLATFORM_HINTS["cli"]
|
||||
assert "MEDIA:" in cli_hint, (
|
||||
"CLI hint should mention MEDIA: in order to tell the agent "
|
||||
"NOT to use it (negative guidance)."
|
||||
)
|
||||
# Must contain explicit "don't" language near the MEDIA reference.
|
||||
assert any(
|
||||
marker in cli_hint.lower()
|
||||
for marker in ("do not emit media", "not intercepted", "do not", "don't")
|
||||
), "CLI hint should explicitly discourage MEDIA: tags."
|
||||
# Messaging hints should still advertise MEDIA: positively (sanity
|
||||
# check that this test is calibrated correctly).
|
||||
assert "include MEDIA:" in PLATFORM_HINTS["telegram"]
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Environment hints
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ when proxy env vars or custom endpoint URLs are malformed.
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.auxiliary_client import _validate_base_url, _validate_proxy_env_urls
|
||||
|
|
@ -31,6 +33,12 @@ def test_proxy_env_accepts_empty(monkeypatch):
|
|||
_validate_proxy_env_urls() # should not raise
|
||||
|
||||
|
||||
def test_proxy_env_normalizes_socks_alias(monkeypatch):
|
||||
monkeypatch.setenv("ALL_PROXY", "socks://127.0.0.1:1080/")
|
||||
_validate_proxy_env_urls()
|
||||
assert os.environ["ALL_PROXY"] == "socks5://127.0.0.1:1080/"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("key", [
|
||||
"HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY",
|
||||
"http_proxy", "https_proxy", "all_proxy",
|
||||
|
|
|
|||
|
|
@ -376,3 +376,138 @@ class TestDiscordMentions:
|
|||
result = redact_sensitive_text(text)
|
||||
assert result.startswith("User ")
|
||||
assert result.endswith(" said hello")
|
||||
|
||||
|
||||
class TestUrlQueryParamRedaction:
|
||||
"""URL query-string redaction (ported from nearai/ironclaw#2529).
|
||||
|
||||
Catches opaque tokens that don't match vendor prefix regexes by
|
||||
matching on parameter NAME rather than value shape.
|
||||
"""
|
||||
|
||||
def test_oauth_callback_code(self):
|
||||
text = "GET https://api.example.com/oauth/cb?code=abc123xyz789&state=csrf_ok"
|
||||
result = redact_sensitive_text(text)
|
||||
assert "abc123xyz789" not in result
|
||||
assert "code=***" in result
|
||||
assert "state=csrf_ok" in result # state is not sensitive
|
||||
|
||||
def test_access_token_query(self):
|
||||
text = "Fetching https://example.com/api?access_token=opaque_value_here_1234&format=json"
|
||||
result = redact_sensitive_text(text)
|
||||
assert "opaque_value_here_1234" not in result
|
||||
assert "access_token=***" in result
|
||||
assert "format=json" in result
|
||||
|
||||
def test_refresh_token_query(self):
|
||||
text = "https://auth.example.com/token?refresh_token=somerefresh&grant_type=refresh"
|
||||
result = redact_sensitive_text(text)
|
||||
assert "somerefresh" not in result
|
||||
assert "grant_type=refresh" in result
|
||||
|
||||
def test_api_key_query(self):
|
||||
text = "https://api.example.com/v1/data?api_key=kABCDEF12345&limit=10"
|
||||
result = redact_sensitive_text(text)
|
||||
assert "kABCDEF12345" not in result
|
||||
assert "limit=10" in result
|
||||
|
||||
def test_presigned_signature(self):
|
||||
text = "https://s3.amazonaws.com/bucket/k?signature=LONG_PRESIGNED_SIG&id=public"
|
||||
result = redact_sensitive_text(text)
|
||||
assert "LONG_PRESIGNED_SIG" not in result
|
||||
assert "id=public" in result
|
||||
|
||||
def test_case_insensitive_param_names(self):
|
||||
"""Lowercase/mixed-case sensitive param names are redacted."""
|
||||
# NOTE: All-caps names like TOKEN= are swallowed by _ENV_ASSIGN_RE
|
||||
# (which matches KEY=value patterns greedily) before URL regex runs.
|
||||
# This test uses lowercase names to isolate URL-query redaction.
|
||||
text = "https://example.com?api_key=abcdef&secret=ghijkl"
|
||||
result = redact_sensitive_text(text)
|
||||
assert "abcdef" not in result
|
||||
assert "ghijkl" not in result
|
||||
assert "api_key=***" in result
|
||||
assert "secret=***" in result
|
||||
|
||||
def test_substring_match_does_not_trigger(self):
|
||||
"""`token_count` and `session_id` must NOT match `token` / `session`."""
|
||||
text = "https://example.com/cb?token_count=42&session_id=xyz&foo=bar"
|
||||
result = redact_sensitive_text(text)
|
||||
assert "token_count=42" in result
|
||||
assert "session_id=xyz" in result
|
||||
|
||||
def test_url_without_query_unchanged(self):
|
||||
text = "https://example.com/path/to/resource"
|
||||
assert redact_sensitive_text(text) == text
|
||||
|
||||
def test_url_with_fragment(self):
|
||||
text = "https://example.com/page?token=xyz#section"
|
||||
result = redact_sensitive_text(text)
|
||||
assert "token=xyz" not in result
|
||||
assert "#section" in result
|
||||
|
||||
def test_websocket_url_query(self):
|
||||
text = "wss://api.example.com/ws?token=opaqueWsToken123"
|
||||
result = redact_sensitive_text(text)
|
||||
assert "opaqueWsToken123" not in result
|
||||
|
||||
|
||||
class TestUrlUserinfoRedaction:
|
||||
"""URL userinfo (`scheme://user:pass@host`) for non-DB schemes."""
|
||||
|
||||
def test_https_userinfo(self):
|
||||
text = "URL: https://user:supersecretpw@host.example.com/path"
|
||||
result = redact_sensitive_text(text)
|
||||
assert "supersecretpw" not in result
|
||||
assert "https://user:***@host.example.com" in result
|
||||
|
||||
def test_http_userinfo(self):
|
||||
text = "http://admin:plaintextpass@internal.example.com/api"
|
||||
result = redact_sensitive_text(text)
|
||||
assert "plaintextpass" not in result
|
||||
|
||||
def test_ftp_userinfo(self):
|
||||
text = "ftp://user:ftppass@ftp.example.com/file.txt"
|
||||
result = redact_sensitive_text(text)
|
||||
assert "ftppass" not in result
|
||||
|
||||
def test_url_without_userinfo_unchanged(self):
|
||||
text = "https://example.com/path"
|
||||
assert redact_sensitive_text(text) == text
|
||||
|
||||
def test_db_connstr_still_handled(self):
|
||||
"""DB schemes are handled by _DB_CONNSTR_RE, not _URL_USERINFO_RE."""
|
||||
text = "postgres://admin:dbpass@db.internal:5432/app"
|
||||
result = redact_sensitive_text(text)
|
||||
assert "dbpass" not in result
|
||||
|
||||
|
||||
class TestFormBodyRedaction:
|
||||
"""Form-urlencoded body redaction (k=v&k=v with no other text)."""
|
||||
|
||||
def test_pure_form_body(self):
|
||||
text = "password=mysecret&username=bob&token=opaqueValue"
|
||||
result = redact_sensitive_text(text)
|
||||
assert "mysecret" not in result
|
||||
assert "opaqueValue" not in result
|
||||
assert "username=bob" in result
|
||||
|
||||
def test_oauth_token_request(self):
|
||||
text = "grant_type=password&client_id=app&client_secret=topsecret&username=alice&password=alicepw"
|
||||
result = redact_sensitive_text(text)
|
||||
assert "topsecret" not in result
|
||||
assert "alicepw" not in result
|
||||
assert "client_id=app" in result
|
||||
|
||||
def test_non_form_text_unchanged(self):
|
||||
"""Sentences with `&` should NOT trigger form redaction."""
|
||||
text = "I have password=foo and other things" # contains spaces
|
||||
result = redact_sensitive_text(text)
|
||||
# The space breaks the form regex; passthrough expected.
|
||||
assert "I have" in result
|
||||
|
||||
def test_multiline_text_not_form(self):
|
||||
"""Multi-line text is never treated as form body."""
|
||||
text = "first=1\nsecond=2"
|
||||
# Should pass through (still subject to other redactors)
|
||||
assert "first=1" in redact_sensitive_text(text)
|
||||
|
|
|
|||
716
tests/agent/test_shell_hooks.py
Normal file
716
tests/agent/test_shell_hooks.py
Normal file
|
|
@ -0,0 +1,716 @@
|
|||
"""Tests for the shell-hooks subprocess bridge (agent.shell_hooks).
|
||||
|
||||
These tests focus on the pure translation layer — JSON serialisation,
|
||||
JSON parsing, matcher behaviour, block-schema correctness, and the
|
||||
subprocess runner's graceful error handling. Consent prompts are
|
||||
covered in ``test_shell_hooks_consent.py``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import stat
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from agent import shell_hooks
|
||||
|
||||
|
||||
# ── helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _write_script(tmp_path: Path, name: str, body: str) -> Path:
|
||||
path = tmp_path / name
|
||||
path.write_text(body)
|
||||
path.chmod(0o755)
|
||||
return path
|
||||
|
||||
|
||||
def _allowlist_pair(monkeypatch, tmp_path, event: str, command: str) -> None:
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_home"))
|
||||
shell_hooks._record_approval(event, command)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registration_state():
|
||||
shell_hooks.reset_for_tests()
|
||||
yield
|
||||
shell_hooks.reset_for_tests()
|
||||
|
||||
|
||||
# ── _parse_response ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestParseResponse:
|
||||
def test_block_claude_code_style(self):
|
||||
r = shell_hooks._parse_response(
|
||||
"pre_tool_call",
|
||||
'{"decision": "block", "reason": "nope"}',
|
||||
)
|
||||
assert r == {"action": "block", "message": "nope"}
|
||||
|
||||
def test_block_canonical_style(self):
|
||||
r = shell_hooks._parse_response(
|
||||
"pre_tool_call",
|
||||
'{"action": "block", "message": "nope"}',
|
||||
)
|
||||
assert r == {"action": "block", "message": "nope"}
|
||||
|
||||
def test_block_canonical_wins_over_claude_style(self):
|
||||
r = shell_hooks._parse_response(
|
||||
"pre_tool_call",
|
||||
'{"action": "block", "message": "canonical", '
|
||||
'"decision": "block", "reason": "claude"}',
|
||||
)
|
||||
assert r == {"action": "block", "message": "canonical"}
|
||||
|
||||
def test_empty_stdout_returns_none(self):
|
||||
assert shell_hooks._parse_response("pre_tool_call", "") is None
|
||||
assert shell_hooks._parse_response("pre_tool_call", " ") is None
|
||||
|
||||
def test_invalid_json_returns_none(self):
|
||||
assert shell_hooks._parse_response("pre_tool_call", "not json") is None
|
||||
|
||||
def test_non_dict_json_returns_none(self):
|
||||
assert shell_hooks._parse_response("pre_tool_call", "[1, 2]") is None
|
||||
|
||||
def test_non_block_pre_tool_call_returns_none(self):
|
||||
r = shell_hooks._parse_response("pre_tool_call", '{"decision": "allow"}')
|
||||
assert r is None
|
||||
|
||||
def test_pre_llm_call_context_passthrough(self):
|
||||
r = shell_hooks._parse_response(
|
||||
"pre_llm_call", '{"context": "today is Friday"}',
|
||||
)
|
||||
assert r == {"context": "today is Friday"}
|
||||
|
||||
def test_subagent_stop_context_passthrough(self):
|
||||
r = shell_hooks._parse_response(
|
||||
"subagent_stop", '{"context": "child role=leaf"}',
|
||||
)
|
||||
assert r == {"context": "child role=leaf"}
|
||||
|
||||
def test_pre_llm_call_block_ignored(self):
|
||||
"""Only pre_tool_call honors block directives."""
|
||||
r = shell_hooks._parse_response(
|
||||
"pre_llm_call", '{"decision": "block", "reason": "no"}',
|
||||
)
|
||||
assert r is None
|
||||
|
||||
|
||||
# ── _serialize_payload ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSerializePayload:
|
||||
def test_basic_pre_tool_call_schema(self):
|
||||
raw = shell_hooks._serialize_payload(
|
||||
"pre_tool_call",
|
||||
{
|
||||
"tool_name": "terminal",
|
||||
"args": {"command": "ls"},
|
||||
"session_id": "sess-1",
|
||||
"task_id": "t-1",
|
||||
"tool_call_id": "c-1",
|
||||
},
|
||||
)
|
||||
payload = json.loads(raw)
|
||||
assert payload["hook_event_name"] == "pre_tool_call"
|
||||
assert payload["tool_name"] == "terminal"
|
||||
assert payload["tool_input"] == {"command": "ls"}
|
||||
assert payload["session_id"] == "sess-1"
|
||||
assert "cwd" in payload
|
||||
# task_id / tool_call_id end up under extra
|
||||
assert payload["extra"]["task_id"] == "t-1"
|
||||
assert payload["extra"]["tool_call_id"] == "c-1"
|
||||
|
||||
def test_args_not_dict_becomes_null(self):
|
||||
raw = shell_hooks._serialize_payload(
|
||||
"pre_tool_call", {"args": ["not", "a", "dict"]},
|
||||
)
|
||||
payload = json.loads(raw)
|
||||
assert payload["tool_input"] is None
|
||||
|
||||
def test_parent_session_id_used_when_no_session_id(self):
|
||||
raw = shell_hooks._serialize_payload(
|
||||
"subagent_stop", {"parent_session_id": "p-1"},
|
||||
)
|
||||
payload = json.loads(raw)
|
||||
assert payload["session_id"] == "p-1"
|
||||
|
||||
def test_unserialisable_extras_stringified(self):
|
||||
class Weird:
|
||||
def __repr__(self) -> str:
|
||||
return "<weird>"
|
||||
|
||||
raw = shell_hooks._serialize_payload(
|
||||
"on_session_start", {"obj": Weird()},
|
||||
)
|
||||
payload = json.loads(raw)
|
||||
assert payload["extra"]["obj"] == "<weird>"
|
||||
|
||||
|
||||
# ── Matcher behaviour ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestMatcher:
|
||||
def test_no_matcher_fires_for_any_tool(self):
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="pre_tool_call", command="echo", matcher=None,
|
||||
)
|
||||
assert spec.matches_tool("terminal")
|
||||
assert spec.matches_tool("write_file")
|
||||
|
||||
def test_single_name_matcher(self):
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="pre_tool_call", command="echo", matcher="terminal",
|
||||
)
|
||||
assert spec.matches_tool("terminal")
|
||||
assert not spec.matches_tool("web_search")
|
||||
|
||||
def test_alternation_matcher(self):
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="pre_tool_call", command="echo", matcher="terminal|file",
|
||||
)
|
||||
assert spec.matches_tool("terminal")
|
||||
assert spec.matches_tool("file")
|
||||
assert not spec.matches_tool("web")
|
||||
|
||||
def test_invalid_regex_falls_back_to_literal(self):
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="pre_tool_call", command="echo", matcher="foo[bar",
|
||||
)
|
||||
assert spec.matches_tool("foo[bar")
|
||||
assert not spec.matches_tool("foo")
|
||||
|
||||
def test_matcher_ignored_when_no_tool_name(self):
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="pre_tool_call", command="echo", matcher="terminal",
|
||||
)
|
||||
assert not spec.matches_tool(None)
|
||||
|
||||
def test_matcher_leading_whitespace_stripped(self):
|
||||
"""YAML quirks can introduce leading/trailing whitespace — must
|
||||
not silently break the matcher."""
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="pre_tool_call", command="echo", matcher=" terminal ",
|
||||
)
|
||||
assert spec.matcher == "terminal"
|
||||
assert spec.matches_tool("terminal")
|
||||
|
||||
def test_matcher_trailing_newline_stripped(self):
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="pre_tool_call", command="echo", matcher="terminal\n",
|
||||
)
|
||||
assert spec.matches_tool("terminal")
|
||||
|
||||
def test_whitespace_only_matcher_becomes_none(self):
|
||||
"""A matcher that's pure whitespace is treated as 'no matcher'."""
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="pre_tool_call", command="echo", matcher=" ",
|
||||
)
|
||||
assert spec.matcher is None
|
||||
assert spec.matches_tool("anything")
|
||||
|
||||
|
||||
# ── End-to-end subprocess behaviour ───────────────────────────────────────
|
||||
|
||||
|
||||
class TestCallbackSubprocess:
|
||||
def test_timeout_returns_none(self, tmp_path):
|
||||
# Script that sleeps forever; we set a 1s timeout.
|
||||
script = _write_script(
|
||||
tmp_path, "slow.sh",
|
||||
"#!/usr/bin/env bash\nsleep 60\n",
|
||||
)
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="post_tool_call", command=str(script), timeout=1,
|
||||
)
|
||||
cb = shell_hooks._make_callback(spec)
|
||||
assert cb(tool_name="terminal") is None
|
||||
|
||||
def test_malformed_json_stdout_returns_none(self, tmp_path):
|
||||
script = _write_script(
|
||||
tmp_path, "bad_json.sh",
|
||||
"#!/usr/bin/env bash\necho 'not json at all'\n",
|
||||
)
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="pre_tool_call", command=str(script),
|
||||
)
|
||||
cb = shell_hooks._make_callback(spec)
|
||||
# Matcher is None so the callback fires for any tool.
|
||||
assert cb(tool_name="terminal") is None
|
||||
|
||||
def test_non_zero_exit_with_block_stdout_still_blocks(self, tmp_path):
|
||||
"""A script that signals failure via exit code AND prints a block
|
||||
directive must still block — scripts should be free to mix exit
|
||||
codes with parseable output."""
|
||||
script = _write_script(
|
||||
tmp_path, "exit1_block.sh",
|
||||
"#!/usr/bin/env bash\n"
|
||||
'printf \'{"decision": "block", "reason": "via exit 1"}\\n\'\n'
|
||||
"exit 1\n",
|
||||
)
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="pre_tool_call", command=str(script),
|
||||
)
|
||||
cb = shell_hooks._make_callback(spec)
|
||||
assert cb(tool_name="terminal") == {"action": "block", "message": "via exit 1"}
|
||||
|
||||
def test_block_translation_end_to_end(self, tmp_path):
|
||||
"""v1 schema-bug regression gate.
|
||||
|
||||
Shell hook returns the Claude-Code-style payload and the bridge
|
||||
must translate it to the canonical Hermes block shape so that
|
||||
get_pre_tool_call_block_message() surfaces the block.
|
||||
"""
|
||||
script = _write_script(
|
||||
tmp_path, "blocker.sh",
|
||||
"#!/usr/bin/env bash\n"
|
||||
'printf \'{"decision": "block", "reason": "no terminal"}\\n\'\n',
|
||||
)
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="pre_tool_call",
|
||||
command=str(script),
|
||||
matcher="terminal",
|
||||
)
|
||||
cb = shell_hooks._make_callback(spec)
|
||||
result = cb(tool_name="terminal", args={"command": "rm -rf /"})
|
||||
assert result == {"action": "block", "message": "no terminal"}
|
||||
|
||||
def test_block_aggregation_through_plugin_manager(self, tmp_path, monkeypatch):
|
||||
"""Registering via register_from_config makes
|
||||
get_pre_tool_call_block_message surface the block — the real
|
||||
end-to-end control flow used by run_agent._invoke_tool."""
|
||||
from hermes_cli import plugins
|
||||
|
||||
script = _write_script(
|
||||
tmp_path, "block.sh",
|
||||
"#!/usr/bin/env bash\n"
|
||||
'printf \'{"decision": "block", "reason": "blocked-by-shell"}\\n\'\n',
|
||||
)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home"))
|
||||
monkeypatch.setenv("HERMES_ACCEPT_HOOKS", "1")
|
||||
|
||||
# Fresh manager
|
||||
plugins._plugin_manager = plugins.PluginManager()
|
||||
|
||||
cfg = {
|
||||
"hooks": {
|
||||
"pre_tool_call": [
|
||||
{"matcher": "terminal", "command": str(script)},
|
||||
],
|
||||
},
|
||||
}
|
||||
registered = shell_hooks.register_from_config(cfg, accept_hooks=True)
|
||||
assert len(registered) == 1
|
||||
|
||||
msg = plugins.get_pre_tool_call_block_message(
|
||||
tool_name="terminal",
|
||||
args={"command": "rm"},
|
||||
)
|
||||
assert msg == "blocked-by-shell"
|
||||
|
||||
def test_matcher_regex_filters_callback(self, tmp_path, monkeypatch):
|
||||
"""A matcher set to 'terminal' must not fire for 'web_search'."""
|
||||
calls = tmp_path / "calls.log"
|
||||
script = _write_script(
|
||||
tmp_path, "log.sh",
|
||||
f"#!/usr/bin/env bash\n"
|
||||
f"echo \"$(cat -)\" >> {calls}\n"
|
||||
f"printf '{{}}\\n'\n",
|
||||
)
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="pre_tool_call",
|
||||
command=str(script),
|
||||
matcher="terminal",
|
||||
)
|
||||
cb = shell_hooks._make_callback(spec)
|
||||
cb(tool_name="terminal", args={"command": "ls"})
|
||||
cb(tool_name="web_search", args={"q": "x"})
|
||||
cb(tool_name="file_read", args={"path": "x"})
|
||||
assert calls.exists()
|
||||
# Only the terminal call wrote to the log
|
||||
assert calls.read_text().count("pre_tool_call") == 1
|
||||
|
||||
def test_payload_schema_delivered(self, tmp_path):
|
||||
capture = tmp_path / "payload.json"
|
||||
script = _write_script(
|
||||
tmp_path, "capture.sh",
|
||||
f"#!/usr/bin/env bash\ncat - > {capture}\nprintf '{{}}\\n'\n",
|
||||
)
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="pre_tool_call", command=str(script),
|
||||
)
|
||||
cb = shell_hooks._make_callback(spec)
|
||||
cb(
|
||||
tool_name="terminal",
|
||||
args={"command": "echo hi"},
|
||||
session_id="sess-77",
|
||||
task_id="task-77",
|
||||
)
|
||||
payload = json.loads(capture.read_text())
|
||||
assert payload["hook_event_name"] == "pre_tool_call"
|
||||
assert payload["tool_name"] == "terminal"
|
||||
assert payload["tool_input"] == {"command": "echo hi"}
|
||||
assert payload["session_id"] == "sess-77"
|
||||
assert "cwd" in payload
|
||||
assert payload["extra"]["task_id"] == "task-77"
|
||||
|
||||
def test_pre_llm_call_context_flows_through(self, tmp_path):
|
||||
script = _write_script(
|
||||
tmp_path, "ctx.sh",
|
||||
"#!/usr/bin/env bash\n"
|
||||
'printf \'{"context": "env-note"}\\n\'\n',
|
||||
)
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="pre_llm_call", command=str(script),
|
||||
)
|
||||
cb = shell_hooks._make_callback(spec)
|
||||
result = cb(
|
||||
session_id="s1", user_message="hello",
|
||||
conversation_history=[], is_first_turn=True,
|
||||
model="gpt-4", platform="cli",
|
||||
)
|
||||
assert result == {"context": "env-note"}
|
||||
|
||||
def test_shlex_handles_paths_with_spaces(self, tmp_path):
|
||||
dir_with_space = tmp_path / "path with space"
|
||||
dir_with_space.mkdir()
|
||||
script = _write_script(
|
||||
dir_with_space, "ok.sh",
|
||||
"#!/usr/bin/env bash\nprintf '{}\\n'\n",
|
||||
)
|
||||
# Quote the path so shlex keeps it as a single token.
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="post_tool_call",
|
||||
command=f'"{script}"',
|
||||
)
|
||||
cb = shell_hooks._make_callback(spec)
|
||||
# No crash = shlex parsed it correctly.
|
||||
assert cb(tool_name="terminal") is None # empty object parses to None
|
||||
|
||||
def test_missing_binary_logged_not_raised(self, tmp_path):
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="on_session_start",
|
||||
command=str(tmp_path / "does-not-exist"),
|
||||
)
|
||||
cb = shell_hooks._make_callback(spec)
|
||||
# Must not raise — agent loop should continue.
|
||||
assert cb(session_id="s") is None
|
||||
|
||||
def test_non_executable_binary_logged_not_raised(self, tmp_path):
|
||||
path = tmp_path / "no-exec"
|
||||
path.write_text("#!/usr/bin/env bash\necho hi\n")
|
||||
# Intentionally do NOT chmod +x.
|
||||
spec = shell_hooks.ShellHookSpec(
|
||||
event="on_session_start", command=str(path),
|
||||
)
|
||||
cb = shell_hooks._make_callback(spec)
|
||||
assert cb(session_id="s") is None
|
||||
|
||||
|
||||
# ── config parsing ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestParseHooksBlock:
|
||||
def test_valid_entry(self):
|
||||
specs = shell_hooks._parse_hooks_block({
|
||||
"pre_tool_call": [
|
||||
{"matcher": "terminal", "command": "/tmp/hook.sh", "timeout": 30},
|
||||
],
|
||||
})
|
||||
assert len(specs) == 1
|
||||
assert specs[0].event == "pre_tool_call"
|
||||
assert specs[0].matcher == "terminal"
|
||||
assert specs[0].command == "/tmp/hook.sh"
|
||||
assert specs[0].timeout == 30
|
||||
|
||||
def test_unknown_event_skipped(self, caplog):
|
||||
specs = shell_hooks._parse_hooks_block({
|
||||
"pre_tools_call": [ # typo
|
||||
{"command": "/tmp/hook.sh"},
|
||||
],
|
||||
})
|
||||
assert specs == []
|
||||
|
||||
def test_missing_command_skipped(self):
|
||||
specs = shell_hooks._parse_hooks_block({
|
||||
"pre_tool_call": [{"matcher": "terminal"}],
|
||||
})
|
||||
assert specs == []
|
||||
|
||||
def test_timeout_clamped_to_max(self):
|
||||
specs = shell_hooks._parse_hooks_block({
|
||||
"post_tool_call": [
|
||||
{"command": "/tmp/slow.sh", "timeout": 9999},
|
||||
],
|
||||
})
|
||||
assert specs[0].timeout == shell_hooks.MAX_TIMEOUT_SECONDS
|
||||
|
||||
def test_non_int_timeout_defaulted(self):
|
||||
specs = shell_hooks._parse_hooks_block({
|
||||
"post_tool_call": [
|
||||
{"command": "/tmp/x.sh", "timeout": "thirty"},
|
||||
],
|
||||
})
|
||||
assert specs[0].timeout == shell_hooks.DEFAULT_TIMEOUT_SECONDS
|
||||
|
||||
def test_non_list_event_skipped(self):
|
||||
specs = shell_hooks._parse_hooks_block({
|
||||
"pre_tool_call": "not a list",
|
||||
})
|
||||
assert specs == []
|
||||
|
||||
def test_none_hooks_block(self):
|
||||
assert shell_hooks._parse_hooks_block(None) == []
|
||||
assert shell_hooks._parse_hooks_block("string") == []
|
||||
assert shell_hooks._parse_hooks_block([]) == []
|
||||
|
||||
def test_non_tool_event_matcher_warns_and_drops(self, caplog):
|
||||
"""matcher: is only honored for pre/post_tool_call; must warn
|
||||
and drop on other events so the spec reflects runtime."""
|
||||
import logging
|
||||
cfg = {"pre_llm_call": [{"matcher": "terminal", "command": "/bin/echo"}]}
|
||||
with caplog.at_level(logging.WARNING, logger=shell_hooks.logger.name):
|
||||
specs = shell_hooks._parse_hooks_block(cfg)
|
||||
assert len(specs) == 1 and specs[0].matcher is None
|
||||
assert any(
|
||||
"only honored for pre_tool_call" in r.getMessage()
|
||||
and "pre_llm_call" in r.getMessage()
|
||||
for r in caplog.records
|
||||
)
|
||||
|
||||
|
||||
# ── Idempotent registration ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestIdempotentRegistration:
|
||||
def test_double_call_registers_once(self, tmp_path, monkeypatch):
|
||||
from hermes_cli import plugins
|
||||
|
||||
script = _write_script(tmp_path, "h.sh",
|
||||
"#!/usr/bin/env bash\nprintf '{}\\n'\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home"))
|
||||
monkeypatch.setenv("HERMES_ACCEPT_HOOKS", "1")
|
||||
|
||||
plugins._plugin_manager = plugins.PluginManager()
|
||||
|
||||
cfg = {"hooks": {"on_session_start": [{"command": str(script)}]}}
|
||||
|
||||
first = shell_hooks.register_from_config(cfg, accept_hooks=True)
|
||||
second = shell_hooks.register_from_config(cfg, accept_hooks=True)
|
||||
assert len(first) == 1
|
||||
assert second == []
|
||||
# Only one callback on the manager
|
||||
mgr = plugins.get_plugin_manager()
|
||||
assert len(mgr._hooks.get("on_session_start", [])) == 1
|
||||
|
||||
def test_same_command_different_matcher_registers_both(
|
||||
self, tmp_path, monkeypatch,
|
||||
):
|
||||
"""Same script used for different matchers under one event must
|
||||
register both callbacks — dedupe keys on (event, matcher, command)."""
|
||||
from hermes_cli import plugins
|
||||
|
||||
script = _write_script(tmp_path, "h.sh",
|
||||
"#!/usr/bin/env bash\nprintf '{}\\n'\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home"))
|
||||
monkeypatch.setenv("HERMES_ACCEPT_HOOKS", "1")
|
||||
|
||||
plugins._plugin_manager = plugins.PluginManager()
|
||||
|
||||
cfg = {
|
||||
"hooks": {
|
||||
"pre_tool_call": [
|
||||
{"matcher": "terminal", "command": str(script)},
|
||||
{"matcher": "web_search", "command": str(script)},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
registered = shell_hooks.register_from_config(cfg, accept_hooks=True)
|
||||
assert len(registered) == 2
|
||||
mgr = plugins.get_plugin_manager()
|
||||
assert len(mgr._hooks.get("pre_tool_call", [])) == 2
|
||||
|
||||
|
||||
# ── Allowlist concurrency ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAllowlistConcurrency:
|
||||
"""Regression tests for the Codex#1 finding: simultaneous
|
||||
_record_approval() calls used to collide on a fixed tmp path and
|
||||
silently lose entries under read-modify-write races."""
|
||||
|
||||
def test_parallel_record_approval_does_not_lose_entries(
|
||||
self, tmp_path, monkeypatch,
|
||||
):
|
||||
import threading
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home"))
|
||||
|
||||
N = 32
|
||||
barrier = threading.Barrier(N)
|
||||
errors: list = []
|
||||
|
||||
def worker(i: int) -> None:
|
||||
try:
|
||||
barrier.wait(timeout=5)
|
||||
shell_hooks._record_approval(
|
||||
"on_session_start", f"/bin/hook-{i}.sh",
|
||||
)
|
||||
except Exception as exc: # pragma: no cover
|
||||
errors.append(exc)
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(N)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert not errors, f"worker errors: {errors}"
|
||||
|
||||
data = shell_hooks.load_allowlist()
|
||||
commands = {e["command"] for e in data["approvals"]}
|
||||
assert commands == {f"/bin/hook-{i}.sh" for i in range(N)}, (
|
||||
f"expected all {N} entries, got {len(commands)}"
|
||||
)
|
||||
|
||||
def test_non_posix_fallback_does_not_self_deadlock(
|
||||
self, tmp_path, monkeypatch,
|
||||
):
|
||||
"""Regression: on platforms without fcntl, the fallback lock must
|
||||
be separate from _registered_lock. register_from_config holds
|
||||
_registered_lock while calling _record_approval (via the consent
|
||||
prompt path), so a shared non-reentrant lock would self-deadlock."""
|
||||
import threading
|
||||
|
||||
monkeypatch.setattr(shell_hooks, "fcntl", None)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home"))
|
||||
|
||||
completed = threading.Event()
|
||||
errors: list = []
|
||||
|
||||
def target() -> None:
|
||||
try:
|
||||
with shell_hooks._registered_lock:
|
||||
shell_hooks._record_approval(
|
||||
"on_session_start", "/bin/x.sh",
|
||||
)
|
||||
completed.set()
|
||||
except Exception as exc: # pragma: no cover
|
||||
errors.append(exc)
|
||||
completed.set()
|
||||
|
||||
t = threading.Thread(target=target, daemon=True)
|
||||
t.start()
|
||||
if not completed.wait(timeout=3.0):
|
||||
pytest.fail(
|
||||
"non-POSIX fallback self-deadlocked — "
|
||||
"_locked_update_approvals must not reuse _registered_lock",
|
||||
)
|
||||
t.join(timeout=1.0)
|
||||
assert not errors, f"errors: {errors}"
|
||||
assert shell_hooks._is_allowlisted(
|
||||
"on_session_start", "/bin/x.sh",
|
||||
)
|
||||
|
||||
def test_save_allowlist_failure_logs_actionable_warning(
|
||||
self, tmp_path, monkeypatch, caplog,
|
||||
):
|
||||
"""Persistence failures must log the path, errno, and
|
||||
re-prompt consequence so "hermes keeps asking" is debuggable."""
|
||||
import logging
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home"))
|
||||
monkeypatch.setattr(
|
||||
shell_hooks.tempfile, "mkstemp",
|
||||
lambda *a, **kw: (_ for _ in ()).throw(OSError(28, "No space")),
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger=shell_hooks.logger.name):
|
||||
shell_hooks.save_allowlist({"approvals": []})
|
||||
msg = next(
|
||||
(r.getMessage() for r in caplog.records
|
||||
if "Failed to persist" in r.getMessage()), "",
|
||||
)
|
||||
assert "shell-hooks-allowlist.json" in msg
|
||||
assert "No space" in msg
|
||||
assert "re-prompt" in msg
|
||||
|
||||
def test_script_is_executable_handles_interpreter_prefix(self, tmp_path):
|
||||
"""For ``python3 hook.py`` and similar the interpreter reads
|
||||
the script, so X_OK on the script itself is not required —
|
||||
only R_OK. Bare invocations still require X_OK."""
|
||||
script = tmp_path / "hook.py"
|
||||
script.write_text("print()\n") # readable, NOT executable
|
||||
|
||||
# Interpreter prefix: R_OK is enough.
|
||||
assert shell_hooks.script_is_executable(f"python3 {script}")
|
||||
assert shell_hooks.script_is_executable(f"/usr/bin/env python3 {script}")
|
||||
|
||||
# Bare invocation on the same non-X_OK file: not runnable.
|
||||
assert not shell_hooks.script_is_executable(str(script))
|
||||
|
||||
# Flip +x; bare invocation is now runnable too.
|
||||
script.chmod(0o755)
|
||||
assert shell_hooks.script_is_executable(str(script))
|
||||
|
||||
def test_command_script_path_resolution(self):
|
||||
"""Regression: ``_command_script_path`` used to return the first
|
||||
shlex token, which picked the interpreter (``python3``, ``bash``,
|
||||
``/usr/bin/env``) instead of the actual script for any
|
||||
interpreter-prefixed command. That broke
|
||||
``hermes hooks doctor``'s executability check and silently
|
||||
disabled mtime drift detection for such hooks."""
|
||||
cases = [
|
||||
# bare path
|
||||
("/path/hook.sh", "/path/hook.sh"),
|
||||
("/bin/echo hi", "/bin/echo"),
|
||||
("~/hook.sh", "~/hook.sh"),
|
||||
("hook.sh", "hook.sh"),
|
||||
# interpreter prefix
|
||||
("python3 /path/hook.py", "/path/hook.py"),
|
||||
("bash /path/hook.sh", "/path/hook.sh"),
|
||||
("bash ~/hook.sh", "~/hook.sh"),
|
||||
("python3 -u /path/hook.py", "/path/hook.py"),
|
||||
("nice -n 10 /path/hook.sh", "/path/hook.sh"),
|
||||
# /usr/bin/env shebang form — must find the *script*, not env
|
||||
("/usr/bin/env python3 /path/hook.py", "/path/hook.py"),
|
||||
("/usr/bin/env bash /path/hook.sh", "/path/hook.sh"),
|
||||
# no path-like tokens → fallback to first token
|
||||
("my-binary --verbose", "my-binary"),
|
||||
("python3 -c 'print(1)'", "python3"),
|
||||
# unparseable (unbalanced quotes) → return command as-is
|
||||
("python3 'unterminated", "python3 'unterminated"),
|
||||
# empty
|
||||
("", ""),
|
||||
]
|
||||
for command, expected in cases:
|
||||
got = shell_hooks._command_script_path(command)
|
||||
assert got == expected, f"{command!r} -> {got!r}, expected {expected!r}"
|
||||
|
||||
def test_save_allowlist_uses_unique_tmp_paths(self, tmp_path, monkeypatch):
|
||||
"""Two save_allowlist calls in flight must use distinct tmp files
|
||||
so the loser's os.replace does not ENOENT on the winner's sweep."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home"))
|
||||
p = shell_hooks.allowlist_path()
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tmp_paths_seen: list = []
|
||||
real_mkstemp = shell_hooks.tempfile.mkstemp
|
||||
|
||||
def spying_mkstemp(*args, **kwargs):
|
||||
fd, path = real_mkstemp(*args, **kwargs)
|
||||
tmp_paths_seen.append(path)
|
||||
return fd, path
|
||||
|
||||
monkeypatch.setattr(shell_hooks.tempfile, "mkstemp", spying_mkstemp)
|
||||
|
||||
shell_hooks.save_allowlist({"approvals": [{"event": "a", "command": "x"}]})
|
||||
shell_hooks.save_allowlist({"approvals": [{"event": "b", "command": "y"}]})
|
||||
|
||||
assert len(tmp_paths_seen) == 2
|
||||
assert tmp_paths_seen[0] != tmp_paths_seen[1]
|
||||
242
tests/agent/test_shell_hooks_consent.py
Normal file
242
tests/agent/test_shell_hooks_consent.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
"""Consent-flow tests for the shell-hook allowlist.
|
||||
|
||||
Covers the prompt/non-prompt decision tree: TTY vs non-TTY, and the
|
||||
three accept-hooks channels (--accept-hooks, HERMES_ACCEPT_HOOKS env,
|
||||
hooks_auto_accept: config key).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent import shell_hooks
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolated_home(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_home"))
|
||||
monkeypatch.delenv("HERMES_ACCEPT_HOOKS", raising=False)
|
||||
shell_hooks.reset_for_tests()
|
||||
yield
|
||||
shell_hooks.reset_for_tests()
|
||||
|
||||
|
||||
def _write_hook_script(tmp_path: Path) -> Path:
|
||||
script = tmp_path / "hook.sh"
|
||||
script.write_text("#!/usr/bin/env bash\nprintf '{}\\n'\n")
|
||||
script.chmod(0o755)
|
||||
return script
|
||||
|
||||
|
||||
# ── TTY prompt flow ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTTYPromptFlow:
|
||||
def test_first_use_prompts_and_approves(self, tmp_path):
|
||||
from hermes_cli import plugins
|
||||
|
||||
script = _write_hook_script(tmp_path)
|
||||
plugins._plugin_manager = plugins.PluginManager()
|
||||
|
||||
with patch("sys.stdin") as mock_stdin, patch("builtins.input", return_value="y"):
|
||||
mock_stdin.isatty.return_value = True
|
||||
registered = shell_hooks.register_from_config(
|
||||
{"hooks": {"on_session_start": [{"command": str(script)}]}},
|
||||
accept_hooks=False,
|
||||
)
|
||||
assert len(registered) == 1
|
||||
|
||||
entry = shell_hooks.allowlist_entry_for("on_session_start", str(script))
|
||||
assert entry is not None
|
||||
assert entry["event"] == "on_session_start"
|
||||
assert entry["command"] == str(script)
|
||||
|
||||
def test_first_use_prompts_and_rejects(self, tmp_path):
|
||||
from hermes_cli import plugins
|
||||
|
||||
script = _write_hook_script(tmp_path)
|
||||
plugins._plugin_manager = plugins.PluginManager()
|
||||
|
||||
with patch("sys.stdin") as mock_stdin, patch("builtins.input", return_value="n"):
|
||||
mock_stdin.isatty.return_value = True
|
||||
registered = shell_hooks.register_from_config(
|
||||
{"hooks": {"on_session_start": [{"command": str(script)}]}},
|
||||
accept_hooks=False,
|
||||
)
|
||||
assert registered == []
|
||||
assert shell_hooks.allowlist_entry_for(
|
||||
"on_session_start", str(script),
|
||||
) is None
|
||||
|
||||
def test_subsequent_use_does_not_prompt(self, tmp_path):
|
||||
"""After the first approval, re-registration must be silent."""
|
||||
from hermes_cli import plugins
|
||||
|
||||
script = _write_hook_script(tmp_path)
|
||||
plugins._plugin_manager = plugins.PluginManager()
|
||||
|
||||
# First call: TTY, approved.
|
||||
with patch("sys.stdin") as mock_stdin, patch("builtins.input", return_value="y"):
|
||||
mock_stdin.isatty.return_value = True
|
||||
shell_hooks.register_from_config(
|
||||
{"hooks": {"on_session_start": [{"command": str(script)}]}},
|
||||
accept_hooks=False,
|
||||
)
|
||||
|
||||
# Reset registration set but keep the allowlist on disk.
|
||||
shell_hooks.reset_for_tests()
|
||||
|
||||
# Second call: TTY, input() must NOT be called.
|
||||
with patch("sys.stdin") as mock_stdin, patch(
|
||||
"builtins.input", side_effect=AssertionError("should not prompt"),
|
||||
):
|
||||
mock_stdin.isatty.return_value = True
|
||||
registered = shell_hooks.register_from_config(
|
||||
{"hooks": {"on_session_start": [{"command": str(script)}]}},
|
||||
accept_hooks=False,
|
||||
)
|
||||
assert len(registered) == 1
|
||||
|
||||
|
||||
# ── non-TTY flow ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestNonTTYFlow:
|
||||
def test_no_tty_no_flag_skips_registration(self, tmp_path):
|
||||
from hermes_cli import plugins
|
||||
|
||||
script = _write_hook_script(tmp_path)
|
||||
plugins._plugin_manager = plugins.PluginManager()
|
||||
|
||||
with patch("sys.stdin") as mock_stdin:
|
||||
mock_stdin.isatty.return_value = False
|
||||
registered = shell_hooks.register_from_config(
|
||||
{"hooks": {"on_session_start": [{"command": str(script)}]}},
|
||||
accept_hooks=False,
|
||||
)
|
||||
assert registered == []
|
||||
|
||||
def test_no_tty_with_argument_flag_accepts(self, tmp_path):
|
||||
from hermes_cli import plugins
|
||||
|
||||
script = _write_hook_script(tmp_path)
|
||||
plugins._plugin_manager = plugins.PluginManager()
|
||||
|
||||
with patch("sys.stdin") as mock_stdin:
|
||||
mock_stdin.isatty.return_value = False
|
||||
registered = shell_hooks.register_from_config(
|
||||
{"hooks": {"on_session_start": [{"command": str(script)}]}},
|
||||
accept_hooks=True,
|
||||
)
|
||||
assert len(registered) == 1
|
||||
|
||||
def test_no_tty_with_env_accepts(self, tmp_path, monkeypatch):
|
||||
from hermes_cli import plugins
|
||||
|
||||
script = _write_hook_script(tmp_path)
|
||||
plugins._plugin_manager = plugins.PluginManager()
|
||||
monkeypatch.setenv("HERMES_ACCEPT_HOOKS", "1")
|
||||
|
||||
with patch("sys.stdin") as mock_stdin:
|
||||
mock_stdin.isatty.return_value = False
|
||||
registered = shell_hooks.register_from_config(
|
||||
{"hooks": {"on_session_start": [{"command": str(script)}]}},
|
||||
accept_hooks=False,
|
||||
)
|
||||
assert len(registered) == 1
|
||||
|
||||
def test_no_tty_with_config_accepts(self, tmp_path):
|
||||
from hermes_cli import plugins
|
||||
|
||||
script = _write_hook_script(tmp_path)
|
||||
plugins._plugin_manager = plugins.PluginManager()
|
||||
|
||||
with patch("sys.stdin") as mock_stdin:
|
||||
mock_stdin.isatty.return_value = False
|
||||
registered = shell_hooks.register_from_config(
|
||||
{
|
||||
"hooks_auto_accept": True,
|
||||
"hooks": {"on_session_start": [{"command": str(script)}]},
|
||||
},
|
||||
accept_hooks=False,
|
||||
)
|
||||
assert len(registered) == 1
|
||||
|
||||
|
||||
# ── Allowlist + revoke + mtime ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAllowlistOps:
|
||||
def test_mtime_recorded_on_approval(self, tmp_path):
|
||||
script = _write_hook_script(tmp_path)
|
||||
shell_hooks._record_approval("on_session_start", str(script))
|
||||
|
||||
entry = shell_hooks.allowlist_entry_for(
|
||||
"on_session_start", str(script),
|
||||
)
|
||||
assert entry is not None
|
||||
assert entry["script_mtime_at_approval"] is not None
|
||||
# ISO-8601 Z-suffix
|
||||
assert entry["script_mtime_at_approval"].endswith("Z")
|
||||
|
||||
def test_revoke_removes_entry(self, tmp_path):
|
||||
script = _write_hook_script(tmp_path)
|
||||
shell_hooks._record_approval("on_session_start", str(script))
|
||||
assert shell_hooks.allowlist_entry_for(
|
||||
"on_session_start", str(script),
|
||||
) is not None
|
||||
|
||||
removed = shell_hooks.revoke(str(script))
|
||||
assert removed == 1
|
||||
assert shell_hooks.allowlist_entry_for(
|
||||
"on_session_start", str(script),
|
||||
) is None
|
||||
|
||||
def test_revoke_unknown_returns_zero(self, tmp_path):
|
||||
assert shell_hooks.revoke(str(tmp_path / "never-approved.sh")) == 0
|
||||
|
||||
def test_tilde_path_approval_records_resolvable_mtime(self, tmp_path, monkeypatch):
|
||||
"""If the command uses ~ the approval must still find the file."""
|
||||
monkeypatch.setenv("HOME", str(tmp_path))
|
||||
target = tmp_path / "hook.sh"
|
||||
target.write_text("#!/usr/bin/env bash\n")
|
||||
target.chmod(0o755)
|
||||
|
||||
shell_hooks._record_approval("on_session_start", "~/hook.sh")
|
||||
entry = shell_hooks.allowlist_entry_for(
|
||||
"on_session_start", "~/hook.sh",
|
||||
)
|
||||
assert entry is not None
|
||||
# Must not be None — the tilde was expanded before stat().
|
||||
assert entry["script_mtime_at_approval"] is not None
|
||||
|
||||
def test_duplicate_approval_replaces_mtime(self, tmp_path):
|
||||
"""Re-approving the same pair refreshes the approval timestamp."""
|
||||
script = _write_hook_script(tmp_path)
|
||||
shell_hooks._record_approval("on_session_start", str(script))
|
||||
original_entry = shell_hooks.allowlist_entry_for(
|
||||
"on_session_start", str(script),
|
||||
)
|
||||
assert original_entry is not None
|
||||
|
||||
# Touch the script to bump its mtime then re-approve.
|
||||
import os
|
||||
import time
|
||||
new_mtime = original_entry.get("script_mtime_at_approval")
|
||||
time.sleep(0.01)
|
||||
os.utime(script, None) # current time
|
||||
|
||||
shell_hooks._record_approval("on_session_start", str(script))
|
||||
|
||||
# Exactly one entry per (event, command).
|
||||
approvals = shell_hooks.load_allowlist().get("approvals", [])
|
||||
matching = [
|
||||
e for e in approvals
|
||||
if e.get("event") == "on_session_start"
|
||||
and e.get("command") == str(script)
|
||||
]
|
||||
assert len(matching) == 1
|
||||
|
|
@ -405,3 +405,191 @@ class TestPlanSkillHelpers:
|
|||
assert "Add a /plan command" in msg
|
||||
assert ".hermes/plans/plan.md" in msg
|
||||
assert "Runtime note:" in msg
|
||||
|
||||
|
||||
class TestSkillDirectoryHeader:
|
||||
"""The activation message must expose the absolute skill directory and
|
||||
explain how to resolve relative paths, so skills with bundled scripts
|
||||
don't force the agent into a second ``skill_view()`` round-trip."""
|
||||
|
||||
def test_header_contains_absolute_skill_dir(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
skill_dir = _make_skill(tmp_path, "abs-dir-skill")
|
||||
scan_skill_commands()
|
||||
msg = build_skill_invocation_message("/abs-dir-skill", "go")
|
||||
|
||||
assert msg is not None
|
||||
assert f"[Skill directory: {skill_dir}]" in msg
|
||||
assert "Resolve any relative paths" in msg
|
||||
|
||||
def test_supporting_files_shown_with_absolute_paths(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
skill_dir = _make_skill(tmp_path, "scripted-skill")
|
||||
(skill_dir / "scripts").mkdir()
|
||||
(skill_dir / "scripts" / "run.js").write_text("console.log('hi')")
|
||||
scan_skill_commands()
|
||||
msg = build_skill_invocation_message("/scripted-skill")
|
||||
|
||||
assert msg is not None
|
||||
# The supporting-files block must emit both the relative form (so the
|
||||
# agent can call skill_view on it) and the absolute form (so it can
|
||||
# run the script directly via terminal).
|
||||
assert "scripts/run.js" in msg
|
||||
assert str(skill_dir / "scripts" / "run.js") in msg
|
||||
assert f"node {skill_dir}/scripts/foo.js" in msg
|
||||
|
||||
|
||||
class TestTemplateVarSubstitution:
|
||||
"""``${HERMES_SKILL_DIR}`` and ``${HERMES_SESSION_ID}`` in SKILL.md body
|
||||
are replaced before the agent sees the content."""
|
||||
|
||||
def test_substitutes_skill_dir(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
skill_dir = _make_skill(
|
||||
tmp_path,
|
||||
"templated",
|
||||
body="Run: node ${HERMES_SKILL_DIR}/scripts/foo.js",
|
||||
)
|
||||
scan_skill_commands()
|
||||
msg = build_skill_invocation_message("/templated")
|
||||
|
||||
assert msg is not None
|
||||
assert f"node {skill_dir}/scripts/foo.js" in msg
|
||||
# The literal template token must not leak through.
|
||||
assert "${HERMES_SKILL_DIR}" not in msg.split("[Skill directory:")[0]
|
||||
|
||||
def test_substitutes_session_id_when_available(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
_make_skill(
|
||||
tmp_path,
|
||||
"sess-templated",
|
||||
body="Session: ${HERMES_SESSION_ID}",
|
||||
)
|
||||
scan_skill_commands()
|
||||
msg = build_skill_invocation_message(
|
||||
"/sess-templated", task_id="abc-123"
|
||||
)
|
||||
|
||||
assert msg is not None
|
||||
assert "Session: abc-123" in msg
|
||||
|
||||
def test_leaves_session_id_token_when_missing(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
_make_skill(
|
||||
tmp_path,
|
||||
"sess-missing",
|
||||
body="Session: ${HERMES_SESSION_ID}",
|
||||
)
|
||||
scan_skill_commands()
|
||||
msg = build_skill_invocation_message("/sess-missing", task_id=None)
|
||||
|
||||
assert msg is not None
|
||||
# No session — token left intact so the author can spot it.
|
||||
assert "Session: ${HERMES_SESSION_ID}" in msg
|
||||
|
||||
def test_disable_template_vars_via_config(self, tmp_path):
|
||||
with (
|
||||
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
|
||||
patch(
|
||||
"agent.skill_commands._load_skills_config",
|
||||
return_value={"template_vars": False},
|
||||
),
|
||||
):
|
||||
_make_skill(
|
||||
tmp_path,
|
||||
"no-sub",
|
||||
body="Run: node ${HERMES_SKILL_DIR}/scripts/foo.js",
|
||||
)
|
||||
scan_skill_commands()
|
||||
msg = build_skill_invocation_message("/no-sub")
|
||||
|
||||
assert msg is not None
|
||||
# Template token must survive when substitution is disabled.
|
||||
assert "${HERMES_SKILL_DIR}/scripts/foo.js" in msg
|
||||
|
||||
|
||||
class TestInlineShellExpansion:
|
||||
"""Inline ``!`cmd`` snippets in SKILL.md run before the agent sees the
|
||||
content — but only when the user has opted in via config."""
|
||||
|
||||
def test_inline_shell_is_off_by_default(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
_make_skill(
|
||||
tmp_path,
|
||||
"dyn-default-off",
|
||||
body="Today is !`echo INLINE_RAN`.",
|
||||
)
|
||||
scan_skill_commands()
|
||||
msg = build_skill_invocation_message("/dyn-default-off")
|
||||
|
||||
assert msg is not None
|
||||
# Default config has inline_shell=False — snippet must stay literal.
|
||||
assert "!`echo INLINE_RAN`" in msg
|
||||
assert "Today is INLINE_RAN." not in msg
|
||||
|
||||
def test_inline_shell_runs_when_enabled(self, tmp_path):
|
||||
with (
|
||||
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
|
||||
patch(
|
||||
"agent.skill_commands._load_skills_config",
|
||||
return_value={"template_vars": True, "inline_shell": True,
|
||||
"inline_shell_timeout": 5},
|
||||
),
|
||||
):
|
||||
_make_skill(
|
||||
tmp_path,
|
||||
"dyn-on",
|
||||
body="Marker: !`echo INLINE_RAN`.",
|
||||
)
|
||||
scan_skill_commands()
|
||||
msg = build_skill_invocation_message("/dyn-on")
|
||||
|
||||
assert msg is not None
|
||||
assert "Marker: INLINE_RAN." in msg
|
||||
assert "!`echo INLINE_RAN`" not in msg
|
||||
|
||||
def test_inline_shell_runs_in_skill_directory(self, tmp_path):
|
||||
"""Inline snippets get the skill dir as CWD so relative paths work."""
|
||||
with (
|
||||
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
|
||||
patch(
|
||||
"agent.skill_commands._load_skills_config",
|
||||
return_value={"template_vars": True, "inline_shell": True,
|
||||
"inline_shell_timeout": 5},
|
||||
),
|
||||
):
|
||||
skill_dir = _make_skill(
|
||||
tmp_path,
|
||||
"dyn-cwd",
|
||||
body="Here: !`pwd`",
|
||||
)
|
||||
scan_skill_commands()
|
||||
msg = build_skill_invocation_message("/dyn-cwd")
|
||||
|
||||
assert msg is not None
|
||||
assert f"Here: {skill_dir}" in msg
|
||||
|
||||
def test_inline_shell_timeout_does_not_break_message(self, tmp_path):
|
||||
with (
|
||||
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
|
||||
patch(
|
||||
"agent.skill_commands._load_skills_config",
|
||||
return_value={"template_vars": True, "inline_shell": True,
|
||||
"inline_shell_timeout": 1},
|
||||
),
|
||||
):
|
||||
_make_skill(
|
||||
tmp_path,
|
||||
"dyn-slow",
|
||||
body="Slow: !`sleep 5 && printf DYN_MARKER`",
|
||||
)
|
||||
scan_skill_commands()
|
||||
msg = build_skill_invocation_message("/dyn-slow")
|
||||
|
||||
assert msg is not None
|
||||
# Timeout is surfaced as a marker instead of propagating as an error,
|
||||
# and the rest of the skill message still renders.
|
||||
assert "inline-shell timeout" in msg
|
||||
# The command's intended stdout never made it through — only the
|
||||
# timeout marker (which echoes the command text) survives.
|
||||
assert "DYN_MARKER" not in msg.replace("sleep 5 && printf DYN_MARKER", "")
|
||||
|
|
|
|||
|
|
@ -1,61 +0,0 @@
|
|||
from agent.smart_model_routing import choose_cheap_model_route
|
||||
|
||||
|
||||
_BASE_CONFIG = {
|
||||
"enabled": True,
|
||||
"cheap_model": {
|
||||
"provider": "openrouter",
|
||||
"model": "google/gemini-2.5-flash",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_returns_none_when_disabled():
|
||||
cfg = {**_BASE_CONFIG, "enabled": False}
|
||||
assert choose_cheap_model_route("what time is it in tokyo?", cfg) is None
|
||||
|
||||
|
||||
def test_routes_short_simple_prompt():
|
||||
result = choose_cheap_model_route("what time is it in tokyo?", _BASE_CONFIG)
|
||||
assert result is not None
|
||||
assert result["provider"] == "openrouter"
|
||||
assert result["model"] == "google/gemini-2.5-flash"
|
||||
assert result["routing_reason"] == "simple_turn"
|
||||
|
||||
|
||||
def test_skips_long_prompt():
|
||||
prompt = "please summarize this carefully " * 20
|
||||
assert choose_cheap_model_route(prompt, _BASE_CONFIG) is None
|
||||
|
||||
|
||||
def test_skips_code_like_prompt():
|
||||
prompt = "debug this traceback: ```python\nraise ValueError('bad')\n```"
|
||||
assert choose_cheap_model_route(prompt, _BASE_CONFIG) is None
|
||||
|
||||
|
||||
def test_skips_tool_heavy_prompt_keywords():
|
||||
prompt = "implement a patch for this docker error"
|
||||
assert choose_cheap_model_route(prompt, _BASE_CONFIG) is None
|
||||
|
||||
|
||||
def test_resolve_turn_route_falls_back_to_primary_when_route_runtime_cannot_be_resolved(monkeypatch):
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
lambda **kwargs: (_ for _ in ()).throw(RuntimeError("bad route")),
|
||||
)
|
||||
result = resolve_turn_route(
|
||||
"what time is it in tokyo?",
|
||||
_BASE_CONFIG,
|
||||
{
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"provider": "openrouter",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"api_mode": "chat_completions",
|
||||
"api_key": "sk-primary",
|
||||
},
|
||||
)
|
||||
assert result["model"] == "anthropic/claude-sonnet-4"
|
||||
assert result["runtime"]["provider"] == "openrouter"
|
||||
assert result["label"] is None
|
||||
|
|
@ -193,7 +193,7 @@ class TestBuildChildProgressCallback:
|
|||
|
||||
# task_index=0 in a batch of 3 → prefix "[1]"
|
||||
cb0 = _build_child_progress_callback(0, "test goal", parent, task_count=3)
|
||||
cb0("web_search", "test")
|
||||
cb0("tool.started", "web_search", "test", {})
|
||||
output = buf.getvalue()
|
||||
assert "[1]" in output
|
||||
|
||||
|
|
@ -201,7 +201,7 @@ class TestBuildChildProgressCallback:
|
|||
buf.truncate(0)
|
||||
buf.seek(0)
|
||||
cb2 = _build_child_progress_callback(2, "test goal", parent, task_count=3)
|
||||
cb2("web_search", "test")
|
||||
cb2("tool.started", "web_search", "test", {})
|
||||
output = buf.getvalue()
|
||||
assert "[3]" in output
|
||||
|
||||
|
|
|
|||
224
tests/agent/test_subagent_stop_hook.py
Normal file
224
tests/agent/test_subagent_stop_hook.py
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
"""Tests for the subagent_stop hook event.
|
||||
|
||||
Covers wire-up from tools.delegate_tool.delegate_task:
|
||||
* fires once per child in both single-task and batch modes
|
||||
* runs on the parent thread (no re-entrancy for hook authors)
|
||||
* carries child_role when the agent exposes _delegate_role
|
||||
* carries child_role=None when _delegate_role is not set (pre-M3)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import threading
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.delegate_tool import delegate_task
|
||||
from hermes_cli import plugins
|
||||
|
||||
|
||||
def _make_parent(depth: int = 0, session_id: str = "parent-1"):
|
||||
parent = MagicMock()
|
||||
parent.base_url = "https://openrouter.ai/api/v1"
|
||||
parent.api_key = "***"
|
||||
parent.provider = "openrouter"
|
||||
parent.api_mode = "chat_completions"
|
||||
parent.model = "anthropic/claude-sonnet-4"
|
||||
parent.platform = "cli"
|
||||
parent.providers_allowed = None
|
||||
parent.providers_ignored = None
|
||||
parent.providers_order = None
|
||||
parent.provider_sort = None
|
||||
parent._session_db = None
|
||||
parent._delegate_depth = depth
|
||||
parent._active_children = []
|
||||
parent._active_children_lock = threading.Lock()
|
||||
parent._print_fn = None
|
||||
parent.tool_progress_callback = None
|
||||
parent.thinking_callback = None
|
||||
parent._memory_manager = None
|
||||
parent.session_id = session_id
|
||||
return parent
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _fresh_plugin_manager():
|
||||
"""Each test gets a fresh PluginManager so hook callbacks don't
|
||||
leak between tests."""
|
||||
original = plugins._plugin_manager
|
||||
plugins._plugin_manager = plugins.PluginManager()
|
||||
yield
|
||||
plugins._plugin_manager = original
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_child_builder(monkeypatch):
|
||||
"""Replace _build_child_agent with a MagicMock factory so delegate_task
|
||||
never transitively imports run_agent / openai. Keeps the test runnable
|
||||
in environments without heavyweight runtime deps installed."""
|
||||
def _fake_build_child(task_index, **kwargs):
|
||||
child = MagicMock()
|
||||
child._delegate_saved_tool_names = []
|
||||
child._credential_pool = None
|
||||
return child
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tools.delegate_tool._build_child_agent", _fake_build_child,
|
||||
)
|
||||
|
||||
|
||||
def _register_capturing_hook():
|
||||
captured = []
|
||||
|
||||
def _cb(**kwargs):
|
||||
kwargs["_thread"] = threading.current_thread()
|
||||
captured.append(kwargs)
|
||||
|
||||
mgr = plugins.get_plugin_manager()
|
||||
mgr._hooks.setdefault("subagent_stop", []).append(_cb)
|
||||
return captured
|
||||
|
||||
|
||||
# ── single-task mode ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSingleTask:
|
||||
def test_fires_once(self):
|
||||
captured = _register_capturing_hook()
|
||||
|
||||
with patch("tools.delegate_tool._run_single_child") as mock_run:
|
||||
mock_run.return_value = {
|
||||
"task_index": 0,
|
||||
"status": "completed",
|
||||
"summary": "Done!",
|
||||
"api_calls": 3,
|
||||
"duration_seconds": 5.0,
|
||||
"_child_role": "analyst",
|
||||
}
|
||||
delegate_task(goal="do X", parent_agent=_make_parent())
|
||||
|
||||
assert len(captured) == 1
|
||||
payload = captured[0]
|
||||
assert payload["child_role"] == "analyst"
|
||||
assert payload["child_status"] == "completed"
|
||||
assert payload["child_summary"] == "Done!"
|
||||
assert payload["duration_ms"] == 5000
|
||||
|
||||
def test_fires_on_parent_thread(self):
|
||||
captured = _register_capturing_hook()
|
||||
main_thread = threading.current_thread()
|
||||
|
||||
with patch("tools.delegate_tool._run_single_child") as mock_run:
|
||||
mock_run.return_value = {
|
||||
"task_index": 0, "status": "completed",
|
||||
"summary": "x", "api_calls": 1, "duration_seconds": 0.1,
|
||||
"_child_role": None,
|
||||
}
|
||||
delegate_task(goal="go", parent_agent=_make_parent())
|
||||
|
||||
assert captured[0]["_thread"] is main_thread
|
||||
|
||||
def test_payload_includes_parent_session_id(self):
|
||||
captured = _register_capturing_hook()
|
||||
|
||||
with patch("tools.delegate_tool._run_single_child") as mock_run:
|
||||
mock_run.return_value = {
|
||||
"task_index": 0, "status": "completed",
|
||||
"summary": "x", "api_calls": 1, "duration_seconds": 0.1,
|
||||
"_child_role": None,
|
||||
}
|
||||
delegate_task(
|
||||
goal="go",
|
||||
parent_agent=_make_parent(session_id="sess-xyz"),
|
||||
)
|
||||
|
||||
assert captured[0]["parent_session_id"] == "sess-xyz"
|
||||
|
||||
|
||||
# ── batch mode ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBatchMode:
|
||||
def test_fires_per_child(self):
|
||||
captured = _register_capturing_hook()
|
||||
|
||||
with patch("tools.delegate_tool._run_single_child") as mock_run:
|
||||
mock_run.side_effect = [
|
||||
{"task_index": 0, "status": "completed",
|
||||
"summary": "A", "api_calls": 1, "duration_seconds": 1.0,
|
||||
"_child_role": "role-a"},
|
||||
{"task_index": 1, "status": "completed",
|
||||
"summary": "B", "api_calls": 2, "duration_seconds": 2.0,
|
||||
"_child_role": "role-b"},
|
||||
{"task_index": 2, "status": "completed",
|
||||
"summary": "C", "api_calls": 3, "duration_seconds": 3.0,
|
||||
"_child_role": "role-c"},
|
||||
]
|
||||
delegate_task(
|
||||
tasks=[
|
||||
{"goal": "A"}, {"goal": "B"}, {"goal": "C"},
|
||||
],
|
||||
parent_agent=_make_parent(),
|
||||
)
|
||||
|
||||
assert len(captured) == 3
|
||||
roles = sorted(c["child_role"] for c in captured)
|
||||
assert roles == ["role-a", "role-b", "role-c"]
|
||||
|
||||
def test_all_fires_on_parent_thread(self):
|
||||
captured = _register_capturing_hook()
|
||||
main_thread = threading.current_thread()
|
||||
|
||||
with patch("tools.delegate_tool._run_single_child") as mock_run:
|
||||
mock_run.side_effect = [
|
||||
{"task_index": 0, "status": "completed",
|
||||
"summary": "A", "api_calls": 1, "duration_seconds": 1.0,
|
||||
"_child_role": None},
|
||||
{"task_index": 1, "status": "completed",
|
||||
"summary": "B", "api_calls": 2, "duration_seconds": 2.0,
|
||||
"_child_role": None},
|
||||
]
|
||||
delegate_task(
|
||||
tasks=[{"goal": "A"}, {"goal": "B"}],
|
||||
parent_agent=_make_parent(),
|
||||
)
|
||||
|
||||
for payload in captured:
|
||||
assert payload["_thread"] is main_thread
|
||||
|
||||
|
||||
# ── payload shape ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPayloadShape:
|
||||
def test_role_absent_becomes_none(self):
|
||||
captured = _register_capturing_hook()
|
||||
|
||||
with patch("tools.delegate_tool._run_single_child") as mock_run:
|
||||
mock_run.return_value = {
|
||||
"task_index": 0, "status": "completed",
|
||||
"summary": "x", "api_calls": 1, "duration_seconds": 0.1,
|
||||
# Deliberately omit _child_role — pre-M3 shape.
|
||||
}
|
||||
delegate_task(goal="do X", parent_agent=_make_parent())
|
||||
|
||||
assert captured[0]["child_role"] is None
|
||||
|
||||
def test_result_does_not_leak_child_role_field(self):
|
||||
"""The internal _child_role key must be stripped before the
|
||||
result dict is serialised to JSON."""
|
||||
_register_capturing_hook()
|
||||
|
||||
with patch("tools.delegate_tool._run_single_child") as mock_run:
|
||||
mock_run.return_value = {
|
||||
"task_index": 0, "status": "completed",
|
||||
"summary": "x", "api_calls": 1, "duration_seconds": 0.1,
|
||||
"_child_role": "leaf",
|
||||
}
|
||||
raw = delegate_task(goal="do X", parent_agent=_make_parent())
|
||||
|
||||
parsed = json.loads(raw)
|
||||
assert "results" in parsed
|
||||
assert "_child_role" not in parsed["results"][0]
|
||||
0
tests/agent/transports/__init__.py
Normal file
0
tests/agent/transports/__init__.py
Normal file
164
tests/agent/transports/test_bedrock_transport.py
Normal file
164
tests/agent/transports/test_bedrock_transport.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
"""Tests for the BedrockTransport."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from agent.transports import get_transport
|
||||
from agent.transports.types import NormalizedResponse, ToolCall
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transport():
|
||||
import agent.transports.bedrock # noqa: F401
|
||||
return get_transport("bedrock_converse")
|
||||
|
||||
|
||||
class TestBedrockBasic:
|
||||
|
||||
def test_api_mode(self, transport):
|
||||
assert transport.api_mode == "bedrock_converse"
|
||||
|
||||
def test_registered(self, transport):
|
||||
assert transport is not None
|
||||
|
||||
|
||||
class TestBedrockBuildKwargs:
|
||||
|
||||
def test_basic_kwargs(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hello"}]
|
||||
kw = transport.build_kwargs(model="anthropic.claude-3-5-sonnet-20241022-v2:0", messages=msgs)
|
||||
assert kw["modelId"] == "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
assert kw["__bedrock_converse__"] is True
|
||||
assert kw["__bedrock_region__"] == "us-east-1"
|
||||
assert "messages" in kw
|
||||
|
||||
def test_custom_region(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
messages=msgs,
|
||||
region="eu-west-1",
|
||||
)
|
||||
assert kw["__bedrock_region__"] == "eu-west-1"
|
||||
|
||||
def test_max_tokens(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
messages=msgs,
|
||||
max_tokens=8192,
|
||||
)
|
||||
assert kw["inferenceConfig"]["maxTokens"] == 8192
|
||||
|
||||
|
||||
class TestBedrockConvertTools:
|
||||
|
||||
def test_convert_tools(self, transport):
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "terminal",
|
||||
"description": "Run commands",
|
||||
"parameters": {"type": "object", "properties": {"command": {"type": "string"}}},
|
||||
}
|
||||
}]
|
||||
result = transport.convert_tools(tools)
|
||||
assert len(result) == 1
|
||||
assert result[0]["toolSpec"]["name"] == "terminal"
|
||||
|
||||
|
||||
class TestBedrockValidate:
|
||||
|
||||
def test_none(self, transport):
|
||||
assert transport.validate_response(None) is False
|
||||
|
||||
def test_raw_dict_valid(self, transport):
|
||||
assert transport.validate_response({"output": {"message": {}}}) is True
|
||||
|
||||
def test_raw_dict_invalid(self, transport):
|
||||
assert transport.validate_response({"error": "fail"}) is False
|
||||
|
||||
def test_normalized_valid(self, transport):
|
||||
r = SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="hi"))])
|
||||
assert transport.validate_response(r) is True
|
||||
|
||||
|
||||
class TestBedrockMapFinishReason:
|
||||
|
||||
def test_end_turn(self, transport):
|
||||
assert transport.map_finish_reason("end_turn") == "stop"
|
||||
|
||||
def test_tool_use(self, transport):
|
||||
assert transport.map_finish_reason("tool_use") == "tool_calls"
|
||||
|
||||
def test_max_tokens(self, transport):
|
||||
assert transport.map_finish_reason("max_tokens") == "length"
|
||||
|
||||
def test_guardrail(self, transport):
|
||||
assert transport.map_finish_reason("guardrail_intervened") == "content_filter"
|
||||
|
||||
def test_unknown(self, transport):
|
||||
assert transport.map_finish_reason("unknown") == "stop"
|
||||
|
||||
|
||||
class TestBedrockNormalize:
|
||||
|
||||
def _make_bedrock_response(self, text="Hello", tool_calls=None, stop_reason="end_turn"):
|
||||
"""Build a raw Bedrock converse response dict."""
|
||||
content = []
|
||||
if text:
|
||||
content.append({"text": text})
|
||||
if tool_calls:
|
||||
for tc in tool_calls:
|
||||
content.append({
|
||||
"toolUse": {
|
||||
"toolUseId": tc["id"],
|
||||
"name": tc["name"],
|
||||
"input": tc["input"],
|
||||
}
|
||||
})
|
||||
return {
|
||||
"output": {"message": {"role": "assistant", "content": content}},
|
||||
"stopReason": stop_reason,
|
||||
"usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15},
|
||||
}
|
||||
|
||||
def test_text_response(self, transport):
|
||||
raw = self._make_bedrock_response(text="Hello world")
|
||||
nr = transport.normalize_response(raw)
|
||||
assert isinstance(nr, NormalizedResponse)
|
||||
assert nr.content == "Hello world"
|
||||
assert nr.finish_reason == "stop"
|
||||
|
||||
def test_tool_call_response(self, transport):
|
||||
raw = self._make_bedrock_response(
|
||||
text=None,
|
||||
tool_calls=[{"id": "tool_1", "name": "terminal", "input": {"command": "ls"}}],
|
||||
stop_reason="tool_use",
|
||||
)
|
||||
nr = transport.normalize_response(raw)
|
||||
assert nr.finish_reason == "tool_calls"
|
||||
assert len(nr.tool_calls) == 1
|
||||
assert nr.tool_calls[0].name == "terminal"
|
||||
|
||||
def test_already_normalized_response(self, transport):
|
||||
"""Test normalize_response handles already-normalized SimpleNamespace (from dispatch site)."""
|
||||
pre_normalized = SimpleNamespace(
|
||||
choices=[SimpleNamespace(
|
||||
message=SimpleNamespace(
|
||||
content="Hello from Bedrock",
|
||||
tool_calls=None,
|
||||
reasoning=None,
|
||||
reasoning_content=None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
)
|
||||
nr = transport.normalize_response(pre_normalized)
|
||||
assert isinstance(nr, NormalizedResponse)
|
||||
assert nr.content == "Hello from Bedrock"
|
||||
assert nr.finish_reason == "stop"
|
||||
assert nr.usage is not None
|
||||
assert nr.usage.prompt_tokens == 10
|
||||
349
tests/agent/transports/test_chat_completions.py
Normal file
349
tests/agent/transports/test_chat_completions.py
Normal file
|
|
@ -0,0 +1,349 @@
|
|||
"""Tests for the ChatCompletionsTransport."""
|
||||
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from agent.transports import get_transport
|
||||
from agent.transports.types import NormalizedResponse, ToolCall
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transport():
|
||||
import agent.transports.chat_completions # noqa: F401
|
||||
return get_transport("chat_completions")
|
||||
|
||||
|
||||
class TestChatCompletionsBasic:
|
||||
|
||||
def test_api_mode(self, transport):
|
||||
assert transport.api_mode == "chat_completions"
|
||||
|
||||
def test_registered(self, transport):
|
||||
assert transport is not None
|
||||
|
||||
def test_convert_tools_identity(self, transport):
|
||||
tools = [{"type": "function", "function": {"name": "test", "parameters": {}}}]
|
||||
assert transport.convert_tools(tools) is tools
|
||||
|
||||
def test_convert_messages_no_codex_leaks(self, transport):
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
result = transport.convert_messages(msgs)
|
||||
assert result is msgs # no copy needed
|
||||
|
||||
def test_convert_messages_strips_codex_fields(self, transport):
|
||||
msgs = [
|
||||
{"role": "assistant", "content": "ok", "codex_reasoning_items": [{"id": "rs_1"}],
|
||||
"tool_calls": [{"id": "call_1", "call_id": "call_1", "response_item_id": "fc_1",
|
||||
"type": "function", "function": {"name": "t", "arguments": "{}"}}]},
|
||||
]
|
||||
result = transport.convert_messages(msgs)
|
||||
assert "codex_reasoning_items" not in result[0]
|
||||
assert "call_id" not in result[0]["tool_calls"][0]
|
||||
assert "response_item_id" not in result[0]["tool_calls"][0]
|
||||
# Original list untouched (deepcopy-on-demand)
|
||||
assert "codex_reasoning_items" in msgs[0]
|
||||
|
||||
|
||||
class TestChatCompletionsBuildKwargs:
|
||||
|
||||
def test_basic_kwargs(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hello"}]
|
||||
kw = transport.build_kwargs(model="gpt-4o", messages=msgs, timeout=30.0)
|
||||
assert kw["model"] == "gpt-4o"
|
||||
assert kw["messages"][0]["content"] == "Hello"
|
||||
assert kw["timeout"] == 30.0
|
||||
|
||||
def test_developer_role_swap(self, transport):
|
||||
msgs = [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(model="gpt-5.4", messages=msgs, model_lower="gpt-5.4")
|
||||
assert kw["messages"][0]["role"] == "developer"
|
||||
|
||||
def test_no_developer_swap_for_non_gpt5(self, transport):
|
||||
msgs = [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(model="claude-sonnet-4", messages=msgs, model_lower="claude-sonnet-4")
|
||||
assert kw["messages"][0]["role"] == "system"
|
||||
|
||||
def test_tools_included(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
tools = [{"type": "function", "function": {"name": "test", "parameters": {}}}]
|
||||
kw = transport.build_kwargs(model="gpt-4o", messages=msgs, tools=tools)
|
||||
assert kw["tools"] == tools
|
||||
|
||||
def test_openrouter_provider_prefs(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-4o", messages=msgs,
|
||||
is_openrouter=True,
|
||||
provider_preferences={"only": ["openai"]},
|
||||
)
|
||||
assert kw["extra_body"]["provider"] == {"only": ["openai"]}
|
||||
|
||||
def test_nous_tags(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(model="gpt-4o", messages=msgs, is_nous=True)
|
||||
assert kw["extra_body"]["tags"] == ["product=hermes-agent"]
|
||||
|
||||
def test_reasoning_default(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-4o", messages=msgs,
|
||||
supports_reasoning=True,
|
||||
)
|
||||
assert kw["extra_body"]["reasoning"] == {"enabled": True, "effort": "medium"}
|
||||
|
||||
def test_nous_omits_disabled_reasoning(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-4o", messages=msgs,
|
||||
supports_reasoning=True,
|
||||
is_nous=True,
|
||||
reasoning_config={"enabled": False},
|
||||
)
|
||||
# Nous rejects enabled=false; reasoning omitted entirely
|
||||
assert "reasoning" not in kw.get("extra_body", {})
|
||||
|
||||
def test_ollama_num_ctx(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="llama3", messages=msgs,
|
||||
ollama_num_ctx=32768,
|
||||
)
|
||||
assert kw["extra_body"]["options"]["num_ctx"] == 32768
|
||||
|
||||
def test_custom_think_false(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="qwen3", messages=msgs,
|
||||
is_custom_provider=True,
|
||||
reasoning_config={"effort": "none"},
|
||||
)
|
||||
assert kw["extra_body"]["think"] is False
|
||||
|
||||
def test_max_tokens_with_fn(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-4o", messages=msgs,
|
||||
max_tokens=4096,
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
assert kw["max_tokens"] == 4096
|
||||
|
||||
def test_ephemeral_overrides_max_tokens(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-4o", messages=msgs,
|
||||
max_tokens=4096,
|
||||
ephemeral_max_output_tokens=2048,
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
assert kw["max_tokens"] == 2048
|
||||
|
||||
def test_nvidia_default_max_tokens(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="glm-4.7", messages=msgs,
|
||||
is_nvidia_nim=True,
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
# NVIDIA default: 16384
|
||||
assert kw["max_tokens"] == 16384
|
||||
|
||||
def test_qwen_default_max_tokens(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="qwen3-coder-plus", messages=msgs,
|
||||
is_qwen_portal=True,
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
# Qwen default: 65536
|
||||
assert kw["max_tokens"] == 65536
|
||||
|
||||
def test_anthropic_max_output_for_claude_on_aggregator(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="anthropic/claude-sonnet-4.6", messages=msgs,
|
||||
is_openrouter=True,
|
||||
anthropic_max_output=64000,
|
||||
)
|
||||
# Set as plain max_tokens (not via fn) because the aggregator proxies to
|
||||
# Anthropic Messages API which requires the field.
|
||||
assert kw["max_tokens"] == 64000
|
||||
|
||||
def test_request_overrides_last(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-4o", messages=msgs,
|
||||
request_overrides={"service_tier": "priority"},
|
||||
)
|
||||
assert kw["service_tier"] == "priority"
|
||||
|
||||
def test_fixed_temperature(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(model="gpt-4o", messages=msgs, fixed_temperature=0.6)
|
||||
assert kw["temperature"] == 0.6
|
||||
|
||||
def test_omit_temperature(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(model="gpt-4o", messages=msgs, omit_temperature=True, fixed_temperature=0.5)
|
||||
# omit wins
|
||||
assert "temperature" not in kw
|
||||
|
||||
|
||||
class TestChatCompletionsKimi:
|
||||
"""Regression tests for the Kimi/Moonshot quirks migrated into the transport."""
|
||||
|
||||
def test_kimi_max_tokens_default(self, transport):
|
||||
kw = transport.build_kwargs(
|
||||
model="kimi-k2", messages=[{"role": "user", "content": "Hi"}],
|
||||
is_kimi=True,
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
# Kimi CLI default: 32000
|
||||
assert kw["max_tokens"] == 32000
|
||||
|
||||
def test_kimi_reasoning_effort_top_level(self, transport):
|
||||
kw = transport.build_kwargs(
|
||||
model="kimi-k2", messages=[{"role": "user", "content": "Hi"}],
|
||||
is_kimi=True,
|
||||
reasoning_config={"effort": "high"},
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
# Kimi requires reasoning_effort as a top-level parameter
|
||||
assert kw["reasoning_effort"] == "high"
|
||||
|
||||
def test_kimi_reasoning_effort_omitted_when_thinking_disabled(self, transport):
|
||||
kw = transport.build_kwargs(
|
||||
model="kimi-k2", messages=[{"role": "user", "content": "Hi"}],
|
||||
is_kimi=True,
|
||||
reasoning_config={"enabled": False},
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
# Mirror Kimi CLI: omit reasoning_effort entirely when thinking off
|
||||
assert "reasoning_effort" not in kw
|
||||
|
||||
def test_kimi_thinking_enabled_extra_body(self, transport):
|
||||
kw = transport.build_kwargs(
|
||||
model="kimi-k2", messages=[{"role": "user", "content": "Hi"}],
|
||||
is_kimi=True,
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
assert kw["extra_body"]["thinking"] == {"type": "enabled"}
|
||||
|
||||
def test_kimi_thinking_disabled_extra_body(self, transport):
|
||||
kw = transport.build_kwargs(
|
||||
model="kimi-k2", messages=[{"role": "user", "content": "Hi"}],
|
||||
is_kimi=True,
|
||||
reasoning_config={"enabled": False},
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
assert kw["extra_body"]["thinking"] == {"type": "disabled"}
|
||||
|
||||
|
||||
class TestChatCompletionsValidate:
|
||||
|
||||
def test_none(self, transport):
|
||||
assert transport.validate_response(None) is False
|
||||
|
||||
def test_no_choices(self, transport):
|
||||
r = SimpleNamespace(choices=None)
|
||||
assert transport.validate_response(r) is False
|
||||
|
||||
def test_empty_choices(self, transport):
|
||||
r = SimpleNamespace(choices=[])
|
||||
assert transport.validate_response(r) is False
|
||||
|
||||
def test_valid(self, transport):
|
||||
r = SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="hi"))])
|
||||
assert transport.validate_response(r) is True
|
||||
|
||||
|
||||
class TestChatCompletionsNormalize:
|
||||
|
||||
def test_text_response(self, transport):
|
||||
r = SimpleNamespace(
|
||||
choices=[SimpleNamespace(
|
||||
message=SimpleNamespace(content="Hello", tool_calls=None, reasoning_content=None),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert isinstance(nr, NormalizedResponse)
|
||||
assert nr.content == "Hello"
|
||||
assert nr.finish_reason == "stop"
|
||||
assert nr.tool_calls is None
|
||||
|
||||
def test_tool_call_response(self, transport):
|
||||
tc = SimpleNamespace(
|
||||
id="call_123",
|
||||
function=SimpleNamespace(name="terminal", arguments='{"command": "ls"}'),
|
||||
)
|
||||
r = SimpleNamespace(
|
||||
choices=[SimpleNamespace(
|
||||
message=SimpleNamespace(content=None, tool_calls=[tc], reasoning_content=None),
|
||||
finish_reason="tool_calls",
|
||||
)],
|
||||
usage=SimpleNamespace(prompt_tokens=10, completion_tokens=20, total_tokens=30),
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert len(nr.tool_calls) == 1
|
||||
assert nr.tool_calls[0].name == "terminal"
|
||||
assert nr.tool_calls[0].id == "call_123"
|
||||
|
||||
def test_tool_call_extra_content_preserved(self, transport):
|
||||
"""Gemini 3 thinking models attach extra_content with thought_signature
|
||||
on tool_calls. Without this replay on the next turn, the API rejects
|
||||
the request with 400. The transport MUST surface extra_content so the
|
||||
agent loop can write it back into the assistant message."""
|
||||
tc = SimpleNamespace(
|
||||
id="call_gem",
|
||||
function=SimpleNamespace(name="terminal", arguments='{"command": "ls"}'),
|
||||
extra_content={"google": {"thought_signature": "SIG_ABC123"}},
|
||||
)
|
||||
r = SimpleNamespace(
|
||||
choices=[SimpleNamespace(
|
||||
message=SimpleNamespace(content=None, tool_calls=[tc], reasoning_content=None),
|
||||
finish_reason="tool_calls",
|
||||
)],
|
||||
usage=None,
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert nr.tool_calls[0].provider_data == {
|
||||
"extra_content": {"google": {"thought_signature": "SIG_ABC123"}}
|
||||
}
|
||||
|
||||
def test_reasoning_content_preserved_separately(self, transport):
|
||||
"""DeepSeek/Moonshot use reasoning_content distinct from reasoning.
|
||||
Don't merge them — the thinking-prefill retry check reads each field
|
||||
separately."""
|
||||
r = SimpleNamespace(
|
||||
choices=[SimpleNamespace(
|
||||
message=SimpleNamespace(
|
||||
content=None, tool_calls=None,
|
||||
reasoning="summary text",
|
||||
reasoning_content="detailed scratchpad",
|
||||
),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=None,
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert nr.reasoning == "summary text"
|
||||
assert nr.provider_data == {"reasoning_content": "detailed scratchpad"}
|
||||
|
||||
|
||||
class TestChatCompletionsCacheStats:
|
||||
|
||||
def test_no_usage(self, transport):
|
||||
r = SimpleNamespace(usage=None)
|
||||
assert transport.extract_cache_stats(r) is None
|
||||
|
||||
def test_no_details(self, transport):
|
||||
r = SimpleNamespace(usage=SimpleNamespace(prompt_tokens_details=None))
|
||||
assert transport.extract_cache_stats(r) is None
|
||||
|
||||
def test_with_cache(self, transport):
|
||||
details = SimpleNamespace(cached_tokens=500, cache_write_tokens=100)
|
||||
r = SimpleNamespace(usage=SimpleNamespace(prompt_tokens_details=details))
|
||||
result = transport.extract_cache_stats(r)
|
||||
assert result == {"cached_tokens": 500, "creation_tokens": 100}
|
||||
220
tests/agent/transports/test_codex_transport.py
Normal file
220
tests/agent/transports/test_codex_transport.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
"""Tests for the ResponsesApiTransport (Codex)."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from agent.transports import get_transport
|
||||
from agent.transports.types import NormalizedResponse, ToolCall
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transport():
|
||||
import agent.transports.codex # noqa: F401
|
||||
return get_transport("codex_responses")
|
||||
|
||||
|
||||
class TestCodexTransportBasic:
|
||||
|
||||
def test_api_mode(self, transport):
|
||||
assert transport.api_mode == "codex_responses"
|
||||
|
||||
def test_registered_on_import(self, transport):
|
||||
assert transport is not None
|
||||
|
||||
def test_convert_tools(self, transport):
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "terminal",
|
||||
"description": "Run a command",
|
||||
"parameters": {"type": "object", "properties": {"command": {"type": "string"}}},
|
||||
}
|
||||
}]
|
||||
result = transport.convert_tools(tools)
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == "function"
|
||||
assert result[0]["name"] == "terminal"
|
||||
|
||||
|
||||
class TestCodexBuildKwargs:
|
||||
|
||||
def test_basic_kwargs(self, transport):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-5.4",
|
||||
messages=messages,
|
||||
tools=[],
|
||||
)
|
||||
assert kw["model"] == "gpt-5.4"
|
||||
assert kw["instructions"] == "You are helpful."
|
||||
assert "input" in kw
|
||||
assert kw["store"] is False
|
||||
|
||||
def test_system_extracted_from_messages(self, transport):
|
||||
messages = [
|
||||
{"role": "system", "content": "Custom system prompt"},
|
||||
{"role": "user", "content": "Hi"},
|
||||
]
|
||||
kw = transport.build_kwargs(model="gpt-5.4", messages=messages, tools=[])
|
||||
assert kw["instructions"] == "Custom system prompt"
|
||||
|
||||
def test_no_system_uses_default(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(model="gpt-5.4", messages=messages, tools=[])
|
||||
assert kw["instructions"] # should be non-empty default
|
||||
|
||||
def test_reasoning_config(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-5.4", messages=messages, tools=[],
|
||||
reasoning_config={"effort": "high"},
|
||||
)
|
||||
assert kw.get("reasoning", {}).get("effort") == "high"
|
||||
|
||||
def test_reasoning_disabled(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-5.4", messages=messages, tools=[],
|
||||
reasoning_config={"enabled": False},
|
||||
)
|
||||
assert "reasoning" not in kw or kw.get("include") == []
|
||||
|
||||
def test_session_id_sets_cache_key(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-5.4", messages=messages, tools=[],
|
||||
session_id="test-session-123",
|
||||
)
|
||||
assert kw.get("prompt_cache_key") == "test-session-123"
|
||||
|
||||
def test_github_responses_no_cache_key(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-5.4", messages=messages, tools=[],
|
||||
session_id="test-session",
|
||||
is_github_responses=True,
|
||||
)
|
||||
assert "prompt_cache_key" not in kw
|
||||
|
||||
def test_max_tokens(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-5.4", messages=messages, tools=[],
|
||||
max_tokens=4096,
|
||||
)
|
||||
assert kw.get("max_output_tokens") == 4096
|
||||
|
||||
def test_codex_backend_no_max_output_tokens(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-5.4", messages=messages, tools=[],
|
||||
max_tokens=4096,
|
||||
is_codex_backend=True,
|
||||
)
|
||||
assert "max_output_tokens" not in kw
|
||||
|
||||
def test_xai_headers(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="grok-3", messages=messages, tools=[],
|
||||
session_id="conv-123",
|
||||
is_xai_responses=True,
|
||||
)
|
||||
assert kw.get("extra_headers", {}).get("x-grok-conv-id") == "conv-123"
|
||||
|
||||
def test_minimal_effort_clamped(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-5.4", messages=messages, tools=[],
|
||||
reasoning_config={"effort": "minimal"},
|
||||
)
|
||||
# "minimal" should be clamped to "low"
|
||||
assert kw.get("reasoning", {}).get("effort") == "low"
|
||||
|
||||
|
||||
class TestCodexValidateResponse:
|
||||
|
||||
def test_none_response(self, transport):
|
||||
assert transport.validate_response(None) is False
|
||||
|
||||
def test_empty_output(self, transport):
|
||||
r = SimpleNamespace(output=[], output_text=None)
|
||||
assert transport.validate_response(r) is False
|
||||
|
||||
def test_valid_output(self, transport):
|
||||
r = SimpleNamespace(output=[{"type": "message", "content": []}])
|
||||
assert transport.validate_response(r) is True
|
||||
|
||||
def test_output_text_fallback_not_valid(self, transport):
|
||||
"""validate_response is strict — output_text doesn't make it valid.
|
||||
The caller handles output_text fallback with diagnostic logging."""
|
||||
r = SimpleNamespace(output=None, output_text="Some text")
|
||||
assert transport.validate_response(r) is False
|
||||
|
||||
|
||||
class TestCodexMapFinishReason:
|
||||
|
||||
def test_completed(self, transport):
|
||||
assert transport.map_finish_reason("completed") == "stop"
|
||||
|
||||
def test_incomplete(self, transport):
|
||||
assert transport.map_finish_reason("incomplete") == "length"
|
||||
|
||||
def test_failed(self, transport):
|
||||
assert transport.map_finish_reason("failed") == "stop"
|
||||
|
||||
def test_unknown(self, transport):
|
||||
assert transport.map_finish_reason("unknown_status") == "stop"
|
||||
|
||||
|
||||
class TestCodexNormalizeResponse:
|
||||
|
||||
def test_text_response(self, transport):
|
||||
"""Normalize a simple text Codex response."""
|
||||
r = SimpleNamespace(
|
||||
output=[
|
||||
SimpleNamespace(
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=[SimpleNamespace(type="output_text", text="Hello world")],
|
||||
status="completed",
|
||||
),
|
||||
],
|
||||
status="completed",
|
||||
incomplete_details=None,
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=5,
|
||||
input_tokens_details=None, output_tokens_details=None),
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert isinstance(nr, NormalizedResponse)
|
||||
assert nr.content == "Hello world"
|
||||
assert nr.finish_reason == "stop"
|
||||
|
||||
def test_tool_call_response(self, transport):
|
||||
"""Normalize a Codex response with tool calls."""
|
||||
r = SimpleNamespace(
|
||||
output=[
|
||||
SimpleNamespace(
|
||||
type="function_call",
|
||||
call_id="call_abc123",
|
||||
name="terminal",
|
||||
arguments=json.dumps({"command": "ls"}),
|
||||
id="fc_abc123",
|
||||
status="completed",
|
||||
),
|
||||
],
|
||||
status="completed",
|
||||
incomplete_details=None,
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=20,
|
||||
input_tokens_details=None, output_tokens_details=None),
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert nr.finish_reason == "tool_calls"
|
||||
assert len(nr.tool_calls) == 1
|
||||
tc = nr.tool_calls[0]
|
||||
assert tc.name == "terminal"
|
||||
assert '"command"' in tc.arguments
|
||||
220
tests/agent/transports/test_transport.py
Normal file
220
tests/agent/transports/test_transport.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
"""Tests for the transport ABC, registry, and AnthropicTransport."""
|
||||
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from agent.transports.base import ProviderTransport
|
||||
from agent.transports.types import NormalizedResponse, ToolCall, Usage
|
||||
from agent.transports import get_transport, register_transport, _REGISTRY
|
||||
|
||||
|
||||
# ── ABC contract tests ──────────────────────────────────────────────────
|
||||
|
||||
class TestProviderTransportABC:
|
||||
"""Verify the ABC contract is enforceable."""
|
||||
|
||||
def test_cannot_instantiate_abc(self):
|
||||
with pytest.raises(TypeError):
|
||||
ProviderTransport()
|
||||
|
||||
def test_concrete_must_implement_all_abstract(self):
|
||||
class Incomplete(ProviderTransport):
|
||||
@property
|
||||
def api_mode(self):
|
||||
return "test"
|
||||
with pytest.raises(TypeError):
|
||||
Incomplete()
|
||||
|
||||
def test_minimal_concrete(self):
|
||||
class Minimal(ProviderTransport):
|
||||
@property
|
||||
def api_mode(self):
|
||||
return "test_minimal"
|
||||
def convert_messages(self, messages, **kw):
|
||||
return messages
|
||||
def convert_tools(self, tools):
|
||||
return tools
|
||||
def build_kwargs(self, model, messages, tools=None, **params):
|
||||
return {"model": model, "messages": messages}
|
||||
def normalize_response(self, response, **kw):
|
||||
return NormalizedResponse(content="ok", tool_calls=None, finish_reason="stop")
|
||||
|
||||
t = Minimal()
|
||||
assert t.api_mode == "test_minimal"
|
||||
assert t.validate_response(None) is True # default
|
||||
assert t.extract_cache_stats(None) is None # default
|
||||
assert t.map_finish_reason("end_turn") == "end_turn" # default passthrough
|
||||
|
||||
|
||||
# ── Registry tests ───────────────────────────────────────────────────────
|
||||
|
||||
class TestTransportRegistry:
|
||||
|
||||
def test_get_unregistered_returns_none(self):
|
||||
assert get_transport("nonexistent_mode") is None
|
||||
|
||||
def test_anthropic_registered_on_import(self):
|
||||
import agent.transports.anthropic # noqa: F401
|
||||
t = get_transport("anthropic_messages")
|
||||
assert t is not None
|
||||
assert t.api_mode == "anthropic_messages"
|
||||
|
||||
def test_register_and_get(self):
|
||||
class DummyTransport(ProviderTransport):
|
||||
@property
|
||||
def api_mode(self):
|
||||
return "dummy_test"
|
||||
def convert_messages(self, messages, **kw):
|
||||
return messages
|
||||
def convert_tools(self, tools):
|
||||
return tools
|
||||
def build_kwargs(self, model, messages, tools=None, **params):
|
||||
return {}
|
||||
def normalize_response(self, response, **kw):
|
||||
return NormalizedResponse(content=None, tool_calls=None, finish_reason="stop")
|
||||
|
||||
register_transport("dummy_test", DummyTransport)
|
||||
t = get_transport("dummy_test")
|
||||
assert t.api_mode == "dummy_test"
|
||||
# Cleanup
|
||||
_REGISTRY.pop("dummy_test", None)
|
||||
|
||||
|
||||
# ── AnthropicTransport tests ────────────────────────────────────────────
|
||||
|
||||
class TestAnthropicTransport:
|
||||
|
||||
@pytest.fixture
|
||||
def transport(self):
|
||||
import agent.transports.anthropic # noqa: F401
|
||||
return get_transport("anthropic_messages")
|
||||
|
||||
def test_api_mode(self, transport):
|
||||
assert transport.api_mode == "anthropic_messages"
|
||||
|
||||
def test_convert_tools_simple(self, transport):
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"description": "A test",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
}
|
||||
}]
|
||||
result = transport.convert_tools(tools)
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "test_tool"
|
||||
assert "input_schema" in result[0]
|
||||
|
||||
def test_validate_response_none(self, transport):
|
||||
assert transport.validate_response(None) is False
|
||||
|
||||
def test_validate_response_empty_content(self, transport):
|
||||
r = SimpleNamespace(content=[])
|
||||
assert transport.validate_response(r) is False
|
||||
|
||||
def test_validate_response_valid(self, transport):
|
||||
r = SimpleNamespace(content=[SimpleNamespace(type="text", text="hello")])
|
||||
assert transport.validate_response(r) is True
|
||||
|
||||
def test_map_finish_reason(self, transport):
|
||||
assert transport.map_finish_reason("end_turn") == "stop"
|
||||
assert transport.map_finish_reason("tool_use") == "tool_calls"
|
||||
assert transport.map_finish_reason("max_tokens") == "length"
|
||||
assert transport.map_finish_reason("stop_sequence") == "stop"
|
||||
assert transport.map_finish_reason("refusal") == "content_filter"
|
||||
assert transport.map_finish_reason("model_context_window_exceeded") == "length"
|
||||
assert transport.map_finish_reason("unknown") == "stop"
|
||||
|
||||
def test_extract_cache_stats_none_usage(self, transport):
|
||||
r = SimpleNamespace(usage=None)
|
||||
assert transport.extract_cache_stats(r) is None
|
||||
|
||||
def test_extract_cache_stats_with_cache(self, transport):
|
||||
usage = SimpleNamespace(cache_read_input_tokens=100, cache_creation_input_tokens=50)
|
||||
r = SimpleNamespace(usage=usage)
|
||||
result = transport.extract_cache_stats(r)
|
||||
assert result == {"cached_tokens": 100, "creation_tokens": 50}
|
||||
|
||||
def test_extract_cache_stats_zero(self, transport):
|
||||
usage = SimpleNamespace(cache_read_input_tokens=0, cache_creation_input_tokens=0)
|
||||
r = SimpleNamespace(usage=usage)
|
||||
assert transport.extract_cache_stats(r) is None
|
||||
|
||||
def test_normalize_response_text(self, transport):
|
||||
"""Test normalization of a simple text response."""
|
||||
r = SimpleNamespace(
|
||||
content=[SimpleNamespace(type="text", text="Hello world")],
|
||||
stop_reason="end_turn",
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=5),
|
||||
model="claude-sonnet-4-6",
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert isinstance(nr, NormalizedResponse)
|
||||
assert nr.content == "Hello world"
|
||||
assert nr.tool_calls is None or nr.tool_calls == []
|
||||
assert nr.finish_reason == "stop"
|
||||
|
||||
def test_normalize_response_tool_calls(self, transport):
|
||||
"""Test normalization of a tool-use response."""
|
||||
r = SimpleNamespace(
|
||||
content=[
|
||||
SimpleNamespace(
|
||||
type="tool_use",
|
||||
id="toolu_123",
|
||||
name="terminal",
|
||||
input={"command": "ls"},
|
||||
),
|
||||
],
|
||||
stop_reason="tool_use",
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=20),
|
||||
model="claude-sonnet-4-6",
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert nr.finish_reason == "tool_calls"
|
||||
assert len(nr.tool_calls) == 1
|
||||
tc = nr.tool_calls[0]
|
||||
assert tc.name == "terminal"
|
||||
assert tc.id == "toolu_123"
|
||||
assert '"command"' in tc.arguments
|
||||
|
||||
def test_normalize_response_thinking(self, transport):
|
||||
"""Test normalization preserves thinking content."""
|
||||
r = SimpleNamespace(
|
||||
content=[
|
||||
SimpleNamespace(type="thinking", thinking="Let me think..."),
|
||||
SimpleNamespace(type="text", text="The answer is 42"),
|
||||
],
|
||||
stop_reason="end_turn",
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=15),
|
||||
model="claude-sonnet-4-6",
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert nr.content == "The answer is 42"
|
||||
assert nr.reasoning == "Let me think..."
|
||||
|
||||
def test_build_kwargs_returns_dict(self, transport):
|
||||
"""Test build_kwargs produces a usable kwargs dict."""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="claude-sonnet-4-6",
|
||||
messages=messages,
|
||||
max_tokens=1024,
|
||||
)
|
||||
assert isinstance(kw, dict)
|
||||
assert "model" in kw
|
||||
assert "max_tokens" in kw
|
||||
assert "messages" in kw
|
||||
|
||||
def test_convert_messages_extracts_system(self, transport):
|
||||
"""Test convert_messages separates system from messages."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hi"},
|
||||
]
|
||||
system, msgs = transport.convert_messages(messages)
|
||||
# System should be extracted
|
||||
assert system is not None
|
||||
# Messages should only have user
|
||||
assert len(msgs) >= 1
|
||||
151
tests/agent/transports/test_types.py
Normal file
151
tests/agent/transports/test_types.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
"""Tests for agent/transports/types.py — dataclass construction + helpers."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from agent.transports.types import (
|
||||
NormalizedResponse,
|
||||
ToolCall,
|
||||
Usage,
|
||||
build_tool_call,
|
||||
map_finish_reason,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ToolCall
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestToolCall:
|
||||
def test_basic_construction(self):
|
||||
tc = ToolCall(id="call_abc", name="terminal", arguments='{"cmd": "ls"}')
|
||||
assert tc.id == "call_abc"
|
||||
assert tc.name == "terminal"
|
||||
assert tc.arguments == '{"cmd": "ls"}'
|
||||
assert tc.provider_data is None
|
||||
|
||||
def test_none_id(self):
|
||||
tc = ToolCall(id=None, name="read_file", arguments="{}")
|
||||
assert tc.id is None
|
||||
|
||||
def test_provider_data(self):
|
||||
tc = ToolCall(
|
||||
id="call_x",
|
||||
name="t",
|
||||
arguments="{}",
|
||||
provider_data={"call_id": "call_x", "response_item_id": "fc_x"},
|
||||
)
|
||||
assert tc.provider_data["call_id"] == "call_x"
|
||||
assert tc.provider_data["response_item_id"] == "fc_x"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Usage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUsage:
|
||||
def test_defaults(self):
|
||||
u = Usage()
|
||||
assert u.prompt_tokens == 0
|
||||
assert u.completion_tokens == 0
|
||||
assert u.total_tokens == 0
|
||||
assert u.cached_tokens == 0
|
||||
|
||||
def test_explicit(self):
|
||||
u = Usage(prompt_tokens=100, completion_tokens=50, total_tokens=150, cached_tokens=80)
|
||||
assert u.total_tokens == 150
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# NormalizedResponse
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNormalizedResponse:
|
||||
def test_text_only(self):
|
||||
r = NormalizedResponse(content="hello", tool_calls=None, finish_reason="stop")
|
||||
assert r.content == "hello"
|
||||
assert r.tool_calls is None
|
||||
assert r.finish_reason == "stop"
|
||||
assert r.reasoning is None
|
||||
assert r.usage is None
|
||||
assert r.provider_data is None
|
||||
|
||||
def test_with_tool_calls(self):
|
||||
tcs = [ToolCall(id="call_1", name="terminal", arguments='{"cmd":"pwd"}')]
|
||||
r = NormalizedResponse(content=None, tool_calls=tcs, finish_reason="tool_calls")
|
||||
assert r.finish_reason == "tool_calls"
|
||||
assert len(r.tool_calls) == 1
|
||||
assert r.tool_calls[0].name == "terminal"
|
||||
|
||||
def test_with_reasoning(self):
|
||||
r = NormalizedResponse(
|
||||
content="answer",
|
||||
tool_calls=None,
|
||||
finish_reason="stop",
|
||||
reasoning="I thought about it",
|
||||
)
|
||||
assert r.reasoning == "I thought about it"
|
||||
|
||||
def test_with_provider_data(self):
|
||||
r = NormalizedResponse(
|
||||
content=None,
|
||||
tool_calls=None,
|
||||
finish_reason="stop",
|
||||
provider_data={"reasoning_details": [{"type": "thinking", "thinking": "hmm"}]},
|
||||
)
|
||||
assert r.provider_data["reasoning_details"][0]["type"] == "thinking"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_tool_call
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBuildToolCall:
|
||||
def test_dict_arguments_serialized(self):
|
||||
tc = build_tool_call(id="call_1", name="terminal", arguments={"cmd": "ls"})
|
||||
assert tc.arguments == json.dumps({"cmd": "ls"})
|
||||
assert tc.provider_data is None
|
||||
|
||||
def test_string_arguments_passthrough(self):
|
||||
tc = build_tool_call(id="call_2", name="read_file", arguments='{"path": "/tmp"}')
|
||||
assert tc.arguments == '{"path": "/tmp"}'
|
||||
|
||||
def test_provider_fields(self):
|
||||
tc = build_tool_call(
|
||||
id="call_3",
|
||||
name="terminal",
|
||||
arguments="{}",
|
||||
call_id="call_3",
|
||||
response_item_id="fc_3",
|
||||
)
|
||||
assert tc.provider_data == {"call_id": "call_3", "response_item_id": "fc_3"}
|
||||
|
||||
def test_none_id(self):
|
||||
tc = build_tool_call(id=None, name="t", arguments="{}")
|
||||
assert tc.id is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# map_finish_reason
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMapFinishReason:
|
||||
ANTHROPIC_MAP = {
|
||||
"end_turn": "stop",
|
||||
"tool_use": "tool_calls",
|
||||
"max_tokens": "length",
|
||||
"stop_sequence": "stop",
|
||||
"refusal": "content_filter",
|
||||
}
|
||||
|
||||
def test_known_reason(self):
|
||||
assert map_finish_reason("end_turn", self.ANTHROPIC_MAP) == "stop"
|
||||
assert map_finish_reason("tool_use", self.ANTHROPIC_MAP) == "tool_calls"
|
||||
assert map_finish_reason("max_tokens", self.ANTHROPIC_MAP) == "length"
|
||||
assert map_finish_reason("refusal", self.ANTHROPIC_MAP) == "content_filter"
|
||||
|
||||
def test_unknown_reason_defaults_to_stop(self):
|
||||
assert map_finish_reason("something_new", self.ANTHROPIC_MAP) == "stop"
|
||||
|
||||
def test_none_reason(self):
|
||||
assert map_finish_reason(None, self.ANTHROPIC_MAP) == "stop"
|
||||
|
|
@ -254,3 +254,88 @@ class TestCliApprovalUi:
|
|||
|
||||
# Command got truncated with a marker.
|
||||
assert "(command truncated" in rendered
|
||||
|
||||
|
||||
class TestApprovalCallbackThreadLocalWiring:
|
||||
"""Regression guard for the thread-local callback freeze (#13617 / #13618).
|
||||
|
||||
After 62348cff made _approval_callback / _sudo_password_callback thread-local
|
||||
(ACP GHSA-qg5c-hvr5-hjgr), the CLI agent thread could no longer see callbacks
|
||||
registered in the main thread — the dangerous-command prompt silently fell
|
||||
back to stdin input() and deadlocked against prompt_toolkit. The fix is to
|
||||
register the callbacks INSIDE the agent worker thread (matching the ACP
|
||||
pattern). These tests lock in that invariant.
|
||||
"""
|
||||
|
||||
def test_main_thread_registration_is_invisible_to_child_thread(self):
|
||||
"""Confirms the underlying threading.local semantics that drove the bug.
|
||||
|
||||
If this ever starts passing as "visible", the thread-local isolation
|
||||
is gone and the ACP race GHSA-qg5c-hvr5-hjgr may be back.
|
||||
"""
|
||||
from tools.terminal_tool import (
|
||||
set_approval_callback,
|
||||
_get_approval_callback,
|
||||
)
|
||||
|
||||
def main_cb(_cmd, _desc):
|
||||
return "once"
|
||||
|
||||
set_approval_callback(main_cb)
|
||||
try:
|
||||
seen = {}
|
||||
|
||||
def _child():
|
||||
seen["value"] = _get_approval_callback()
|
||||
|
||||
t = threading.Thread(target=_child, daemon=True)
|
||||
t.start()
|
||||
t.join(timeout=2)
|
||||
assert seen["value"] is None
|
||||
finally:
|
||||
set_approval_callback(None)
|
||||
|
||||
def test_child_thread_registration_is_visible_and_cleared_in_finally(self):
|
||||
"""The fix pattern: register INSIDE the worker thread, clear in finally.
|
||||
|
||||
This is exactly what cli.py's run_agent() closure does. If this test
|
||||
fails, the CLI approval prompt freeze (#13617) has regressed.
|
||||
"""
|
||||
from tools.terminal_tool import (
|
||||
set_approval_callback,
|
||||
set_sudo_password_callback,
|
||||
_get_approval_callback,
|
||||
_get_sudo_password_callback,
|
||||
)
|
||||
|
||||
def approval_cb(_cmd, _desc):
|
||||
return "once"
|
||||
|
||||
def sudo_cb():
|
||||
return "hunter2"
|
||||
|
||||
seen = {}
|
||||
|
||||
def _worker():
|
||||
# Mimic cli.py's run_agent() thread target.
|
||||
set_approval_callback(approval_cb)
|
||||
set_sudo_password_callback(sudo_cb)
|
||||
try:
|
||||
seen["approval"] = _get_approval_callback()
|
||||
seen["sudo"] = _get_sudo_password_callback()
|
||||
finally:
|
||||
set_approval_callback(None)
|
||||
set_sudo_password_callback(None)
|
||||
seen["approval_after"] = _get_approval_callback()
|
||||
seen["sudo_after"] = _get_sudo_password_callback()
|
||||
|
||||
t = threading.Thread(target=_worker, daemon=True)
|
||||
t.start()
|
||||
t.join(timeout=2)
|
||||
|
||||
assert seen["approval"] is approval_cb
|
||||
assert seen["sudo"] is sudo_cb
|
||||
# Finally block must clear both slots — otherwise a reused thread
|
||||
# would hold a stale reference to a disposed CLI instance.
|
||||
assert seen["approval_after"] is None
|
||||
assert seen["sudo_after"] is None
|
||||
|
|
|
|||
105
tests/cli/test_cli_external_editor.py
Normal file
105
tests/cli/test_cli_external_editor.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
"""Tests for CLI external-editor support."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from cli import HermesCLI
|
||||
|
||||
|
||||
class _FakeBuffer:
|
||||
def __init__(self, text=""):
|
||||
self.calls = []
|
||||
self.text = text
|
||||
self.cursor_position = len(text)
|
||||
|
||||
def open_in_editor(self, validate_and_handle=False):
|
||||
self.calls.append(validate_and_handle)
|
||||
|
||||
|
||||
class _FakeApp:
|
||||
def __init__(self):
|
||||
self.current_buffer = _FakeBuffer()
|
||||
|
||||
|
||||
def _make_cli(with_app=True):
|
||||
cli_obj = HermesCLI.__new__(HermesCLI)
|
||||
cli_obj._app = _FakeApp() if with_app else None
|
||||
cli_obj._command_running = False
|
||||
cli_obj._command_status = ""
|
||||
cli_obj._command_display = ""
|
||||
cli_obj._sudo_state = None
|
||||
cli_obj._secret_state = None
|
||||
cli_obj._approval_state = None
|
||||
cli_obj._clarify_state = None
|
||||
cli_obj._skip_paste_collapse = False
|
||||
return cli_obj
|
||||
|
||||
def test_open_external_editor_uses_prompt_toolkit_buffer_editor():
|
||||
cli_obj = _make_cli()
|
||||
|
||||
assert cli_obj._open_external_editor() is True
|
||||
assert cli_obj._app.current_buffer.calls == [False]
|
||||
|
||||
|
||||
def test_open_external_editor_rejects_when_no_tui():
|
||||
cli_obj = _make_cli(with_app=False)
|
||||
|
||||
with patch("cli._cprint") as mock_cprint:
|
||||
assert cli_obj._open_external_editor() is False
|
||||
|
||||
assert mock_cprint.called
|
||||
assert "interactive cli" in str(mock_cprint.call_args).lower()
|
||||
|
||||
|
||||
def test_open_external_editor_rejects_modal_prompts():
|
||||
cli_obj = _make_cli()
|
||||
cli_obj._approval_state = {"selected": 0}
|
||||
|
||||
with patch("cli._cprint") as mock_cprint:
|
||||
assert cli_obj._open_external_editor() is False
|
||||
|
||||
assert mock_cprint.called
|
||||
assert "active prompt" in str(mock_cprint.call_args).lower()
|
||||
|
||||
def test_open_external_editor_uses_explicit_buffer_when_provided():
|
||||
cli_obj = _make_cli()
|
||||
external_buffer = _FakeBuffer()
|
||||
|
||||
assert cli_obj._open_external_editor(buffer=external_buffer) is True
|
||||
assert external_buffer.calls == [False]
|
||||
assert cli_obj._app.current_buffer.calls == []
|
||||
|
||||
|
||||
def test_expand_paste_references_replaces_placeholder_with_file_contents(tmp_path):
|
||||
cli_obj = _make_cli()
|
||||
paste_file = tmp_path / "paste.txt"
|
||||
paste_file.write_text("line one\nline two", encoding="utf-8")
|
||||
|
||||
text = f"before [Pasted text #1: 2 lines → {paste_file}] after"
|
||||
expanded = cli_obj._expand_paste_references(text)
|
||||
|
||||
assert expanded == "before line one\nline two after"
|
||||
|
||||
|
||||
def test_open_external_editor_expands_paste_placeholders_before_open(tmp_path):
|
||||
cli_obj = _make_cli()
|
||||
paste_file = tmp_path / "paste.txt"
|
||||
paste_file.write_text("alpha\nbeta", encoding="utf-8")
|
||||
buffer = _FakeBuffer(text=f"[Pasted text #1: 2 lines → {paste_file}]")
|
||||
|
||||
assert cli_obj._open_external_editor(buffer=buffer) is True
|
||||
assert buffer.text == "alpha\nbeta"
|
||||
assert buffer.cursor_position == len("alpha\nbeta")
|
||||
assert buffer.calls == [False]
|
||||
|
||||
|
||||
def test_open_external_editor_sets_skip_collapse_flag_during_expansion(tmp_path):
|
||||
cli_obj = _make_cli()
|
||||
paste_file = tmp_path / "paste.txt"
|
||||
paste_file.write_text("a\nb\nc\nd\ne\nf", encoding="utf-8")
|
||||
buffer = _FakeBuffer(text=f"[Pasted text #1: 6 lines \u2192 {paste_file}]")
|
||||
|
||||
# After expansion the flag should have been set (to prevent re-collapse)
|
||||
assert cli_obj._open_external_editor(buffer=buffer) is True
|
||||
# Flag is consumed by _on_text_changed, but since no handler is attached
|
||||
# in tests it stays True until the handler resets it.
|
||||
assert cli_obj._skip_paste_collapse is True
|
||||
|
|
@ -147,6 +147,37 @@ class TestEscapedSpaces:
|
|||
assert result["path"] == tmp_image_with_spaces
|
||||
assert result["remainder"] == "what is this?"
|
||||
|
||||
def test_unquoted_spaces_in_path(self, tmp_image_with_spaces):
|
||||
result = _detect_file_drop(str(tmp_image_with_spaces))
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_image_with_spaces
|
||||
assert result["is_image"] is True
|
||||
assert result["remainder"] == ""
|
||||
|
||||
def test_unquoted_spaces_with_trailing_text(self, tmp_image_with_spaces):
|
||||
user_input = f"{tmp_image_with_spaces} what is this?"
|
||||
result = _detect_file_drop(user_input)
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_image_with_spaces
|
||||
assert result["remainder"] == "what is this?"
|
||||
|
||||
def test_mixed_escaped_and_literal_spaces_in_path(self, tmp_path):
|
||||
img = tmp_path / "Screenshot 2026-04-21 at 1.04.43 PM.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n")
|
||||
mixed = str(img).replace("Screenshot ", "Screenshot\\ ").replace("2026-04-21 ", "2026-04-21\\ ").replace("at ", "at\\ ")
|
||||
result = _detect_file_drop(mixed)
|
||||
assert result is not None
|
||||
assert result["path"] == img
|
||||
assert result["is_image"] is True
|
||||
assert result["remainder"] == ""
|
||||
|
||||
def test_file_uri_image_path(self, tmp_image_with_spaces):
|
||||
uri = tmp_image_with_spaces.as_uri()
|
||||
result = _detect_file_drop(uri)
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_image_with_spaces
|
||||
assert result["is_image"] is True
|
||||
|
||||
def test_tilde_prefixed_path(self, tmp_path, monkeypatch):
|
||||
home = tmp_path / "home"
|
||||
img = home / "storage" / "shared" / "Pictures" / "cat.png"
|
||||
|
|
|
|||
141
tests/cli/test_cli_markdown_rendering.py
Normal file
141
tests/cli/test_cli_markdown_rendering.py
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
from io import StringIO
|
||||
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
||||
from cli import _render_final_assistant_content
|
||||
|
||||
|
||||
def _render_to_text(renderable) -> str:
|
||||
buf = StringIO()
|
||||
Console(file=buf, width=80, force_terminal=False, color_system=None).print(renderable)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def test_final_assistant_content_uses_markdown_renderable():
|
||||
renderable = _render_final_assistant_content("# Title\n\n- one\n- two")
|
||||
|
||||
assert isinstance(renderable, Markdown)
|
||||
output = _render_to_text(renderable)
|
||||
assert "Title" in output
|
||||
assert "one" in output
|
||||
assert "two" in output
|
||||
|
||||
|
||||
def test_final_assistant_content_strips_ansi_before_markdown_rendering():
|
||||
renderable = _render_final_assistant_content("\x1b[31m# Title\x1b[0m")
|
||||
|
||||
output = _render_to_text(renderable)
|
||||
assert "Title" in output
|
||||
assert "\x1b" not in output
|
||||
|
||||
|
||||
def test_final_assistant_content_can_strip_markdown_syntax():
|
||||
renderable = _render_final_assistant_content(
|
||||
"***Bold italic***\n~~Strike~~\n- item\n# Title\n`code`",
|
||||
mode="strip",
|
||||
)
|
||||
|
||||
output = _render_to_text(renderable)
|
||||
assert "Bold italic" in output
|
||||
assert "Strike" in output
|
||||
assert "item" in output
|
||||
assert "Title" in output
|
||||
assert "code" in output
|
||||
assert "***" not in output
|
||||
assert "~~" not in output
|
||||
assert "`" not in output
|
||||
|
||||
|
||||
def test_strip_mode_preserves_lists():
|
||||
renderable = _render_final_assistant_content(
|
||||
"**Formatting**\n- Ran prettier\n- Files changed\n- Verified clean",
|
||||
mode="strip",
|
||||
)
|
||||
|
||||
output = _render_to_text(renderable)
|
||||
assert "- Ran prettier" in output
|
||||
assert "- Files changed" in output
|
||||
assert "- Verified clean" in output
|
||||
assert "**" not in output
|
||||
|
||||
|
||||
def test_strip_mode_preserves_ordered_lists():
|
||||
renderable = _render_final_assistant_content(
|
||||
"1. First item\n2. Second item\n3. Third item",
|
||||
mode="strip",
|
||||
)
|
||||
|
||||
output = _render_to_text(renderable)
|
||||
assert "1. First" in output
|
||||
assert "2. Second" in output
|
||||
assert "3. Third" in output
|
||||
|
||||
|
||||
def test_strip_mode_preserves_blockquotes():
|
||||
renderable = _render_final_assistant_content(
|
||||
"> This is quoted text\n> Another quoted line",
|
||||
mode="strip",
|
||||
)
|
||||
|
||||
output = _render_to_text(renderable)
|
||||
assert "> This is quoted" in output
|
||||
assert "> Another quoted" in output
|
||||
|
||||
|
||||
def test_strip_mode_preserves_checkboxes():
|
||||
renderable = _render_final_assistant_content(
|
||||
"- [ ] Todo item\n- [x] Done item",
|
||||
mode="strip",
|
||||
)
|
||||
|
||||
output = _render_to_text(renderable)
|
||||
assert "- [ ] Todo" in output
|
||||
assert "- [x] Done" in output
|
||||
|
||||
|
||||
def test_strip_mode_preserves_table_structure_while_cleaning_cell_markdown():
|
||||
renderable = _render_final_assistant_content(
|
||||
"| Syntax | Example |\n|---|---|\n| Bold | `**bold**` |\n| Strike | `~~strike~~` |",
|
||||
mode="strip",
|
||||
)
|
||||
|
||||
output = _render_to_text(renderable)
|
||||
assert "| Syntax | Example |" in output
|
||||
assert "|---|---|" in output
|
||||
assert "| Bold | bold |" in output
|
||||
assert "| Strike | strike |" in output
|
||||
assert "**" not in output
|
||||
assert "~~" not in output
|
||||
assert "`" not in output
|
||||
|
||||
|
||||
def test_final_assistant_content_can_leave_markdown_raw():
|
||||
renderable = _render_final_assistant_content("***Bold italic***", mode="raw")
|
||||
|
||||
output = _render_to_text(renderable)
|
||||
assert "***Bold italic***" in output
|
||||
|
||||
|
||||
def test_strip_mode_preserves_intraword_underscores_in_snake_case_identifiers():
|
||||
renderable = _render_final_assistant_content(
|
||||
"Let me look at test_case_with_underscores and SOME_CONST "
|
||||
"then /tmp/snake_case_dir/file_with_name.py",
|
||||
mode="strip",
|
||||
)
|
||||
|
||||
output = _render_to_text(renderable)
|
||||
assert "test_case_with_underscores" in output
|
||||
assert "SOME_CONST" in output
|
||||
assert "snake_case_dir" in output
|
||||
assert "file_with_name" in output
|
||||
|
||||
|
||||
def test_strip_mode_still_strips_boundary_underscore_emphasis():
|
||||
renderable = _render_final_assistant_content(
|
||||
"say _hi_ and __bold__ now",
|
||||
mode="strip",
|
||||
)
|
||||
|
||||
output = _render_to_text(renderable)
|
||||
assert "say hi and bold now" in output
|
||||
|
|
@ -207,48 +207,11 @@ def test_cli_turn_routing_uses_primary_when_disabled(monkeypatch):
|
|||
shell.api_mode = "chat_completions"
|
||||
shell.base_url = "https://openrouter.ai/api/v1"
|
||||
shell.api_key = "sk-primary"
|
||||
shell._smart_model_routing = {"enabled": False}
|
||||
|
||||
result = shell._resolve_turn_agent_config("what time is it in tokyo?")
|
||||
|
||||
assert result["model"] == "gpt-5"
|
||||
assert result["runtime"]["provider"] == "openrouter"
|
||||
assert result["label"] is None
|
||||
|
||||
|
||||
def test_cli_turn_routing_uses_cheap_model_when_simple(monkeypatch):
|
||||
cli = _import_cli()
|
||||
|
||||
def _runtime_resolve(**kwargs):
|
||||
assert kwargs["requested"] == "zai"
|
||||
return {
|
||||
"provider": "zai",
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": "https://open.z.ai/api/v1",
|
||||
"api_key": "cheap-key",
|
||||
"source": "env/config",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _runtime_resolve)
|
||||
|
||||
shell = cli.HermesCLI(model="anthropic/claude-sonnet-4", compact=True, max_turns=1)
|
||||
shell.provider = "openrouter"
|
||||
shell.api_mode = "chat_completions"
|
||||
shell.base_url = "https://openrouter.ai/api/v1"
|
||||
shell.api_key = "primary-key"
|
||||
shell._smart_model_routing = {
|
||||
"enabled": True,
|
||||
"cheap_model": {"provider": "zai", "model": "glm-5-air"},
|
||||
"max_simple_chars": 160,
|
||||
"max_simple_words": 28,
|
||||
}
|
||||
|
||||
result = shell._resolve_turn_agent_config("what time is it in tokyo?")
|
||||
|
||||
assert result["model"] == "glm-5-air"
|
||||
assert result["runtime"]["provider"] == "zai"
|
||||
assert result["runtime"]["api_key"] == "cheap-key"
|
||||
assert result["label"] is not None
|
||||
|
||||
|
||||
def test_cli_prefers_config_provider_over_stale_env_override(monkeypatch):
|
||||
|
|
|
|||
146
tests/cli/test_cli_steer_busy_path.py
Normal file
146
tests/cli/test_cli_steer_busy_path.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
"""Regression tests for classic-CLI mid-run /steer dispatch.
|
||||
|
||||
Background
|
||||
----------
|
||||
/steer sent while the agent is running used to be queued through
|
||||
``self._pending_input`` alongside ordinary user input. ``process_loop``
|
||||
pulls from that queue and calls ``process_command()`` — but while the
|
||||
agent is running, ``process_loop`` is blocked inside ``self.chat()``.
|
||||
By the time the queued /steer was pulled, ``_agent_running`` had
|
||||
already flipped back to False, so ``process_command()`` took the idle
|
||||
fallback (``"No agent running; queued as next turn"``) and delivered
|
||||
the steer as an ordinary next-turn message.
|
||||
|
||||
The fix dispatches /steer inline on the UI thread when the agent is
|
||||
running — matching the existing pattern for /model — so the steer
|
||||
reaches ``agent.steer()`` (thread-safe) without touching the queue.
|
||||
|
||||
These tests exercise the detector + inline dispatch without starting a
|
||||
prompt_toolkit app.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def _make_cli():
|
||||
"""Create a HermesCLI instance with prompt_toolkit stubbed out."""
|
||||
_clean_config = {
|
||||
"model": {
|
||||
"default": "anthropic/claude-opus-4.6",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "auto",
|
||||
},
|
||||
"display": {"compact": False, "tool_progress": "all"},
|
||||
"agent": {},
|
||||
"terminal": {"env_type": "local"},
|
||||
}
|
||||
clean_env = {"LLM_MODEL": "", "HERMES_MAX_ITERATIONS": ""}
|
||||
prompt_toolkit_stubs = {
|
||||
"prompt_toolkit": MagicMock(),
|
||||
"prompt_toolkit.history": MagicMock(),
|
||||
"prompt_toolkit.styles": MagicMock(),
|
||||
"prompt_toolkit.patch_stdout": MagicMock(),
|
||||
"prompt_toolkit.application": MagicMock(),
|
||||
"prompt_toolkit.layout": MagicMock(),
|
||||
"prompt_toolkit.layout.processors": MagicMock(),
|
||||
"prompt_toolkit.filters": MagicMock(),
|
||||
"prompt_toolkit.layout.dimension": MagicMock(),
|
||||
"prompt_toolkit.layout.menus": MagicMock(),
|
||||
"prompt_toolkit.widgets": MagicMock(),
|
||||
"prompt_toolkit.key_binding": MagicMock(),
|
||||
"prompt_toolkit.completion": MagicMock(),
|
||||
"prompt_toolkit.formatted_text": MagicMock(),
|
||||
"prompt_toolkit.auto_suggest": MagicMock(),
|
||||
}
|
||||
with patch.dict(sys.modules, prompt_toolkit_stubs), patch.dict(
|
||||
"os.environ", clean_env, clear=False
|
||||
):
|
||||
import cli as _cli_mod
|
||||
|
||||
_cli_mod = importlib.reload(_cli_mod)
|
||||
with patch.object(_cli_mod, "get_tool_definitions", return_value=[]), patch.dict(
|
||||
_cli_mod.__dict__, {"CLI_CONFIG": _clean_config}
|
||||
):
|
||||
return _cli_mod.HermesCLI()
|
||||
|
||||
|
||||
class TestSteerInlineDetector:
|
||||
"""_should_handle_steer_command_inline gates the busy-path fast dispatch."""
|
||||
|
||||
def test_detects_steer_when_agent_running(self):
|
||||
cli = _make_cli()
|
||||
cli._agent_running = True
|
||||
assert cli._should_handle_steer_command_inline("/steer focus on error handling") is True
|
||||
|
||||
def test_ignores_steer_when_agent_idle(self):
|
||||
"""Idle-path /steer should fall through to the normal process_loop
|
||||
dispatch so the queue-style fallback message is emitted."""
|
||||
cli = _make_cli()
|
||||
cli._agent_running = False
|
||||
assert cli._should_handle_steer_command_inline("/steer do something") is False
|
||||
|
||||
def test_ignores_non_slash_input(self):
|
||||
cli = _make_cli()
|
||||
cli._agent_running = True
|
||||
assert cli._should_handle_steer_command_inline("steer without slash") is False
|
||||
assert cli._should_handle_steer_command_inline("") is False
|
||||
|
||||
def test_ignores_other_slash_commands(self):
|
||||
cli = _make_cli()
|
||||
cli._agent_running = True
|
||||
assert cli._should_handle_steer_command_inline("/queue hello") is False
|
||||
assert cli._should_handle_steer_command_inline("/stop") is False
|
||||
assert cli._should_handle_steer_command_inline("/help") is False
|
||||
|
||||
def test_ignores_steer_with_attached_images(self):
|
||||
"""Image payloads take the normal path; steer doesn't accept images."""
|
||||
cli = _make_cli()
|
||||
cli._agent_running = True
|
||||
assert cli._should_handle_steer_command_inline("/steer text", has_images=True) is False
|
||||
|
||||
|
||||
class TestSteerBusyPathDispatch:
|
||||
"""When the detector fires, process_command('/steer ...') must call
|
||||
agent.steer() directly rather than the idle-path fallback."""
|
||||
|
||||
def test_process_command_routes_to_agent_steer(self):
|
||||
"""With _agent_running=True and agent.steer present, /steer reaches
|
||||
agent.steer(payload), NOT _pending_input."""
|
||||
cli = _make_cli()
|
||||
cli._agent_running = True
|
||||
cli.agent = MagicMock()
|
||||
cli.agent.steer = MagicMock(return_value=True)
|
||||
# Make sure the idle-path fallback would be observable if taken
|
||||
cli._pending_input = MagicMock()
|
||||
|
||||
cli.process_command("/steer focus on errors")
|
||||
|
||||
cli.agent.steer.assert_called_once_with("focus on errors")
|
||||
cli._pending_input.put.assert_not_called()
|
||||
|
||||
def test_idle_path_queues_as_next_turn(self):
|
||||
"""Control — when the agent is NOT running, /steer correctly falls
|
||||
back to next-turn queue semantics. Demonstrates why the fix was
|
||||
needed: the queue path only works when you can actually drain it."""
|
||||
cli = _make_cli()
|
||||
cli._agent_running = False
|
||||
cli.agent = MagicMock()
|
||||
cli.agent.steer = MagicMock(return_value=True)
|
||||
cli._pending_input = MagicMock()
|
||||
|
||||
cli.process_command("/steer would-be-next-turn")
|
||||
|
||||
# Idle path does NOT call agent.steer
|
||||
cli.agent.steer.assert_not_called()
|
||||
# It puts the payload in the queue as a normal next-turn message
|
||||
cli._pending_input.put.assert_called_once_with("would-be-next-turn")
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
import pytest
|
||||
|
||||
pytest.main([__file__, "-v"])
|
||||
92
tests/cli/test_cli_user_message_preview.py
Normal file
92
tests/cli/test_cli_user_message_preview.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
|
||||
_cli_mod = None
|
||||
|
||||
|
||||
def _make_cli(user_message_preview=None):
|
||||
global _cli_mod
|
||||
clean_config = {
|
||||
"model": {
|
||||
"default": "anthropic/claude-opus-4.6",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "auto",
|
||||
},
|
||||
"display": {
|
||||
"compact": False,
|
||||
"tool_progress": "all",
|
||||
"user_message_preview": user_message_preview or {"first_lines": 2, "last_lines": 2},
|
||||
},
|
||||
"agent": {},
|
||||
"terminal": {"env_type": "local"},
|
||||
}
|
||||
clean_env = {"LLM_MODEL": "", "HERMES_MAX_ITERATIONS": ""}
|
||||
prompt_toolkit_stubs = {
|
||||
"prompt_toolkit": MagicMock(),
|
||||
"prompt_toolkit.history": MagicMock(),
|
||||
"prompt_toolkit.styles": MagicMock(),
|
||||
"prompt_toolkit.patch_stdout": MagicMock(),
|
||||
"prompt_toolkit.application": MagicMock(),
|
||||
"prompt_toolkit.layout": MagicMock(),
|
||||
"prompt_toolkit.layout.processors": MagicMock(),
|
||||
"prompt_toolkit.filters": MagicMock(),
|
||||
"prompt_toolkit.layout.dimension": MagicMock(),
|
||||
"prompt_toolkit.layout.menus": MagicMock(),
|
||||
"prompt_toolkit.widgets": MagicMock(),
|
||||
"prompt_toolkit.key_binding": MagicMock(),
|
||||
"prompt_toolkit.completion": MagicMock(),
|
||||
"prompt_toolkit.formatted_text": MagicMock(),
|
||||
"prompt_toolkit.auto_suggest": MagicMock(),
|
||||
}
|
||||
with patch.dict(sys.modules, prompt_toolkit_stubs), patch.dict("os.environ", clean_env, clear=False):
|
||||
import cli as mod
|
||||
|
||||
mod = importlib.reload(mod)
|
||||
_cli_mod = mod
|
||||
with patch.object(mod, "get_tool_definitions", return_value=[]), patch.dict(mod.__dict__, {"CLI_CONFIG": clean_config}):
|
||||
return mod.HermesCLI()
|
||||
|
||||
|
||||
class TestSubmittedUserMessagePreview:
|
||||
def test_default_preview_shows_first_two_lines_and_last_two_lines(self):
|
||||
cli = _make_cli()
|
||||
|
||||
rendered = cli._format_submitted_user_message_preview(
|
||||
"line1\nline2\nline3\nline4\nline5\nline6"
|
||||
)
|
||||
|
||||
assert "line1" in rendered
|
||||
assert "line2" in rendered
|
||||
assert "line5" in rendered
|
||||
assert "line6" in rendered
|
||||
assert "line3" not in rendered
|
||||
assert "line4" not in rendered
|
||||
assert "(+2 more lines)" in rendered
|
||||
|
||||
def test_preview_can_hide_last_lines(self):
|
||||
cli = _make_cli({"first_lines": 2, "last_lines": 0})
|
||||
|
||||
rendered = cli._format_submitted_user_message_preview(
|
||||
"line1\nline2\nline3\nline4\nline5\nline6"
|
||||
)
|
||||
|
||||
assert "line1" in rendered
|
||||
assert "line2" in rendered
|
||||
assert "line5" not in rendered
|
||||
assert "line6" not in rendered
|
||||
assert "(+4 more lines)" in rendered
|
||||
|
||||
def test_invalid_first_lines_value_falls_back_to_one(self):
|
||||
cli = _make_cli({"first_lines": 0, "last_lines": 2})
|
||||
|
||||
rendered = cli._format_submitted_user_message_preview("line1\nline2\nline3\nline4")
|
||||
|
||||
assert "line1" in rendered
|
||||
assert "line3" in rendered
|
||||
assert "line4" in rendered
|
||||
assert "(+1 more line)" in rendered
|
||||
|
|
@ -183,27 +183,10 @@ class TestFastModeRouting(unittest.TestCase):
|
|||
acp_command=None,
|
||||
acp_args=[],
|
||||
_credential_pool=None,
|
||||
_smart_model_routing={},
|
||||
service_tier="priority",
|
||||
)
|
||||
|
||||
original_runtime = {
|
||||
"api_key": "***",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
"command": None,
|
||||
"args": [],
|
||||
"credential_pool": None,
|
||||
}
|
||||
|
||||
with patch("agent.smart_model_routing.resolve_turn_route", return_value={
|
||||
"model": "gpt-5.4",
|
||||
"runtime": dict(original_runtime),
|
||||
"label": None,
|
||||
"signature": ("gpt-5.4", "openrouter", "https://openrouter.ai/api/v1", "chat_completions", None, ()),
|
||||
}):
|
||||
route = cli_mod.HermesCLI._resolve_turn_agent_config(stub, "hi")
|
||||
route = cli_mod.HermesCLI._resolve_turn_agent_config(stub, "hi")
|
||||
|
||||
# Provider should NOT have changed
|
||||
assert route["runtime"]["provider"] == "openrouter"
|
||||
|
|
@ -222,26 +205,10 @@ class TestFastModeRouting(unittest.TestCase):
|
|||
acp_command=None,
|
||||
acp_args=[],
|
||||
_credential_pool=None,
|
||||
_smart_model_routing={},
|
||||
service_tier="priority",
|
||||
)
|
||||
|
||||
primary_route = {
|
||||
"model": "gpt-5.3-codex",
|
||||
"runtime": {
|
||||
"api_key": "***",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
"command": None,
|
||||
"args": [],
|
||||
"credential_pool": None,
|
||||
},
|
||||
"label": None,
|
||||
"signature": ("gpt-5.3-codex", "openrouter", "https://openrouter.ai/api/v1", "chat_completions", None, ()),
|
||||
}
|
||||
with patch("agent.smart_model_routing.resolve_turn_route", return_value=primary_route):
|
||||
route = cli_mod.HermesCLI._resolve_turn_agent_config(stub, "hi")
|
||||
route = cli_mod.HermesCLI._resolve_turn_agent_config(stub, "hi")
|
||||
|
||||
assert route["runtime"]["provider"] == "openrouter"
|
||||
assert route.get("request_overrides") is None
|
||||
|
|
@ -329,27 +296,10 @@ class TestAnthropicFastMode(unittest.TestCase):
|
|||
acp_command=None,
|
||||
acp_args=[],
|
||||
_credential_pool=None,
|
||||
_smart_model_routing={},
|
||||
service_tier="priority",
|
||||
)
|
||||
|
||||
original_runtime = {
|
||||
"api_key": "***",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
"provider": "anthropic",
|
||||
"api_mode": "anthropic_messages",
|
||||
"command": None,
|
||||
"args": [],
|
||||
"credential_pool": None,
|
||||
}
|
||||
|
||||
with patch("agent.smart_model_routing.resolve_turn_route", return_value={
|
||||
"model": "claude-opus-4-6",
|
||||
"runtime": dict(original_runtime),
|
||||
"label": None,
|
||||
"signature": ("claude-opus-4-6", "anthropic", "https://api.anthropic.com", "anthropic_messages", None, ()),
|
||||
}):
|
||||
route = cli_mod.HermesCLI._resolve_turn_agent_config(stub, "hi")
|
||||
route = cli_mod.HermesCLI._resolve_turn_agent_config(stub, "hi")
|
||||
|
||||
assert route["runtime"]["provider"] == "anthropic"
|
||||
assert route["request_overrides"] == {"speed": "fast"}
|
||||
|
|
|
|||
21
tests/cli/test_gquota_command.py
Normal file
21
tests/cli/test_gquota_command.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def test_gquota_uses_chat_console_when_tui_is_live():
|
||||
from agent.google_oauth import GoogleOAuthError
|
||||
from cli import HermesCLI
|
||||
|
||||
cli = HermesCLI.__new__(HermesCLI)
|
||||
cli.console = MagicMock()
|
||||
cli._app = object()
|
||||
|
||||
live_console = MagicMock()
|
||||
|
||||
with patch("cli.ChatConsole", return_value=live_console), \
|
||||
patch("agent.google_oauth.get_valid_access_token", side_effect=GoogleOAuthError("No Google OAuth credentials found")), \
|
||||
patch("agent.google_oauth.load_credentials", return_value=None), \
|
||||
patch("agent.google_code_assist.retrieve_user_quota"):
|
||||
cli._handle_gquota_command("/gquota")
|
||||
|
||||
assert live_console.print.call_count == 2
|
||||
cli.console.print.assert_not_called()
|
||||
|
|
@ -21,6 +21,7 @@ def test_manual_compress_reports_noop_without_success_banner(capsys):
|
|||
shell.agent = MagicMock()
|
||||
shell.agent.compression_enabled = True
|
||||
shell.agent._cached_system_prompt = ""
|
||||
shell.agent.session_id = shell.session_id # no-op compression: no split
|
||||
shell.agent._compress_context.return_value = (list(history), "")
|
||||
|
||||
def _estimate(messages):
|
||||
|
|
@ -48,6 +49,7 @@ def test_manual_compress_explains_when_token_estimate_rises(capsys):
|
|||
shell.agent = MagicMock()
|
||||
shell.agent.compression_enabled = True
|
||||
shell.agent._cached_system_prompt = ""
|
||||
shell.agent.session_id = shell.session_id # no-op: no split
|
||||
shell.agent._compress_context.return_value = (compressed, "")
|
||||
|
||||
def _estimate(messages):
|
||||
|
|
@ -64,3 +66,64 @@ def test_manual_compress_explains_when_token_estimate_rises(capsys):
|
|||
assert "✅ Compressed: 4 → 3 messages" in output
|
||||
assert "Rough transcript estimate: ~100 → ~120 tokens" in output
|
||||
assert "denser summaries" in output
|
||||
|
||||
|
||||
def test_manual_compress_syncs_session_id_after_split():
|
||||
"""Regression for cli.session_id desync after /compress.
|
||||
|
||||
_compress_context ends the parent session and creates a new child session,
|
||||
mutating agent.session_id. Without syncing, cli.session_id still points
|
||||
at the ended parent — causing /status, /resume, exit summary, and the
|
||||
next end_session() call (e.g. from /resume <id>) to target the wrong row.
|
||||
"""
|
||||
shell = _make_cli()
|
||||
history = _make_history()
|
||||
old_id = shell.session_id
|
||||
new_child_id = "20260101_000000_child1"
|
||||
|
||||
compressed = [
|
||||
{"role": "user", "content": "[summary]"},
|
||||
history[-1],
|
||||
]
|
||||
shell.conversation_history = history
|
||||
shell.agent = MagicMock()
|
||||
shell.agent.compression_enabled = True
|
||||
shell.agent._cached_system_prompt = ""
|
||||
# Simulate _compress_context mutating agent.session_id as a side effect.
|
||||
def _fake_compress(*args, **kwargs):
|
||||
shell.agent.session_id = new_child_id
|
||||
return (compressed, "")
|
||||
shell.agent._compress_context.side_effect = _fake_compress
|
||||
shell.agent.session_id = old_id # starts in sync
|
||||
shell._pending_title = "stale title"
|
||||
|
||||
with patch("agent.model_metadata.estimate_messages_tokens_rough", return_value=100):
|
||||
shell._manual_compress()
|
||||
|
||||
# CLI session_id must now point at the continuation child, not the parent.
|
||||
assert shell.session_id == new_child_id
|
||||
assert shell.session_id != old_id
|
||||
# Pending title must be cleared — titles belong to the parent lineage and
|
||||
# get regenerated for the continuation.
|
||||
assert shell._pending_title is None
|
||||
|
||||
|
||||
def test_manual_compress_no_sync_when_session_id_unchanged():
|
||||
"""If compression is a no-op (agent.session_id didn't change), the CLI
|
||||
must NOT clear _pending_title or otherwise disturb session state.
|
||||
"""
|
||||
shell = _make_cli()
|
||||
history = _make_history()
|
||||
shell.conversation_history = history
|
||||
shell.agent = MagicMock()
|
||||
shell.agent.compression_enabled = True
|
||||
shell.agent._cached_system_prompt = ""
|
||||
shell.agent.session_id = shell.session_id
|
||||
shell.agent._compress_context.return_value = (list(history), "")
|
||||
shell._pending_title = "keep me"
|
||||
|
||||
with patch("agent.model_metadata.estimate_messages_tokens_rough", return_value=100):
|
||||
shell._manual_compress()
|
||||
|
||||
# No split → pending title untouched.
|
||||
assert shell._pending_title == "keep me"
|
||||
|
|
|
|||
|
|
@ -33,6 +33,20 @@ class TestCLIQuickCommands:
|
|||
printed = self._printed_plain(cli.console.print.call_args[0][0])
|
||||
assert printed == "daily-note"
|
||||
|
||||
def test_exec_command_uses_chat_console_when_tui_is_live(self):
|
||||
cli = self._make_cli({"dn": {"type": "exec", "command": "echo daily-note"}})
|
||||
cli._app = object()
|
||||
live_console = MagicMock()
|
||||
|
||||
with patch("cli.ChatConsole", return_value=live_console):
|
||||
result = cli.process_command("/dn")
|
||||
|
||||
assert result is True
|
||||
live_console.print.assert_called_once()
|
||||
printed = self._printed_plain(live_console.print.call_args[0][0])
|
||||
assert printed == "daily-note"
|
||||
cli.console.print.assert_not_called()
|
||||
|
||||
def test_exec_command_stderr_shown_on_no_stdout(self):
|
||||
cli = self._make_cli({"err": {"type": "exec", "command": "echo error >&2"}})
|
||||
result = cli.process_command("/err")
|
||||
|
|
|
|||
|
|
@ -186,6 +186,31 @@ _HERMES_BEHAVIORAL_VARS = frozenset({
|
|||
"HERMES_HOME_MODE",
|
||||
"BROWSER_CDP_URL",
|
||||
"CAMOFOX_URL",
|
||||
# Platform allowlists — not credentials, but if set from any source
|
||||
# (user shell, earlier leaky test, CI env), they change gateway auth
|
||||
# behavior and flake button-authorization tests.
|
||||
"TELEGRAM_ALLOWED_USERS",
|
||||
"DISCORD_ALLOWED_USERS",
|
||||
"WHATSAPP_ALLOWED_USERS",
|
||||
"SLACK_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS",
|
||||
"SIGNAL_GROUP_ALLOWED_USERS",
|
||||
"EMAIL_ALLOWED_USERS",
|
||||
"SMS_ALLOWED_USERS",
|
||||
"MATTERMOST_ALLOWED_USERS",
|
||||
"MATRIX_ALLOWED_USERS",
|
||||
"DINGTALK_ALLOWED_USERS",
|
||||
"FEISHU_ALLOWED_USERS",
|
||||
"WECOM_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOW_ALL_USERS",
|
||||
"TELEGRAM_ALLOW_ALL_USERS",
|
||||
"DISCORD_ALLOW_ALL_USERS",
|
||||
"WHATSAPP_ALLOW_ALL_USERS",
|
||||
"SLACK_ALLOW_ALL_USERS",
|
||||
"SIGNAL_ALLOW_ALL_USERS",
|
||||
"EMAIL_ALLOW_ALL_USERS",
|
||||
"SMS_ALLOW_ALL_USERS",
|
||||
})
|
||||
|
||||
|
||||
|
|
@ -258,6 +283,107 @@ def _isolate_hermes_home(_hermetic_environment):
|
|||
return None
|
||||
|
||||
|
||||
# ── Module-level state reset ───────────────────────────────────────────────
|
||||
#
|
||||
# Python modules are singletons per process, and pytest-xdist workers are
|
||||
# long-lived. Module-level dicts/sets (tool registries, approval state,
|
||||
# interrupt flags) and ContextVars persist across tests in the same worker,
|
||||
# causing tests that pass alone to fail when run with siblings.
|
||||
#
|
||||
# Each entry in this fixture clears state that belongs to a specific module.
|
||||
# New state buckets go here too — this is the single gate that prevents
|
||||
# "works alone, flakes in CI" bugs from state leakage.
|
||||
#
|
||||
# The skill `test-suite-cascade-diagnosis` documents the concrete patterns
|
||||
# this closes; the running example was `test_command_guards` failing 12/15
|
||||
# CI runs because ``tools.approval._session_approved`` carried approvals
|
||||
# from one test's session into another's.
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_module_state():
|
||||
"""Clear module-level mutable state and ContextVars between tests.
|
||||
|
||||
Keeps state from leaking across tests on the same xdist worker. Modules
|
||||
that don't exist yet (test collection before production import) are
|
||||
skipped silently — production import later creates fresh empty state.
|
||||
"""
|
||||
# --- tools.approval — the single biggest source of cross-test pollution ---
|
||||
try:
|
||||
from tools import approval as _approval_mod
|
||||
_approval_mod._session_approved.clear()
|
||||
_approval_mod._session_yolo.clear()
|
||||
_approval_mod._permanent_approved.clear()
|
||||
_approval_mod._pending.clear()
|
||||
_approval_mod._gateway_queues.clear()
|
||||
_approval_mod._gateway_notify_cbs.clear()
|
||||
# ContextVar: reset to empty string so get_current_session_key()
|
||||
# falls through to the env var / default path, matching a fresh
|
||||
# process.
|
||||
_approval_mod._approval_session_key.set("")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- tools.interrupt — per-thread interrupt flag set ---
|
||||
try:
|
||||
from tools import interrupt as _interrupt_mod
|
||||
with _interrupt_mod._lock:
|
||||
_interrupt_mod._interrupted_threads.clear()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- gateway.session_context — 9 ContextVars that represent
|
||||
# the active gateway session. If set in one test and not reset,
|
||||
# the next test's get_session_env() reads stale values.
|
||||
try:
|
||||
from gateway import session_context as _sc_mod
|
||||
for _cv in (
|
||||
_sc_mod._SESSION_PLATFORM,
|
||||
_sc_mod._SESSION_CHAT_ID,
|
||||
_sc_mod._SESSION_CHAT_NAME,
|
||||
_sc_mod._SESSION_THREAD_ID,
|
||||
_sc_mod._SESSION_USER_ID,
|
||||
_sc_mod._SESSION_USER_NAME,
|
||||
_sc_mod._SESSION_KEY,
|
||||
_sc_mod._CRON_AUTO_DELIVER_PLATFORM,
|
||||
_sc_mod._CRON_AUTO_DELIVER_CHAT_ID,
|
||||
_sc_mod._CRON_AUTO_DELIVER_THREAD_ID,
|
||||
):
|
||||
_cv.set(_sc_mod._UNSET)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- tools.env_passthrough — ContextVar<set[str]> with no default ---
|
||||
# LookupError is normal if the test never set it. Setting it to an
|
||||
# empty set unconditionally normalizes the starting state.
|
||||
try:
|
||||
from tools import env_passthrough as _envp_mod
|
||||
_envp_mod._allowed_env_vars_var.set(set())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- tools.credential_files — ContextVar<dict> ---
|
||||
try:
|
||||
from tools import credential_files as _credf_mod
|
||||
_credf_mod._registered_files_var.set({})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- tools.file_tools — per-task read history + file-ops cache ---
|
||||
# _read_tracker accumulates per-task_id read history for loop detection,
|
||||
# capped by _READ_HISTORY_CAP. If entries from a prior test persist, the
|
||||
# cap is hit faster than expected and capacity-related tests flake.
|
||||
try:
|
||||
from tools import file_tools as _ft_mod
|
||||
with _ft_mod._read_tracker_lock:
|
||||
_ft_mod._read_tracker.clear()
|
||||
with _ft_mod._file_ops_lock:
|
||||
_ft_mod._file_ops_cache.clear()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tmp_dir(tmp_path):
|
||||
"""Provide a temporary directory that is cleaned up automatically."""
|
||||
|
|
|
|||
|
|
@ -152,7 +152,6 @@ def test_gateway_run_agent_codex_path_handles_internal_401_refresh(monkeypatch):
|
|||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._running_agents = {}
|
||||
runner._smart_model_routing = {}
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
|
|
|
|||
|
|
@ -772,9 +772,10 @@ class TestRunJobSessionPersistence:
|
|||
pass
|
||||
|
||||
def run_conversation(self, *args, **kwargs):
|
||||
seen["platform"] = os.getenv("HERMES_CRON_AUTO_DELIVER_PLATFORM")
|
||||
seen["chat_id"] = os.getenv("HERMES_CRON_AUTO_DELIVER_CHAT_ID")
|
||||
seen["thread_id"] = os.getenv("HERMES_CRON_AUTO_DELIVER_THREAD_ID")
|
||||
from gateway.session_context import get_session_env
|
||||
seen["platform"] = get_session_env("HERMES_CRON_AUTO_DELIVER_PLATFORM") or None
|
||||
seen["chat_id"] = get_session_env("HERMES_CRON_AUTO_DELIVER_CHAT_ID") or None
|
||||
seen["thread_id"] = get_session_env("HERMES_CRON_AUTO_DELIVER_THREAD_ID") or None
|
||||
return {"final_response": "ok"}
|
||||
|
||||
with patch("cron.scheduler._hermes_home", tmp_path), \
|
||||
|
|
@ -1024,7 +1025,7 @@ class TestRunJobSkillBacked:
|
|||
"id": "multi-skill-job",
|
||||
"name": "multi skill test",
|
||||
"prompt": "Combine the results.",
|
||||
"skills": ["blogwatcher", "find-nearby"],
|
||||
"skills": ["blogwatcher", "maps"],
|
||||
}
|
||||
|
||||
fake_db = MagicMock()
|
||||
|
|
@ -1057,12 +1058,12 @@ class TestRunJobSkillBacked:
|
|||
assert error is None
|
||||
assert final_response == "ok"
|
||||
assert skill_view_mock.call_count == 2
|
||||
assert [call.args[0] for call in skill_view_mock.call_args_list] == ["blogwatcher", "find-nearby"]
|
||||
assert [call.args[0] for call in skill_view_mock.call_args_list] == ["blogwatcher", "maps"]
|
||||
|
||||
prompt_arg = mock_agent.run_conversation.call_args.args[0]
|
||||
assert prompt_arg.index("blogwatcher") < prompt_arg.index("find-nearby")
|
||||
assert prompt_arg.index("blogwatcher") < prompt_arg.index("maps")
|
||||
assert "Instructions for blogwatcher." in prompt_arg
|
||||
assert "Instructions for find-nearby." in prompt_arg
|
||||
assert "Instructions for maps." in prompt_arg
|
||||
assert "Combine the results." in prompt_arg
|
||||
|
||||
|
||||
|
|
@ -1175,6 +1176,204 @@ class TestBuildJobPromptSilentHint:
|
|||
assert system_pos < prompt_pos
|
||||
|
||||
|
||||
class TestParseWakeGate:
|
||||
"""Unit tests for _parse_wake_gate — pure function, no side effects."""
|
||||
|
||||
def test_empty_output_wakes(self):
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
assert _parse_wake_gate("") is True
|
||||
assert _parse_wake_gate(None) is True
|
||||
|
||||
def test_whitespace_only_wakes(self):
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
assert _parse_wake_gate(" \n\n \t\n") is True
|
||||
|
||||
def test_non_json_last_line_wakes(self):
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
assert _parse_wake_gate("hello world") is True
|
||||
assert _parse_wake_gate("line 1\nline 2\nplain text") is True
|
||||
|
||||
def test_json_non_dict_wakes(self):
|
||||
"""Bare arrays, numbers, strings must not be interpreted as a gate."""
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
assert _parse_wake_gate("[1, 2, 3]") is True
|
||||
assert _parse_wake_gate("42") is True
|
||||
assert _parse_wake_gate('"wakeAgent"') is True
|
||||
|
||||
def test_wake_gate_false_skips(self):
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
assert _parse_wake_gate('{"wakeAgent": false}') is False
|
||||
|
||||
def test_wake_gate_true_wakes(self):
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
assert _parse_wake_gate('{"wakeAgent": true}') is True
|
||||
|
||||
def test_wake_gate_missing_wakes(self):
|
||||
"""A JSON dict without a wakeAgent key defaults to waking."""
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
assert _parse_wake_gate('{"data": {"foo": "bar"}}') is True
|
||||
|
||||
def test_non_boolean_false_still_wakes(self):
|
||||
"""Only strict ``False`` skips — truthy/falsy shortcuts are too risky."""
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
assert _parse_wake_gate('{"wakeAgent": 0}') is True
|
||||
assert _parse_wake_gate('{"wakeAgent": null}') is True
|
||||
assert _parse_wake_gate('{"wakeAgent": ""}') is True
|
||||
|
||||
def test_only_last_non_empty_line_parsed(self):
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
multi = 'some log output\nmore output\n{"wakeAgent": false}'
|
||||
assert _parse_wake_gate(multi) is False
|
||||
|
||||
def test_trailing_blank_lines_ignored(self):
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
multi = '{"wakeAgent": false}\n\n\n'
|
||||
assert _parse_wake_gate(multi) is False
|
||||
|
||||
def test_non_last_json_line_does_not_gate(self):
|
||||
"""A JSON gate on an earlier line with plain text after it does NOT trigger."""
|
||||
from cron.scheduler import _parse_wake_gate
|
||||
multi = '{"wakeAgent": false}\nactually this is the real output'
|
||||
assert _parse_wake_gate(multi) is True
|
||||
|
||||
|
||||
class TestRunJobWakeGate:
|
||||
"""Integration tests for run_job wake-gate short-circuit."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_runtime_provider(self):
|
||||
"""Stub ``resolve_runtime_provider`` for wake-gate tests.
|
||||
|
||||
``run_job`` resolves the runtime provider BEFORE constructing
|
||||
``AIAgent``, so these tests must mock ``resolve_runtime_provider``
|
||||
in addition to ``AIAgent`` — otherwise in a hermetic CI env (no
|
||||
API keys), the resolver raises and the test fails before the
|
||||
patched AIAgent is ever reached.
|
||||
"""
|
||||
fake_runtime = {
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"api_key": "test-key",
|
||||
"source": "stub",
|
||||
"requested_provider": None,
|
||||
}
|
||||
with patch(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
return_value=fake_runtime,
|
||||
):
|
||||
yield
|
||||
|
||||
def _make_job(self, name="wake-gate-test", script="check.py"):
|
||||
"""Minimal valid cron job dict for run_job."""
|
||||
return {
|
||||
"id": f"job_{name}",
|
||||
"name": name,
|
||||
"prompt": "Do a thing",
|
||||
"schedule": "*/5 * * * *",
|
||||
"script": script,
|
||||
}
|
||||
|
||||
def test_wake_false_skips_agent_and_returns_silent(self, caplog):
|
||||
"""When _run_job_script output ends with {wakeAgent: false}, the agent
|
||||
is not invoked and run_job returns the SILENT marker so delivery is
|
||||
suppressed."""
|
||||
from cron.scheduler import SILENT_MARKER
|
||||
import cron.scheduler as scheduler
|
||||
|
||||
with patch.object(scheduler, "_run_job_script",
|
||||
return_value=(True, '{"wakeAgent": false}')), \
|
||||
patch("run_agent.AIAgent") as agent_cls:
|
||||
success, doc, final, err = scheduler.run_job(self._make_job())
|
||||
|
||||
assert success is True
|
||||
assert err is None
|
||||
assert final == SILENT_MARKER
|
||||
assert "Script gate returned `wakeAgent=false`" in doc
|
||||
agent_cls.assert_not_called()
|
||||
|
||||
def test_wake_true_runs_agent_with_injected_output(self):
|
||||
"""When the script returns {wakeAgent: true, data: ...}, the agent is
|
||||
invoked and the data line still shows up in the prompt."""
|
||||
import cron.scheduler as scheduler
|
||||
|
||||
script_output = '{"wakeAgent": true, "data": {"new": 3}}'
|
||||
agent = MagicMock()
|
||||
agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "ok", "messages": []
|
||||
})
|
||||
with patch.object(scheduler, "_run_job_script",
|
||||
return_value=(True, script_output)), \
|
||||
patch("run_agent.AIAgent", return_value=agent) as agent_cls:
|
||||
success, doc, final, err = scheduler.run_job(self._make_job())
|
||||
|
||||
agent_cls.assert_called_once()
|
||||
# The script output should be visible in the prompt passed to
|
||||
# run_conversation.
|
||||
call_kwargs = agent.run_conversation.call_args
|
||||
prompt_arg = call_kwargs.args[0] if call_kwargs.args else call_kwargs.kwargs.get("user_message", "")
|
||||
assert script_output in prompt_arg
|
||||
assert success is True
|
||||
assert err is None
|
||||
|
||||
def test_script_runs_only_once_on_wake(self):
|
||||
"""Wake-true path must not re-run the script inside _build_job_prompt
|
||||
(script would execute twice otherwise, wasting work and risking
|
||||
double-side-effects)."""
|
||||
import cron.scheduler as scheduler
|
||||
|
||||
call_count = 0
|
||||
def _script_stub(path):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return (True, "regular output")
|
||||
|
||||
agent = MagicMock()
|
||||
agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "ok", "messages": []
|
||||
})
|
||||
with patch.object(scheduler, "_run_job_script", side_effect=_script_stub), \
|
||||
patch("run_agent.AIAgent", return_value=agent):
|
||||
scheduler.run_job(self._make_job())
|
||||
|
||||
assert call_count == 1, f"script ran {call_count}x, expected exactly 1"
|
||||
|
||||
def test_script_failure_does_not_trigger_gate(self):
|
||||
"""If _run_job_script returns success=False, the gate is NOT evaluated
|
||||
and the agent still runs (the failure is reported as context)."""
|
||||
import cron.scheduler as scheduler
|
||||
|
||||
# Malicious or broken script whose stderr happens to contain the
|
||||
# gate JSON — we must NOT honor it because ran_ok is False.
|
||||
agent = MagicMock()
|
||||
agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "ok", "messages": []
|
||||
})
|
||||
with patch.object(scheduler, "_run_job_script",
|
||||
return_value=(False, '{"wakeAgent": false}')), \
|
||||
patch("run_agent.AIAgent", return_value=agent) as agent_cls:
|
||||
success, doc, final, err = scheduler.run_job(self._make_job())
|
||||
|
||||
agent_cls.assert_called_once() # Agent DID wake despite the gate-like text
|
||||
|
||||
def test_no_script_path_runs_agent_normally(self):
|
||||
"""Regression: jobs without a script still work."""
|
||||
import cron.scheduler as scheduler
|
||||
|
||||
agent = MagicMock()
|
||||
agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "ok", "messages": []
|
||||
})
|
||||
job = self._make_job(script=None)
|
||||
job.pop("script", None)
|
||||
with patch.object(scheduler, "_run_job_script") as script_fn, \
|
||||
patch("run_agent.AIAgent", return_value=agent) as agent_cls:
|
||||
scheduler.run_job(job)
|
||||
|
||||
script_fn.assert_not_called()
|
||||
agent_cls.assert_called_once()
|
||||
|
||||
|
||||
class TestBuildJobPromptMissingSkill:
|
||||
"""Verify that a missing skill logs a warning and does not crash the job."""
|
||||
|
||||
|
|
@ -1259,3 +1458,250 @@ class TestSendMediaViaAdapter:
|
|||
self._run_with_loop(adapter, "123", media_files, None, {"id": "j3"})
|
||||
adapter.send_voice.assert_called_once()
|
||||
adapter.send_image_file.assert_called_once()
|
||||
|
||||
|
||||
class TestParallelTick:
|
||||
"""Verify that tick() runs due jobs concurrently and isolates ContextVars."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_tick_lock(self, tmp_path):
|
||||
"""Point the tick file lock at a per-test temp dir to avoid xdist contention."""
|
||||
lock_dir = tmp_path / "cron"
|
||||
lock_dir.mkdir()
|
||||
with patch("cron.scheduler._LOCK_DIR", lock_dir), \
|
||||
patch("cron.scheduler._LOCK_FILE", lock_dir / ".tick.lock"):
|
||||
yield
|
||||
|
||||
def test_parallel_jobs_run_concurrently(self):
|
||||
"""Two jobs launched in the same tick should overlap in time."""
|
||||
import threading
|
||||
import time
|
||||
|
||||
barrier = threading.Barrier(2, timeout=5)
|
||||
call_order = []
|
||||
|
||||
def mock_run_job(job):
|
||||
"""Each job hits a barrier — both must be active simultaneously."""
|
||||
call_order.append(("start", job["id"]))
|
||||
barrier.wait() # blocks until both threads reach here
|
||||
call_order.append(("end", job["id"]))
|
||||
return (True, "output", "response", None)
|
||||
|
||||
jobs = [
|
||||
{"id": "job-a", "name": "a", "deliver": "local"},
|
||||
{"id": "job-b", "name": "b", "deliver": "local"},
|
||||
]
|
||||
|
||||
with patch("cron.scheduler.get_due_jobs", return_value=jobs), \
|
||||
patch("cron.scheduler.advance_next_run"), \
|
||||
patch("cron.scheduler.run_job", side_effect=mock_run_job), \
|
||||
patch("cron.scheduler.save_job_output", return_value="/tmp/out.md"), \
|
||||
patch("cron.scheduler._deliver_result", return_value=None), \
|
||||
patch("cron.scheduler.mark_job_run"):
|
||||
from cron.scheduler import tick
|
||||
result = tick(verbose=False)
|
||||
|
||||
assert result == 2
|
||||
# Both starts happened before both ends — proof of concurrency
|
||||
starts = [i for i, (action, _) in enumerate(call_order) if action == "start"]
|
||||
ends = [i for i, (action, _) in enumerate(call_order) if action == "end"]
|
||||
assert len(starts) == 2
|
||||
assert len(ends) == 2
|
||||
assert max(starts) < min(ends), f"Jobs not concurrent: {call_order}"
|
||||
|
||||
def test_parallel_jobs_isolated_contextvars(self):
|
||||
"""Each job's ContextVars must be isolated — no cross-contamination."""
|
||||
from gateway.session_context import get_session_env
|
||||
seen = {}
|
||||
|
||||
def mock_run_job(job):
|
||||
origin = job.get("origin", {})
|
||||
# run_job sets ContextVars — verify each job sees its own
|
||||
from gateway.session_context import set_session_vars, clear_session_vars
|
||||
tokens = set_session_vars(
|
||||
platform=origin.get("platform", ""),
|
||||
chat_id=str(origin.get("chat_id", "")),
|
||||
)
|
||||
import time
|
||||
time.sleep(0.05) # give other thread time to set its vars
|
||||
platform = get_session_env("HERMES_SESSION_PLATFORM")
|
||||
chat_id = get_session_env("HERMES_SESSION_CHAT_ID")
|
||||
seen[job["id"]] = {"platform": platform, "chat_id": chat_id}
|
||||
clear_session_vars(tokens)
|
||||
return (True, "output", "response", None)
|
||||
|
||||
jobs = [
|
||||
{"id": "tg-job", "name": "tg", "deliver": "local",
|
||||
"origin": {"platform": "telegram", "chat_id": "111"}},
|
||||
{"id": "dc-job", "name": "dc", "deliver": "local",
|
||||
"origin": {"platform": "discord", "chat_id": "222"}},
|
||||
]
|
||||
|
||||
with patch("cron.scheduler.get_due_jobs", return_value=jobs), \
|
||||
patch("cron.scheduler.advance_next_run"), \
|
||||
patch("cron.scheduler.run_job", side_effect=mock_run_job), \
|
||||
patch("cron.scheduler.save_job_output", return_value="/tmp/out.md"), \
|
||||
patch("cron.scheduler._deliver_result", return_value=None), \
|
||||
patch("cron.scheduler.mark_job_run"):
|
||||
from cron.scheduler import tick
|
||||
tick(verbose=False)
|
||||
|
||||
assert seen["tg-job"] == {"platform": "telegram", "chat_id": "111"}
|
||||
assert seen["dc-job"] == {"platform": "discord", "chat_id": "222"}
|
||||
|
||||
def test_max_parallel_env_var(self, monkeypatch):
|
||||
"""HERMES_CRON_MAX_PARALLEL=1 should restore serial behaviour."""
|
||||
monkeypatch.setenv("HERMES_CRON_MAX_PARALLEL", "1")
|
||||
call_times = []
|
||||
|
||||
def mock_run_job(job):
|
||||
import time
|
||||
call_times.append(("start", job["id"], time.monotonic()))
|
||||
time.sleep(0.05)
|
||||
call_times.append(("end", job["id"], time.monotonic()))
|
||||
return (True, "output", "response", None)
|
||||
|
||||
jobs = [
|
||||
{"id": "s1", "name": "s1", "deliver": "local"},
|
||||
{"id": "s2", "name": "s2", "deliver": "local"},
|
||||
]
|
||||
|
||||
with patch("cron.scheduler.get_due_jobs", return_value=jobs), \
|
||||
patch("cron.scheduler.advance_next_run"), \
|
||||
patch("cron.scheduler.run_job", side_effect=mock_run_job), \
|
||||
patch("cron.scheduler.save_job_output", return_value="/tmp/out.md"), \
|
||||
patch("cron.scheduler._deliver_result", return_value=None), \
|
||||
patch("cron.scheduler.mark_job_run"):
|
||||
from cron.scheduler import tick
|
||||
result = tick(verbose=False)
|
||||
|
||||
assert result == 2
|
||||
# With max_workers=1, second job starts after first ends
|
||||
end_s1 = [t for action, jid, t in call_times if action == "end" and jid == "s1"][0]
|
||||
start_s2 = [t for action, jid, t in call_times if action == "start" and jid == "s2"][0]
|
||||
assert start_s2 >= end_s1, "Jobs ran concurrently despite max_parallel=1"
|
||||
|
||||
|
||||
class TestDeliverResultTimeoutCancelsFuture:
|
||||
"""When future.result(timeout=60) raises TimeoutError in the live
|
||||
adapter delivery path, _deliver_result must cancel the orphan
|
||||
coroutine so it cannot duplicate-send after the standalone fallback.
|
||||
"""
|
||||
|
||||
def test_live_adapter_timeout_cancels_future_and_falls_back(self):
|
||||
"""End-to-end: live adapter hangs past the 60s budget, _deliver_result
|
||||
patches the timeout down to a fast value, confirms future.cancel() fires,
|
||||
and verifies the standalone fallback path still delivers."""
|
||||
from gateway.config import Platform
|
||||
from concurrent.futures import Future
|
||||
|
||||
# Live adapter whose send() coroutine never resolves within the budget
|
||||
adapter = AsyncMock()
|
||||
adapter.send.return_value = MagicMock(success=True)
|
||||
|
||||
pconfig = MagicMock()
|
||||
pconfig.enabled = True
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.platforms = {Platform.TELEGRAM: pconfig}
|
||||
|
||||
loop = MagicMock()
|
||||
loop.is_running.return_value = True
|
||||
|
||||
# A real concurrent.futures.Future so .cancel() has real semantics,
|
||||
# but we override .result() to raise TimeoutError exactly like the
|
||||
# 60s wait firing in production.
|
||||
captured_future = Future()
|
||||
cancel_calls = []
|
||||
original_cancel = captured_future.cancel
|
||||
|
||||
def tracking_cancel():
|
||||
cancel_calls.append(True)
|
||||
return original_cancel()
|
||||
|
||||
captured_future.cancel = tracking_cancel
|
||||
captured_future.result = MagicMock(side_effect=TimeoutError("timed out"))
|
||||
|
||||
def fake_run_coro(coro, _loop):
|
||||
coro.close()
|
||||
return captured_future
|
||||
|
||||
job = {
|
||||
"id": "timeout-job",
|
||||
"deliver": "origin",
|
||||
"origin": {"platform": "telegram", "chat_id": "123"},
|
||||
}
|
||||
|
||||
standalone_send = AsyncMock(return_value={"success": True})
|
||||
|
||||
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
|
||||
patch("cron.scheduler.load_config", return_value={"cron": {"wrap_response": False}}), \
|
||||
patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro), \
|
||||
patch("tools.send_message_tool._send_to_platform", new=standalone_send):
|
||||
result = _deliver_result(
|
||||
job,
|
||||
"Hello world",
|
||||
adapters={Platform.TELEGRAM: adapter},
|
||||
loop=loop,
|
||||
)
|
||||
|
||||
# 1. The orphan future was cancelled on timeout (the bug fix)
|
||||
assert cancel_calls == [True], "future.cancel() must fire on TimeoutError"
|
||||
# 2. The standalone fallback delivered — no double send, no silent drop
|
||||
assert result is None, f"expected successful delivery, got error: {result!r}"
|
||||
standalone_send.assert_awaited_once()
|
||||
|
||||
|
||||
class TestSendMediaTimeoutCancelsFuture:
|
||||
"""Same orphan-coroutine guarantee for _send_media_via_adapter's
|
||||
future.result(timeout=30) call. If this times out mid-batch, the
|
||||
in-flight coroutine must be cancelled before the next file is tried.
|
||||
"""
|
||||
|
||||
def test_media_send_timeout_cancels_future_and_continues(self):
|
||||
"""End-to-end: _send_media_via_adapter with a future whose .result()
|
||||
raises TimeoutError. Assert cancel() fires and the loop proceeds
|
||||
to the next file rather than hanging or crashing."""
|
||||
from concurrent.futures import Future
|
||||
|
||||
adapter = MagicMock()
|
||||
adapter.send_image_file = AsyncMock()
|
||||
adapter.send_video = AsyncMock()
|
||||
|
||||
# First file: future that times out. Second file: future that resolves OK.
|
||||
timeout_future = Future()
|
||||
timeout_cancel_calls = []
|
||||
original_cancel = timeout_future.cancel
|
||||
|
||||
def tracking_cancel():
|
||||
timeout_cancel_calls.append(True)
|
||||
return original_cancel()
|
||||
|
||||
timeout_future.cancel = tracking_cancel
|
||||
timeout_future.result = MagicMock(side_effect=TimeoutError("timed out"))
|
||||
|
||||
ok_future = Future()
|
||||
ok_future.set_result(MagicMock(success=True))
|
||||
|
||||
futures_iter = iter([timeout_future, ok_future])
|
||||
|
||||
def fake_run_coro(coro, _loop):
|
||||
coro.close()
|
||||
return next(futures_iter)
|
||||
|
||||
media_files = [
|
||||
("/tmp/slow.png", False), # times out
|
||||
("/tmp/fast.mp4", False), # succeeds
|
||||
]
|
||||
|
||||
loop = MagicMock()
|
||||
job = {"id": "media-timeout"}
|
||||
|
||||
with patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
|
||||
# Should not raise — the except Exception clause swallows the timeout
|
||||
_send_media_via_adapter(adapter, "chat-1", media_files, None, loop, job)
|
||||
|
||||
# 1. The timed-out future was cancelled (the bug fix)
|
||||
assert timeout_cancel_calls == [True], "future.cancel() must fire on TimeoutError"
|
||||
# 2. Second file still got dispatched — one timeout doesn't abort the batch
|
||||
adapter.send_video.assert_called_once()
|
||||
assert adapter.send_video.call_args[1]["video_path"] == "/tmp/fast.mp4"
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ No LLM, no real platform connections.
|
|||
import asyncio
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
|
@ -22,6 +22,7 @@ from gateway.config import GatewayConfig, Platform, PlatformConfig
|
|||
from gateway.platforms.base import MessageEvent, SendResult
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
|
||||
E2E_MESSAGE_SETTLE_DELAY = 0.3
|
||||
|
||||
# Platform library mocks
|
||||
|
||||
|
|
@ -113,8 +114,9 @@ _ensure_telegram_mock()
|
|||
_ensure_discord_mock()
|
||||
_ensure_slack_mock()
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
import discord # noqa: E402 — mocked above
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
import gateway.platforms.slack as _slack_mod # noqa: E402
|
||||
_slack_mod.SLACK_AVAILABLE = True
|
||||
|
|
@ -264,3 +266,140 @@ def runner(platform, session_entry):
|
|||
@pytest.fixture()
|
||||
def adapter(platform, runner):
|
||||
return make_adapter(platform, runner)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Discord helpers and fixtures
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
BOT_USER_ID = 99999
|
||||
BOT_USER_NAME = "HermesBot"
|
||||
CHANNEL_ID = 22222
|
||||
GUILD_ID = 44444
|
||||
THREAD_ID = 33333
|
||||
MESSAGE_ID_COUNTER = 0
|
||||
|
||||
|
||||
def _next_message_id() -> int:
|
||||
global MESSAGE_ID_COUNTER
|
||||
MESSAGE_ID_COUNTER += 1
|
||||
return 70000 + MESSAGE_ID_COUNTER
|
||||
|
||||
|
||||
def make_fake_bot_user():
|
||||
return SimpleNamespace(
|
||||
id=BOT_USER_ID, name=BOT_USER_NAME,
|
||||
display_name=BOT_USER_NAME, bot=True,
|
||||
)
|
||||
|
||||
|
||||
def make_fake_guild(guild_id: int = GUILD_ID, name: str = "Test Server"):
|
||||
return SimpleNamespace(id=guild_id, name=name)
|
||||
|
||||
|
||||
def make_fake_text_channel(channel_id: int = CHANNEL_ID, name: str = "general", guild=None):
|
||||
return SimpleNamespace(
|
||||
id=channel_id, name=name,
|
||||
guild=guild or make_fake_guild(),
|
||||
topic=None, type=0,
|
||||
)
|
||||
|
||||
|
||||
def make_fake_dm_channel(channel_id: int = 55555):
|
||||
ch = MagicMock(spec=[])
|
||||
ch.id = channel_id
|
||||
ch.name = "DM"
|
||||
ch.topic = None
|
||||
ch.__class__ = discord.DMChannel
|
||||
return ch
|
||||
|
||||
|
||||
def make_fake_thread(thread_id: int = THREAD_ID, name: str = "test-thread", parent=None):
|
||||
th = MagicMock(spec=[])
|
||||
th.id = thread_id
|
||||
th.name = name
|
||||
th.parent = parent or make_fake_text_channel()
|
||||
th.parent_id = th.parent.id
|
||||
th.guild = th.parent.guild
|
||||
th.topic = None
|
||||
th.type = 11
|
||||
th.__class__ = discord.Thread
|
||||
return th
|
||||
|
||||
|
||||
def make_discord_message(
|
||||
*, content: str = "hello", author=None, channel=None, mentions=None,
|
||||
attachments=None, message_id: int = None,
|
||||
):
|
||||
if message_id is None:
|
||||
message_id = _next_message_id()
|
||||
if author is None:
|
||||
author = SimpleNamespace(
|
||||
id=11111, name="testuser", display_name="testuser", bot=False,
|
||||
)
|
||||
if channel is None:
|
||||
channel = make_fake_text_channel()
|
||||
if mentions is None:
|
||||
mentions = []
|
||||
if attachments is None:
|
||||
attachments = []
|
||||
|
||||
return SimpleNamespace(
|
||||
id=message_id, content=content, author=author, channel=channel,
|
||||
mentions=mentions, attachments=attachments,
|
||||
type=getattr(discord, "MessageType", SimpleNamespace()).default,
|
||||
reference=None, created_at=datetime.now(timezone.utc),
|
||||
create_thread=AsyncMock(),
|
||||
)
|
||||
|
||||
|
||||
def get_response_text(adapter) -> str | None:
|
||||
"""Extract the response text from adapter.send() call args, or None if not called."""
|
||||
if not adapter.send.called:
|
||||
return None
|
||||
return adapter.send.call_args[1].get("content") or adapter.send.call_args[0][1]
|
||||
|
||||
|
||||
def _make_discord_adapter_wired(runner=None):
|
||||
"""Create a DiscordAdapter wired to a GatewayRunner for e2e tests."""
|
||||
if runner is None:
|
||||
runner = make_runner(Platform.DISCORD)
|
||||
|
||||
config = PlatformConfig(enabled=True, token="e2e-test-token")
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
with patch.object(ThreadParticipationTracker, "_load", return_value=set()):
|
||||
adapter = DiscordAdapter(config)
|
||||
|
||||
bot_user = make_fake_bot_user()
|
||||
adapter._client = SimpleNamespace(
|
||||
user=bot_user,
|
||||
get_channel=lambda _id: None,
|
||||
fetch_channel=AsyncMock(),
|
||||
)
|
||||
|
||||
adapter.send = AsyncMock(return_value=SendResult(success=True, message_id="e2e-resp-1"))
|
||||
adapter.send_typing = AsyncMock()
|
||||
adapter.set_message_handler(runner._handle_message)
|
||||
runner.adapters[Platform.DISCORD] = adapter
|
||||
|
||||
return adapter, runner
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def discord_setup():
|
||||
return _make_discord_adapter_wired()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def discord_adapter(discord_setup):
|
||||
return discord_setup[0]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def discord_runner(discord_setup):
|
||||
return discord_setup[1]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def bot_user():
|
||||
return make_fake_bot_user()
|
||||
|
|
|
|||
106
tests/e2e/test_discord_adapter.py
Normal file
106
tests/e2e/test_discord_adapter.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
"""Minimal e2e tests for Discord mention stripping + /command detection.
|
||||
|
||||
Covers the fix for slash commands not being recognized when sent via
|
||||
@mention in a channel, especially after auto-threading.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.e2e.conftest import (
|
||||
BOT_USER_ID,
|
||||
E2E_MESSAGE_SETTLE_DELAY,
|
||||
get_response_text,
|
||||
make_discord_message,
|
||||
make_fake_dm_channel,
|
||||
make_fake_thread,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def dispatch(adapter, msg):
|
||||
await adapter._handle_message(msg)
|
||||
await asyncio.sleep(E2E_MESSAGE_SETTLE_DELAY)
|
||||
|
||||
|
||||
class TestMentionStrippedCommandDispatch:
|
||||
async def test_mention_then_command(self, discord_adapter, bot_user):
|
||||
"""<@BOT> /help → mention stripped, /help dispatched."""
|
||||
msg = make_discord_message(
|
||||
content=f"<@{BOT_USER_ID}> /help",
|
||||
mentions=[bot_user],
|
||||
)
|
||||
await dispatch(discord_adapter, msg)
|
||||
response = get_response_text(discord_adapter)
|
||||
assert response is not None
|
||||
assert "/new" in response
|
||||
|
||||
async def test_nickname_mention_then_command(self, discord_adapter, bot_user):
|
||||
"""<@!BOT> /help → nickname mention also stripped, /help works."""
|
||||
msg = make_discord_message(
|
||||
content=f"<@!{BOT_USER_ID}> /help",
|
||||
mentions=[bot_user],
|
||||
)
|
||||
await dispatch(discord_adapter, msg)
|
||||
response = get_response_text(discord_adapter)
|
||||
assert response is not None
|
||||
assert "/new" in response
|
||||
|
||||
async def test_text_before_command_not_detected(self, discord_adapter, bot_user):
|
||||
"""'<@BOT> something else /help' → mention stripped, but 'something else /help'
|
||||
doesn't start with / so it's treated as text, not a command."""
|
||||
msg = make_discord_message(
|
||||
content=f"<@{BOT_USER_ID}> something else /help",
|
||||
mentions=[bot_user],
|
||||
)
|
||||
await dispatch(discord_adapter, msg)
|
||||
# Message is accepted (not dropped by mention gate), but since it doesn't
|
||||
# start with / it's routed as text — no command output, and no agent in this
|
||||
# mock setup means no send call either.
|
||||
response = get_response_text(discord_adapter)
|
||||
assert response is None or "/new" not in response
|
||||
|
||||
async def test_no_mention_in_channel_dropped(self, discord_adapter):
|
||||
"""Message without @mention in server channel → silently dropped."""
|
||||
msg = make_discord_message(content="/help", mentions=[])
|
||||
await dispatch(discord_adapter, msg)
|
||||
assert get_response_text(discord_adapter) is None
|
||||
|
||||
async def test_dm_no_mention_needed(self, discord_adapter):
|
||||
"""DMs don't require @mention — /help works directly."""
|
||||
dm = make_fake_dm_channel()
|
||||
msg = make_discord_message(content="/help", channel=dm, mentions=[])
|
||||
await dispatch(discord_adapter, msg)
|
||||
response = get_response_text(discord_adapter)
|
||||
assert response is not None
|
||||
assert "/new" in response
|
||||
|
||||
|
||||
class TestAutoThreadingPreservesCommand:
|
||||
async def test_command_detected_after_auto_thread(self, discord_adapter, bot_user, monkeypatch):
|
||||
"""@mention /help in channel with auto-thread → thread created AND command dispatched."""
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
|
||||
fake_thread = make_fake_thread(thread_id=90001, name="help")
|
||||
msg = make_discord_message(
|
||||
content=f"<@{BOT_USER_ID}> /help",
|
||||
mentions=[bot_user],
|
||||
)
|
||||
|
||||
# Simulate discord.py restoring the original raw content (with mention)
|
||||
# after create_thread(), which undoes any prior mention stripping.
|
||||
original_content = msg.content
|
||||
|
||||
async def clobber_content(**kwargs):
|
||||
msg.content = original_content
|
||||
return fake_thread
|
||||
|
||||
msg.create_thread = AsyncMock(side_effect=clobber_content)
|
||||
await dispatch(discord_adapter, msg)
|
||||
|
||||
msg.create_thread.assert_awaited_once()
|
||||
response = get_response_text(discord_adapter)
|
||||
assert response is not None
|
||||
assert "/new" in response
|
||||
|
|
@ -108,6 +108,7 @@ def make_restart_runner(
|
|||
runner.hooks.emit = AsyncMock()
|
||||
runner.pairing_store = MagicMock()
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store._entries = {}
|
||||
runner.delivery_router = MagicMock()
|
||||
|
||||
platform_adapter = adapter or RestartTestAdapter()
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ Tests cover:
|
|||
- Error handling (invalid JSON, missing fields)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
|
|
@ -25,6 +26,7 @@ from gateway.config import GatewayConfig, Platform, PlatformConfig
|
|||
from gateway.platforms.api_server import (
|
||||
APIServerAdapter,
|
||||
ResponseStore,
|
||||
_IdempotencyCache,
|
||||
_CORS_HEADERS,
|
||||
_derive_chat_session_id,
|
||||
check_api_server_requirements,
|
||||
|
|
@ -104,6 +106,95 @@ class TestResponseStore:
|
|||
assert store.delete("resp_missing") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _IdempotencyCache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIdempotencyCache:
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_same_key_and_fingerprint_runs_once(self):
|
||||
cache = _IdempotencyCache()
|
||||
gate = asyncio.Event()
|
||||
started = asyncio.Event()
|
||||
calls = 0
|
||||
|
||||
async def compute():
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
started.set()
|
||||
await gate.wait()
|
||||
return ("response", {"total_tokens": 1})
|
||||
|
||||
first = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute))
|
||||
second = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute))
|
||||
|
||||
await started.wait()
|
||||
assert calls == 1
|
||||
|
||||
gate.set()
|
||||
first_result, second_result = await asyncio.gather(first, second)
|
||||
|
||||
assert first_result == second_result == ("response", {"total_tokens": 1})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_fingerprint_does_not_reuse_inflight_task(self):
|
||||
cache = _IdempotencyCache()
|
||||
gate = asyncio.Event()
|
||||
started = asyncio.Event()
|
||||
calls = 0
|
||||
|
||||
async def compute():
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
result = calls
|
||||
if calls == 2:
|
||||
started.set()
|
||||
await gate.wait()
|
||||
return result
|
||||
|
||||
first = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute))
|
||||
second = asyncio.create_task(cache.get_or_set("idem-key", "fp-2", compute))
|
||||
|
||||
await started.wait()
|
||||
assert calls == 2
|
||||
|
||||
gate.set()
|
||||
results = await asyncio.gather(first, second)
|
||||
|
||||
assert sorted(results) == [1, 2]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancelled_waiter_does_not_drop_shared_inflight_task(self):
|
||||
cache = _IdempotencyCache()
|
||||
gate = asyncio.Event()
|
||||
started = asyncio.Event()
|
||||
calls = 0
|
||||
|
||||
async def compute():
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
started.set()
|
||||
await gate.wait()
|
||||
return "response"
|
||||
|
||||
first = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute))
|
||||
|
||||
await started.wait()
|
||||
assert calls == 1
|
||||
|
||||
first.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await first
|
||||
|
||||
second = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute))
|
||||
await asyncio.sleep(0)
|
||||
assert calls == 1
|
||||
|
||||
gate.set()
|
||||
assert await second == "response"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adapter initialization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ from aiohttp.test_utils import TestClient, TestServer
|
|||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.api_server import APIServerAdapter, cors_middleware
|
||||
|
||||
_MOD = "gateway.platforms.api_server"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
|
|
@ -83,10 +85,10 @@ class TestListJobs:
|
|||
"""GET /api/jobs returns job list."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_list", return_value=[SAMPLE_JOB]
|
||||
with patch(
|
||||
f"{_MOD}._CRON_AVAILABLE", True
|
||||
), patch(
|
||||
f"{_MOD}._cron_list", return_value=[SAMPLE_JOB]
|
||||
):
|
||||
resp = await cli.get("/api/jobs")
|
||||
assert resp.status == 200
|
||||
|
|
@ -104,10 +106,10 @@ class TestListJobs:
|
|||
app = _create_app(adapter)
|
||||
mock_list = MagicMock(return_value=[SAMPLE_JOB])
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_list", mock_list
|
||||
with patch(
|
||||
f"{_MOD}._CRON_AVAILABLE", True
|
||||
), patch(
|
||||
f"{_MOD}._cron_list", mock_list
|
||||
):
|
||||
resp = await cli.get("/api/jobs?include_disabled=true")
|
||||
assert resp.status == 200
|
||||
|
|
@ -119,10 +121,10 @@ class TestListJobs:
|
|||
app = _create_app(adapter)
|
||||
mock_list = MagicMock(return_value=[])
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_list", mock_list
|
||||
with patch(
|
||||
f"{_MOD}._CRON_AVAILABLE", True
|
||||
), patch(
|
||||
f"{_MOD}._cron_list", mock_list
|
||||
):
|
||||
resp = await cli.get("/api/jobs")
|
||||
assert resp.status == 200
|
||||
|
|
@ -140,10 +142,10 @@ class TestCreateJob:
|
|||
app = _create_app(adapter)
|
||||
mock_create = MagicMock(return_value=SAMPLE_JOB)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_create", mock_create
|
||||
with patch(
|
||||
f"{_MOD}._CRON_AVAILABLE", True
|
||||
), patch(
|
||||
f"{_MOD}._cron_create", mock_create
|
||||
):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test-job",
|
||||
|
|
@ -164,7 +166,7 @@ class TestCreateJob:
|
|||
"""POST /api/jobs without name returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"schedule": "*/5 * * * *",
|
||||
"prompt": "do something",
|
||||
|
|
@ -178,7 +180,7 @@ class TestCreateJob:
|
|||
"""POST /api/jobs with name > 200 chars returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "x" * 201,
|
||||
"schedule": "*/5 * * * *",
|
||||
|
|
@ -192,7 +194,7 @@ class TestCreateJob:
|
|||
"""POST /api/jobs with prompt > 5000 chars returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test-job",
|
||||
"schedule": "*/5 * * * *",
|
||||
|
|
@ -207,7 +209,7 @@ class TestCreateJob:
|
|||
"""POST /api/jobs with repeat=0 returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test-job",
|
||||
"schedule": "*/5 * * * *",
|
||||
|
|
@ -222,7 +224,7 @@ class TestCreateJob:
|
|||
"""POST /api/jobs without schedule returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test-job",
|
||||
})
|
||||
|
|
@ -242,10 +244,10 @@ class TestGetJob:
|
|||
app = _create_app(adapter)
|
||||
mock_get = MagicMock(return_value=SAMPLE_JOB)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_get", mock_get
|
||||
with patch(
|
||||
f"{_MOD}._CRON_AVAILABLE", True
|
||||
), patch(
|
||||
f"{_MOD}._cron_get", mock_get
|
||||
):
|
||||
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 200
|
||||
|
|
@ -259,10 +261,10 @@ class TestGetJob:
|
|||
app = _create_app(adapter)
|
||||
mock_get = MagicMock(return_value=None)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_get", mock_get
|
||||
with patch(
|
||||
f"{_MOD}._CRON_AVAILABLE", True
|
||||
), patch(
|
||||
f"{_MOD}._cron_get", mock_get
|
||||
):
|
||||
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 404
|
||||
|
|
@ -272,7 +274,7 @@ class TestGetJob:
|
|||
"""GET /api/jobs/{id} with non-hex id returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", True):
|
||||
resp = await cli.get("/api/jobs/not-a-valid-hex!")
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
|
|
@ -291,10 +293,10 @@ class TestUpdateJob:
|
|||
updated_job = {**SAMPLE_JOB, "name": "updated-name"}
|
||||
mock_update = MagicMock(return_value=updated_job)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_update", mock_update
|
||||
with patch(
|
||||
f"{_MOD}._CRON_AVAILABLE", True
|
||||
), patch(
|
||||
f"{_MOD}._cron_update", mock_update
|
||||
):
|
||||
resp = await cli.patch(
|
||||
f"/api/jobs/{VALID_JOB_ID}",
|
||||
|
|
@ -317,10 +319,10 @@ class TestUpdateJob:
|
|||
updated_job = {**SAMPLE_JOB, "name": "new-name"}
|
||||
mock_update = MagicMock(return_value=updated_job)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_update", mock_update
|
||||
with patch(
|
||||
f"{_MOD}._CRON_AVAILABLE", True
|
||||
), patch(
|
||||
f"{_MOD}._cron_update", mock_update
|
||||
):
|
||||
resp = await cli.patch(
|
||||
f"/api/jobs/{VALID_JOB_ID}",
|
||||
|
|
@ -342,7 +344,7 @@ class TestUpdateJob:
|
|||
"""PATCH /api/jobs/{id} with only unknown fields returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", True):
|
||||
resp = await cli.patch(
|
||||
f"/api/jobs/{VALID_JOB_ID}",
|
||||
json={"evil_field": "malicious"},
|
||||
|
|
@ -363,10 +365,10 @@ class TestDeleteJob:
|
|||
app = _create_app(adapter)
|
||||
mock_remove = MagicMock(return_value=True)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_remove", mock_remove
|
||||
with patch(
|
||||
f"{_MOD}._CRON_AVAILABLE", True
|
||||
), patch(
|
||||
f"{_MOD}._cron_remove", mock_remove
|
||||
):
|
||||
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 200
|
||||
|
|
@ -380,10 +382,10 @@ class TestDeleteJob:
|
|||
app = _create_app(adapter)
|
||||
mock_remove = MagicMock(return_value=False)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_remove", mock_remove
|
||||
with patch(
|
||||
f"{_MOD}._CRON_AVAILABLE", True
|
||||
), patch(
|
||||
f"{_MOD}._cron_remove", mock_remove
|
||||
):
|
||||
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 404
|
||||
|
|
@ -401,10 +403,10 @@ class TestPauseJob:
|
|||
paused_job = {**SAMPLE_JOB, "enabled": False}
|
||||
mock_pause = MagicMock(return_value=paused_job)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_pause", mock_pause
|
||||
with patch(
|
||||
f"{_MOD}._CRON_AVAILABLE", True
|
||||
), patch(
|
||||
f"{_MOD}._cron_pause", mock_pause
|
||||
):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/pause")
|
||||
assert resp.status == 200
|
||||
|
|
@ -426,10 +428,10 @@ class TestResumeJob:
|
|||
resumed_job = {**SAMPLE_JOB, "enabled": True}
|
||||
mock_resume = MagicMock(return_value=resumed_job)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_resume", mock_resume
|
||||
with patch(
|
||||
f"{_MOD}._CRON_AVAILABLE", True
|
||||
), patch(
|
||||
f"{_MOD}._cron_resume", mock_resume
|
||||
):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/resume")
|
||||
assert resp.status == 200
|
||||
|
|
@ -451,10 +453,10 @@ class TestRunJob:
|
|||
triggered_job = {**SAMPLE_JOB, "last_run": "2025-01-01T00:00:00Z"}
|
||||
mock_trigger = MagicMock(return_value=triggered_job)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_trigger", mock_trigger
|
||||
with patch(
|
||||
f"{_MOD}._CRON_AVAILABLE", True
|
||||
), patch(
|
||||
f"{_MOD}._cron_trigger", mock_trigger
|
||||
):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/run")
|
||||
assert resp.status == 200
|
||||
|
|
@ -473,7 +475,7 @@ class TestAuthRequired:
|
|||
"""GET /api/jobs without API key returns 401 when key is set."""
|
||||
app = _create_app(auth_adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", True):
|
||||
resp = await cli.get("/api/jobs")
|
||||
assert resp.status == 401
|
||||
|
||||
|
|
@ -482,7 +484,7 @@ class TestAuthRequired:
|
|||
"""POST /api/jobs without API key returns 401 when key is set."""
|
||||
app = _create_app(auth_adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test", "schedule": "* * * * *",
|
||||
})
|
||||
|
|
@ -493,7 +495,7 @@ class TestAuthRequired:
|
|||
"""GET /api/jobs/{id} without API key returns 401 when key is set."""
|
||||
app = _create_app(auth_adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", True):
|
||||
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 401
|
||||
|
||||
|
|
@ -502,7 +504,7 @@ class TestAuthRequired:
|
|||
"""DELETE /api/jobs/{id} without API key returns 401."""
|
||||
app = _create_app(auth_adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", True):
|
||||
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 401
|
||||
|
||||
|
|
@ -512,10 +514,10 @@ class TestAuthRequired:
|
|||
app = _create_app(auth_adapter)
|
||||
mock_list = MagicMock(return_value=[])
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_list", mock_list
|
||||
with patch(
|
||||
f"{_MOD}._CRON_AVAILABLE", True
|
||||
), patch(
|
||||
f"{_MOD}._cron_list", mock_list
|
||||
):
|
||||
resp = await cli.get(
|
||||
"/api/jobs",
|
||||
|
|
@ -534,7 +536,7 @@ class TestCronUnavailable:
|
|||
"""GET /api/jobs returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", False):
|
||||
resp = await cli.get("/api/jobs")
|
||||
assert resp.status == 501
|
||||
data = await resp.json()
|
||||
|
|
@ -551,8 +553,8 @@ class TestCronUnavailable:
|
|||
return SAMPLE_JOB
|
||||
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True), patch.object(
|
||||
APIServerAdapter, "_cron_pause", staticmethod(_plain_pause)
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", True), patch(
|
||||
f"{_MOD}._cron_pause", _plain_pause
|
||||
):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/pause")
|
||||
assert resp.status == 200
|
||||
|
|
@ -571,8 +573,8 @@ class TestCronUnavailable:
|
|||
return [SAMPLE_JOB]
|
||||
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True), patch.object(
|
||||
APIServerAdapter, "_cron_list", staticmethod(_plain_list)
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", True), patch(
|
||||
f"{_MOD}._cron_list", _plain_list
|
||||
):
|
||||
resp = await cli.get("/api/jobs?include_disabled=true")
|
||||
assert resp.status == 200
|
||||
|
|
@ -593,8 +595,8 @@ class TestCronUnavailable:
|
|||
return updated_job
|
||||
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True), patch.object(
|
||||
APIServerAdapter, "_cron_update", staticmethod(_plain_update)
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", True), patch(
|
||||
f"{_MOD}._cron_update", _plain_update
|
||||
):
|
||||
resp = await cli.patch(
|
||||
f"/api/jobs/{VALID_JOB_ID}",
|
||||
|
|
@ -611,7 +613,7 @@ class TestCronUnavailable:
|
|||
"""POST /api/jobs returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", False):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test", "schedule": "* * * * *",
|
||||
})
|
||||
|
|
@ -622,7 +624,7 @@ class TestCronUnavailable:
|
|||
"""GET /api/jobs/{id} returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", False):
|
||||
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 501
|
||||
|
||||
|
|
@ -631,7 +633,7 @@ class TestCronUnavailable:
|
|||
"""DELETE /api/jobs/{id} returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", False):
|
||||
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 501
|
||||
|
||||
|
|
@ -640,7 +642,7 @@ class TestCronUnavailable:
|
|||
"""POST /api/jobs/{id}/pause returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", False):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/pause")
|
||||
assert resp.status == 501
|
||||
|
||||
|
|
@ -649,7 +651,7 @@ class TestCronUnavailable:
|
|||
"""POST /api/jobs/{id}/resume returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", False):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/resume")
|
||||
assert resp.status == 501
|
||||
|
||||
|
|
@ -658,6 +660,6 @@ class TestCronUnavailable:
|
|||
"""POST /api/jobs/{id}/run returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
with patch(f"{_MOD}._CRON_AVAILABLE", False):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/run")
|
||||
assert resp.status == 501
|
||||
|
|
|
|||
308
tests/gateway/test_api_server_multimodal.py
Normal file
308
tests/gateway/test_api_server_multimodal.py
Normal file
|
|
@ -0,0 +1,308 @@
|
|||
"""End-to-end tests for inline image inputs on /v1/chat/completions and /v1/responses.
|
||||
|
||||
Covers the multimodal normalization path added to the API server. Unlike the
|
||||
adapter-level tests that patch ``_run_agent``, these tests patch
|
||||
``AIAgent.run_conversation`` instead so the adapter's full request-handling
|
||||
path (including the ``run_agent`` prologue that used to crash on list content)
|
||||
executes against a real aiohttp app.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.api_server import (
|
||||
APIServerAdapter,
|
||||
_content_has_visible_payload,
|
||||
_normalize_multimodal_content,
|
||||
cors_middleware,
|
||||
security_headers_middleware,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure-function tests for _normalize_multimodal_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNormalizeMultimodalContent:
|
||||
def test_string_passthrough(self):
|
||||
assert _normalize_multimodal_content("hello") == "hello"
|
||||
|
||||
def test_none_returns_empty_string(self):
|
||||
assert _normalize_multimodal_content(None) == ""
|
||||
|
||||
def test_text_only_list_collapses_to_string(self):
|
||||
content = [{"type": "text", "text": "hi"}, {"type": "text", "text": "there"}]
|
||||
assert _normalize_multimodal_content(content) == "hi\nthere"
|
||||
|
||||
def test_responses_input_text_canonicalized(self):
|
||||
content = [{"type": "input_text", "text": "hello"}]
|
||||
assert _normalize_multimodal_content(content) == "hello"
|
||||
|
||||
def test_image_url_preserved_with_text(self):
|
||||
content = [
|
||||
{"type": "text", "text": "describe this"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/cat.png", "detail": "high"}},
|
||||
]
|
||||
out = _normalize_multimodal_content(content)
|
||||
assert isinstance(out, list)
|
||||
assert out == [
|
||||
{"type": "text", "text": "describe this"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/cat.png", "detail": "high"}},
|
||||
]
|
||||
|
||||
def test_input_image_converted_to_canonical_shape(self):
|
||||
content = [
|
||||
{"type": "input_text", "text": "hi"},
|
||||
{"type": "input_image", "image_url": "https://example.com/cat.png"},
|
||||
]
|
||||
out = _normalize_multimodal_content(content)
|
||||
assert out == [
|
||||
{"type": "text", "text": "hi"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/cat.png"}},
|
||||
]
|
||||
|
||||
def test_data_image_url_accepted(self):
|
||||
content = [{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}}]
|
||||
out = _normalize_multimodal_content(content)
|
||||
assert out == [{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}}]
|
||||
|
||||
def test_non_image_data_url_rejected(self):
|
||||
content = [{"type": "image_url", "image_url": {"url": "data:text/plain;base64,SGVsbG8="}}]
|
||||
with pytest.raises(ValueError) as exc:
|
||||
_normalize_multimodal_content(content)
|
||||
assert str(exc.value).startswith("unsupported_content_type:")
|
||||
|
||||
def test_file_part_rejected(self):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
_normalize_multimodal_content([{"type": "file", "file": {"file_id": "f_1"}}])
|
||||
assert str(exc.value).startswith("unsupported_content_type:")
|
||||
|
||||
def test_input_file_part_rejected(self):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
_normalize_multimodal_content([{"type": "input_file", "file_id": "f_1"}])
|
||||
assert str(exc.value).startswith("unsupported_content_type:")
|
||||
|
||||
def test_missing_url_rejected(self):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
_normalize_multimodal_content([{"type": "image_url", "image_url": {}}])
|
||||
assert str(exc.value).startswith("invalid_image_url:")
|
||||
|
||||
def test_bad_scheme_rejected(self):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
_normalize_multimodal_content([{"type": "image_url", "image_url": {"url": "ftp://example.com/x.png"}}])
|
||||
assert str(exc.value).startswith("invalid_image_url:")
|
||||
|
||||
def test_unknown_part_type_rejected(self):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
_normalize_multimodal_content([{"type": "audio", "audio": {}}])
|
||||
assert str(exc.value).startswith("unsupported_content_type:")
|
||||
|
||||
|
||||
class TestContentHasVisiblePayload:
|
||||
def test_non_empty_string(self):
|
||||
assert _content_has_visible_payload("hello")
|
||||
|
||||
def test_whitespace_only_string(self):
|
||||
assert not _content_has_visible_payload(" ")
|
||||
|
||||
def test_list_with_image_only(self):
|
||||
assert _content_has_visible_payload([{"type": "image_url", "image_url": {"url": "x"}}])
|
||||
|
||||
def test_list_with_only_empty_text(self):
|
||||
assert not _content_has_visible_payload([{"type": "text", "text": ""}])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP integration — real aiohttp client hitting the adapter handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_adapter() -> APIServerAdapter:
|
||||
return APIServerAdapter(PlatformConfig(enabled=True))
|
||||
|
||||
|
||||
def _create_app(adapter: APIServerAdapter) -> web.Application:
|
||||
mws = [mw for mw in (cors_middleware, security_headers_middleware) if mw is not None]
|
||||
app = web.Application(middlewares=mws)
|
||||
app["api_server_adapter"] = adapter
|
||||
app.router.add_post("/v1/chat/completions", adapter._handle_chat_completions)
|
||||
app.router.add_post("/v1/responses", adapter._handle_responses)
|
||||
app.router.add_get("/v1/responses/{response_id}", adapter._handle_get_response)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adapter():
|
||||
return _make_adapter()
|
||||
|
||||
|
||||
class TestChatCompletionsMultimodalHTTP:
|
||||
@pytest.mark.asyncio
|
||||
async def test_inline_image_preserved_to_run_agent(self, adapter):
|
||||
"""Multimodal user content reaches _run_agent as a list of parts."""
|
||||
image_payload = [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/cat.png", "detail": "high"}},
|
||||
]
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
adapter,
|
||||
"_run_agent",
|
||||
new=MagicMock(),
|
||||
) as mock_run:
|
||||
async def _stub(**kwargs):
|
||||
mock_run.captured = kwargs
|
||||
return (
|
||||
{"final_response": "A cat.", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 0, "output_tokens": 0, "total_tokens": 0},
|
||||
)
|
||||
mock_run.side_effect = _stub
|
||||
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "hermes-agent",
|
||||
"messages": [{"role": "user", "content": image_payload}],
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status == 200, await resp.text()
|
||||
assert mock_run.captured["user_message"] == image_payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_only_array_collapses_to_string(self, adapter):
|
||||
"""Text-only array becomes a plain string so logging stays unchanged."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(adapter, "_run_agent", new=MagicMock()) as mock_run:
|
||||
async def _stub(**kwargs):
|
||||
mock_run.captured = kwargs
|
||||
return (
|
||||
{"final_response": "ok", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 0, "output_tokens": 0, "total_tokens": 0},
|
||||
)
|
||||
mock_run.side_effect = _stub
|
||||
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "hermes-agent",
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "hello"}]},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status == 200, await resp.text()
|
||||
assert mock_run.captured["user_message"] == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_part_returns_400(self, adapter):
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "hermes-agent",
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "file", "file": {"file_id": "f_1"}}]},
|
||||
],
|
||||
},
|
||||
)
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert body["error"]["code"] == "unsupported_content_type"
|
||||
assert body["error"]["param"] == "messages[0].content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_image_data_url_returns_400(self, adapter):
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "hermes-agent",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:text/plain;base64,SGVsbG8="},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert body["error"]["code"] == "unsupported_content_type"
|
||||
|
||||
|
||||
class TestResponsesMultimodalHTTP:
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_image_canonicalized_and_forwarded(self, adapter):
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(adapter, "_run_agent", new=MagicMock()) as mock_run:
|
||||
async def _stub(**kwargs):
|
||||
mock_run.captured = kwargs
|
||||
return (
|
||||
{"final_response": "ok", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 0, "output_tokens": 0, "total_tokens": 0},
|
||||
)
|
||||
mock_run.side_effect = _stub
|
||||
|
||||
resp = await cli.post(
|
||||
"/v1/responses",
|
||||
json={
|
||||
"model": "hermes-agent",
|
||||
"input": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "input_text", "text": "Describe."},
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": "https://example.com/cat.png",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status == 200, await resp.text()
|
||||
expected = [
|
||||
{"type": "text", "text": "Describe."},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/cat.png"}},
|
||||
]
|
||||
assert mock_run.captured["user_message"] == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_file_returns_400(self, adapter):
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/v1/responses",
|
||||
json={
|
||||
"model": "hermes-agent",
|
||||
"input": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "input_file", "file_id": "f_1"}],
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert body["error"]["code"] == "unsupported_content_type"
|
||||
148
tests/gateway/test_cancel_background_drain.py
Normal file
148
tests/gateway/test_cancel_background_drain.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
"""Regression test: cancel_background_tasks must drain late-arrival tasks.
|
||||
|
||||
During gateway shutdown, a message arriving while
|
||||
cancel_background_tasks is mid-await can spawn a fresh
|
||||
_process_message_background task via handle_message, which is added
|
||||
to self._background_tasks. Without the re-drain loop, the subsequent
|
||||
_background_tasks.clear() drops the reference; the task runs
|
||||
untracked against a disconnecting adapter.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
class _StubAdapter(BasePlatformAdapter):
|
||||
async def connect(self):
|
||||
pass
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def send(self, chat_id, text, **kwargs):
|
||||
return None
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {}
|
||||
|
||||
|
||||
def _make_adapter():
|
||||
adapter = _StubAdapter(PlatformConfig(enabled=True, token="t"), Platform.TELEGRAM)
|
||||
adapter._send_with_retry = AsyncMock(return_value=None)
|
||||
return adapter
|
||||
|
||||
|
||||
def _event(text, cid="42"):
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=SessionSource(platform=Platform.TELEGRAM, chat_id=cid, chat_type="dm"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_background_tasks_drains_late_arrivals():
|
||||
"""A message that arrives during the gather window must be picked
|
||||
up by the re-drain loop, not leaked as an untracked task."""
|
||||
adapter = _make_adapter()
|
||||
sk = build_session_key(
|
||||
SessionSource(platform=Platform.TELEGRAM, chat_id="42", chat_type="dm")
|
||||
)
|
||||
|
||||
m1_started = asyncio.Event()
|
||||
m1_cleanup_running = asyncio.Event()
|
||||
m2_started = asyncio.Event()
|
||||
m2_cancelled = asyncio.Event()
|
||||
|
||||
async def handler(event):
|
||||
if event.text == "M1":
|
||||
m1_started.set()
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
m1_cleanup_running.set()
|
||||
# Widen the gather window with a shielded cleanup
|
||||
# delay so M2 can get injected during it.
|
||||
await asyncio.shield(asyncio.sleep(0.2))
|
||||
raise
|
||||
else: # M2 — the late arrival
|
||||
m2_started.set()
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
m2_cancelled.set()
|
||||
raise
|
||||
|
||||
adapter._message_handler = handler
|
||||
|
||||
# Spawn M1.
|
||||
await adapter.handle_message(_event("M1"))
|
||||
await asyncio.wait_for(m1_started.wait(), timeout=1.0)
|
||||
|
||||
# Kick off shutdown. This will cancel M1 and await its cleanup.
|
||||
cancel_task = asyncio.create_task(adapter.cancel_background_tasks())
|
||||
|
||||
# Wait until M1's cleanup is running (inside the shielded sleep).
|
||||
# This is the race window: cancel_task is awaiting gather, M1 is
|
||||
# shielded in cleanup, the _active_sessions entry has been cleared
|
||||
# by M1's own finally.
|
||||
await asyncio.wait_for(m1_cleanup_running.wait(), timeout=1.0)
|
||||
|
||||
# Clear the active-session entry (M1's finally hasn't fully run yet,
|
||||
# but in production the platform dispatcher would deliver a new
|
||||
# message that takes the no-active-session spawn path). For this
|
||||
# repro, make it deterministic.
|
||||
adapter._active_sessions.pop(sk, None)
|
||||
|
||||
# Inject late arrival — spawns a fresh _process_message_background
|
||||
# task and adds it to _background_tasks while cancel_task is still
|
||||
# in gather.
|
||||
await adapter.handle_message(_event("M2"))
|
||||
await asyncio.wait_for(m2_started.wait(), timeout=1.0)
|
||||
|
||||
# Let cancel_task finish. Round 1's gather completes when M1's
|
||||
# shielded cleanup finishes. Round 2 should pick up M2.
|
||||
await asyncio.wait_for(cancel_task, timeout=5.0)
|
||||
|
||||
# Assert M2 was drained, not leaked.
|
||||
assert m2_cancelled.is_set(), (
|
||||
"Late-arrival M2 was NOT cancelled by cancel_background_tasks — "
|
||||
"the re-drain loop is missing and the task leaked"
|
||||
)
|
||||
assert adapter._background_tasks == set()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_background_tasks_handles_no_tasks():
|
||||
"""Regression guard: no tasks, no hang, no error."""
|
||||
adapter = _make_adapter()
|
||||
await adapter.cancel_background_tasks()
|
||||
assert adapter._background_tasks == set()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_background_tasks_bounded_rounds():
|
||||
"""Regression guard: the drain loop is bounded — it does not spin
|
||||
forever even if late-arrival tasks keep getting spawned."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
# Single well-behaved task that cancels cleanly — baseline check
|
||||
# that the loop terminates in one round.
|
||||
async def quick():
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
task = asyncio.create_task(quick())
|
||||
adapter._background_tasks.add(task)
|
||||
|
||||
await adapter.cancel_background_tasks()
|
||||
assert task.done()
|
||||
assert adapter._background_tasks == set()
|
||||
91
tests/gateway/test_complete_path_at_filter.py
Normal file
91
tests/gateway/test_complete_path_at_filter.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
"""Regression tests for the TUI gateway's `complete.path` handler.
|
||||
|
||||
Reported during the TUI v2 blitz retest: typing `@folder:` (and `@folder`
|
||||
with no colon yet) still surfaced files alongside directories in the
|
||||
TUI composer, because the gateway-side completion lives in
|
||||
`tui_gateway/server.py` and was never touched by the earlier fix to
|
||||
`hermes_cli/commands.py`.
|
||||
|
||||
Covers:
|
||||
- `@folder:` only yields directories
|
||||
- `@file:` only yields regular files
|
||||
- Bare `@folder` / `@file` (no colon) lists cwd directly
|
||||
- Explicit prefix is preserved in the completion text
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from tui_gateway import server
|
||||
|
||||
|
||||
def _fixture(tmp_path: Path):
|
||||
(tmp_path / "readme.md").write_text("x")
|
||||
(tmp_path / ".env").write_text("x")
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "docs").mkdir()
|
||||
|
||||
|
||||
def _items(word: str):
|
||||
resp = server.handle_request({"id": "1", "method": "complete.path", "params": {"word": word}})
|
||||
|
||||
return [(it["text"], it["display"], it.get("meta", "")) for it in resp["result"]["items"]]
|
||||
|
||||
|
||||
def test_at_folder_colon_only_dirs(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
_fixture(tmp_path)
|
||||
|
||||
texts = [t for t, _, _ in _items("@folder:")]
|
||||
|
||||
assert all(t.startswith("@folder:") for t in texts), texts
|
||||
assert any(t == "@folder:src/" for t in texts)
|
||||
assert any(t == "@folder:docs/" for t in texts)
|
||||
assert not any(t == "@folder:readme.md" for t in texts)
|
||||
assert not any(t == "@folder:.env" for t in texts)
|
||||
|
||||
|
||||
def test_at_file_colon_only_files(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
_fixture(tmp_path)
|
||||
|
||||
texts = [t for t, _, _ in _items("@file:")]
|
||||
|
||||
assert all(t.startswith("@file:") for t in texts), texts
|
||||
assert any(t == "@file:readme.md" for t in texts)
|
||||
assert not any(t == "@file:src/" for t in texts)
|
||||
assert not any(t == "@file:docs/" for t in texts)
|
||||
|
||||
|
||||
def test_at_folder_bare_without_colon_lists_dirs(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
_fixture(tmp_path)
|
||||
|
||||
texts = [t for t, _, _ in _items("@folder")]
|
||||
|
||||
assert any(t == "@folder:src/" for t in texts), texts
|
||||
assert any(t == "@folder:docs/" for t in texts), texts
|
||||
assert not any(t == "@folder:readme.md" for t in texts)
|
||||
|
||||
|
||||
def test_at_file_bare_without_colon_lists_files(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
_fixture(tmp_path)
|
||||
|
||||
texts = [t for t, _, _ in _items("@file")]
|
||||
|
||||
assert any(t == "@file:readme.md" for t in texts), texts
|
||||
assert not any(t == "@file:src/" for t in texts)
|
||||
|
||||
|
||||
def test_bare_at_still_shows_static_refs(tmp_path, monkeypatch):
|
||||
"""`@` alone should list the static references so users discover the
|
||||
available prefixes. (Unchanged behaviour; regression guard.)
|
||||
"""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
texts = [t for t, _, _ in _items("@")]
|
||||
|
||||
for expected in ("@diff", "@staged", "@file:", "@folder:", "@url:", "@git:"):
|
||||
assert expected in texts, f"missing static ref {expected!r} in {texts!r}"
|
||||
|
|
@ -75,7 +75,6 @@ def _make_runner():
|
|||
runner._service_tier = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._smart_model_routing = {}
|
||||
runner._running_agents = {}
|
||||
runner._pending_model_notes = {}
|
||||
runner._session_db = None
|
||||
|
|
|
|||
79
tests/gateway/test_discord_race_polish.py
Normal file
79
tests/gateway/test_discord_race_polish.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""Discord adapter race polish: concurrent join_voice_channel must not
|
||||
double-invoke channel.connect() on the same guild."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
|
||||
def _make_adapter():
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter._platform = Platform.DISCORD
|
||||
adapter.config = PlatformConfig(enabled=True, token="t")
|
||||
adapter._ready_event = asyncio.Event()
|
||||
adapter._allowed_user_ids = set()
|
||||
adapter._allowed_role_ids = set()
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_locks = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._client = MagicMock()
|
||||
return adapter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_joins_do_not_double_connect():
|
||||
"""Two concurrent join_voice_channel calls on the same guild must
|
||||
serialize through the per-guild lock — only ONE channel.connect()
|
||||
actually fires; the second sees the _voice_clients entry the first
|
||||
just installed."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
connect_count = [0]
|
||||
release = asyncio.Event()
|
||||
|
||||
class FakeVC:
|
||||
def __init__(self, channel):
|
||||
self.channel = channel
|
||||
|
||||
def is_connected(self):
|
||||
return True
|
||||
|
||||
async def move_to(self, _channel):
|
||||
return None
|
||||
|
||||
async def slow_connect(self):
|
||||
connect_count[0] += 1
|
||||
await release.wait()
|
||||
return FakeVC(self)
|
||||
|
||||
channel = MagicMock()
|
||||
channel.id = 111
|
||||
channel.guild.id = 42
|
||||
channel.connect = lambda: slow_connect(channel)
|
||||
|
||||
from gateway.platforms import discord as discord_mod
|
||||
with patch.object(discord_mod, "VoiceReceiver",
|
||||
MagicMock(return_value=MagicMock(start=lambda: None))):
|
||||
with patch.object(discord_mod.asyncio, "ensure_future",
|
||||
lambda _c: asyncio.create_task(asyncio.sleep(0))):
|
||||
t1 = asyncio.create_task(adapter.join_voice_channel(channel))
|
||||
t2 = asyncio.create_task(adapter.join_voice_channel(channel))
|
||||
await asyncio.sleep(0.05)
|
||||
release.set()
|
||||
r1, r2 = await asyncio.gather(t1, t2)
|
||||
|
||||
assert connect_count[0] == 1, (
|
||||
f"expected 1 channel.connect() call, got {connect_count[0]} — "
|
||||
"per-guild lock is not serializing join_voice_channel"
|
||||
)
|
||||
assert r1 is True and r2 is True
|
||||
assert 42 in adapter._voice_clients
|
||||
|
|
@ -283,6 +283,48 @@ def test_persist_dm_topic_thread_id_skips_if_already_set(tmp_path):
|
|||
# ── _get_dm_topic_info ──
|
||||
|
||||
|
||||
def test_persist_dm_topic_thread_id_preserves_config_on_write_failure(tmp_path):
|
||||
"""Failed writes should leave the original config.yaml intact."""
|
||||
import yaml
|
||||
|
||||
config_data = {
|
||||
"platforms": {
|
||||
"telegram": {
|
||||
"extra": {
|
||||
"dm_topics": [
|
||||
{
|
||||
"chat_id": 111,
|
||||
"topics": [
|
||||
{"name": "General", "icon_color": 123},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
config_file = tmp_path / ".hermes" / "config.yaml"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
original_text = yaml.dump(config_data)
|
||||
config_file.write_text(original_text, encoding="utf-8")
|
||||
|
||||
adapter = _make_adapter()
|
||||
|
||||
def fail_dump(*args, **kwargs):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
with patch.object(Path, "home", return_value=tmp_path), \
|
||||
patch.dict(os.environ, {"HERMES_HOME": str(tmp_path / ".hermes")}), \
|
||||
patch("yaml.dump", side_effect=fail_dump):
|
||||
adapter._persist_dm_topic_thread_id(111, "General", 999)
|
||||
|
||||
assert config_file.read_text(encoding="utf-8") == original_text
|
||||
result = yaml.safe_load(config_file.read_text(encoding="utf-8"))
|
||||
topics = result["platforms"]["telegram"]["extra"]["dm_topics"][0]["topics"]
|
||||
assert "thread_id" not in topics[0]
|
||||
|
||||
|
||||
def test_get_dm_topic_info_finds_cached_topic():
|
||||
"""Should return topic config when thread_id is in cache."""
|
||||
adapter = _make_adapter([
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import sys
|
|||
import threading
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
|
@ -53,7 +53,6 @@ def _make_runner():
|
|||
runner._service_tier = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._smart_model_routing = {}
|
||||
runner._running_agents = {}
|
||||
runner._pending_model_notes = {}
|
||||
runner._session_db = None
|
||||
|
|
@ -97,13 +96,7 @@ def test_turn_route_injects_priority_processing_without_changing_runtime():
|
|||
"credential_pool": None,
|
||||
}
|
||||
|
||||
with patch("agent.smart_model_routing.resolve_turn_route", return_value={
|
||||
"model": "gpt-5.4",
|
||||
"runtime": dict(runtime_kwargs),
|
||||
"label": None,
|
||||
"signature": ("gpt-5.4", "openrouter", "https://openrouter.ai/api/v1", "chat_completions", None, ()),
|
||||
}):
|
||||
route = gateway_run.GatewayRunner._resolve_turn_agent_config(runner, "hi", "gpt-5.4", runtime_kwargs)
|
||||
route = gateway_run.GatewayRunner._resolve_turn_agent_config(runner, "hi", "gpt-5.4", runtime_kwargs)
|
||||
|
||||
assert route["runtime"]["provider"] == "openrouter"
|
||||
assert route["runtime"]["api_mode"] == "chat_completions"
|
||||
|
|
@ -123,13 +116,7 @@ def test_turn_route_skips_priority_processing_for_unsupported_models():
|
|||
"credential_pool": None,
|
||||
}
|
||||
|
||||
with patch("agent.smart_model_routing.resolve_turn_route", return_value={
|
||||
"model": "gpt-5.3-codex",
|
||||
"runtime": dict(runtime_kwargs),
|
||||
"label": None,
|
||||
"signature": ("gpt-5.3-codex", "openrouter", "https://openrouter.ai/api/v1", "chat_completions", None, ()),
|
||||
}):
|
||||
route = gateway_run.GatewayRunner._resolve_turn_agent_config(runner, "hi", "gpt-5.3-codex", runtime_kwargs)
|
||||
route = gateway_run.GatewayRunner._resolve_turn_agent_config(runner, "hi", "gpt-5.3-codex", runtime_kwargs)
|
||||
|
||||
assert route["request_overrides"] is None
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from pathlib import Path
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from gateway.platforms.base import ProcessingOutcome
|
||||
|
||||
try:
|
||||
import lark_oapi
|
||||
_HAS_LARK_OAPI = True
|
||||
|
|
@ -638,83 +640,54 @@ class TestAdapterBehavior(unittest.TestCase):
|
|||
)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
@unittest.skipUnless(_HAS_LARK_OAPI, "lark-oapi not installed")
|
||||
def test_add_ack_reaction_uses_ok_emoji(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
captured = {}
|
||||
|
||||
class _ReactionAPI:
|
||||
def create(self, request):
|
||||
captured["request"] = request
|
||||
return SimpleNamespace(
|
||||
success=lambda: True,
|
||||
data=SimpleNamespace(reaction_id="r_typing"),
|
||||
)
|
||||
|
||||
adapter._client = SimpleNamespace(
|
||||
im=SimpleNamespace(v1=SimpleNamespace(message_reaction=_ReactionAPI()))
|
||||
)
|
||||
|
||||
async def _direct(func, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct):
|
||||
reaction_id = asyncio.run(adapter._add_ack_reaction("om_msg"))
|
||||
|
||||
self.assertEqual(reaction_id, "r_typing")
|
||||
self.assertEqual(captured["request"].request_body.reaction_type["emoji_type"], "OK")
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_add_ack_reaction_logs_warning_on_failure(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
|
||||
class _ReactionAPI:
|
||||
def create(self, request):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
adapter._client = SimpleNamespace(
|
||||
im=SimpleNamespace(v1=SimpleNamespace(message_reaction=_ReactionAPI()))
|
||||
)
|
||||
|
||||
async def _direct(func, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
with (
|
||||
patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct),
|
||||
self.assertLogs("gateway.platforms.feishu", level="WARNING") as logs,
|
||||
):
|
||||
reaction_id = asyncio.run(adapter._add_ack_reaction("om_msg"))
|
||||
|
||||
self.assertIsNone(reaction_id)
|
||||
self.assertTrue(
|
||||
any("Failed to add ack reaction to om_msg" in entry for entry in logs.output),
|
||||
logs.output,
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_ack_reaction_events_are_ignored_to_avoid_feedback_loops(self):
|
||||
def test_bot_origin_reactions_are_dropped_to_avoid_feedback_loops(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
adapter._loop = object()
|
||||
|
||||
for emoji in ("Typing", "CrossMark"):
|
||||
event = SimpleNamespace(
|
||||
message_id="om_msg",
|
||||
operator_type="bot",
|
||||
reaction_type=SimpleNamespace(emoji_type=emoji),
|
||||
)
|
||||
data = SimpleNamespace(event=event)
|
||||
with patch(
|
||||
"gateway.platforms.feishu.asyncio.run_coroutine_threadsafe"
|
||||
) as run_threadsafe:
|
||||
adapter._on_reaction_event("im.message.reaction.created_v1", data)
|
||||
run_threadsafe.assert_not_called()
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_user_reaction_with_managed_emoji_is_still_routed(self):
|
||||
# Operator-origin filter is enough to prevent feedback loops; we must
|
||||
# not additionally swallow user-origin reactions just because their
|
||||
# emoji happens to collide with a lifecycle emoji.
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
adapter._loop = SimpleNamespace(is_closed=lambda: False)
|
||||
|
||||
event = SimpleNamespace(
|
||||
message_id="om_msg",
|
||||
operator_type="user",
|
||||
reaction_type=SimpleNamespace(emoji_type="OK"),
|
||||
reaction_type=SimpleNamespace(emoji_type="Typing"),
|
||||
)
|
||||
data = SimpleNamespace(event=event)
|
||||
|
||||
with patch("gateway.platforms.feishu.asyncio.run_coroutine_threadsafe") as run_threadsafe:
|
||||
adapter._on_reaction_event("im.message.reaction.created_v1", data)
|
||||
def _close_coro_and_return_future(coro, _loop):
|
||||
coro.close()
|
||||
return SimpleNamespace(add_done_callback=lambda _: None)
|
||||
|
||||
run_threadsafe.assert_not_called()
|
||||
with patch(
|
||||
"gateway.platforms.feishu.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=_close_coro_and_return_future,
|
||||
) as run_threadsafe:
|
||||
adapter._on_reaction_event("im.message.reaction.created_v1", data)
|
||||
run_threadsafe.assert_called_once()
|
||||
|
||||
@patch.dict(os.environ, {"FEISHU_GROUP_POLICY": "open"}, clear=True)
|
||||
def test_group_message_requires_mentions_even_when_policy_open(self):
|
||||
|
|
@ -743,6 +716,57 @@ class TestAdapterBehavior(unittest.TestCase):
|
|||
|
||||
self.assertFalse(adapter._should_accept_group_message(SimpleNamespace(mentions=[other_mention]), sender_id, ""))
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"FEISHU_BOT_OPEN_ID": "ou_hermes",
|
||||
"FEISHU_BOT_USER_ID": "u_hermes",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
def test_other_bot_sender_is_not_treated_as_self_sent_message(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
event = SimpleNamespace(
|
||||
sender=SimpleNamespace(
|
||||
sender_type="bot",
|
||||
sender_id=SimpleNamespace(open_id="ou_other_bot", user_id="u_other_bot"),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertFalse(adapter._is_self_sent_bot_message(event))
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"FEISHU_BOT_OPEN_ID": "ou_hermes",
|
||||
"FEISHU_BOT_USER_ID": "u_hermes",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
def test_self_bot_sender_is_treated_as_self_sent_message(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
by_open_id = SimpleNamespace(
|
||||
sender=SimpleNamespace(
|
||||
sender_type="bot",
|
||||
sender_id=SimpleNamespace(open_id="ou_hermes", user_id="u_other"),
|
||||
)
|
||||
)
|
||||
by_user_id = SimpleNamespace(
|
||||
sender=SimpleNamespace(
|
||||
sender_type="app",
|
||||
sender_id=SimpleNamespace(open_id="ou_other", user_id="u_hermes"),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(adapter._is_self_sent_bot_message(by_open_id))
|
||||
self.assertTrue(adapter._is_self_sent_bot_message(by_user_id))
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
|
|
@ -2370,6 +2394,134 @@ class TestAdapterBehavior(unittest.TestCase):
|
|||
elements = payload["zh_cn"]["content"][0]
|
||||
self.assertEqual(elements, [{"tag": "md", "text": "可以用 **粗体** 和 *斜体*。"}])
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_send_splits_fenced_code_blocks_into_separate_post_rows(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
captured = {}
|
||||
|
||||
class _MessageAPI:
|
||||
def create(self, request):
|
||||
captured["request"] = request
|
||||
return SimpleNamespace(
|
||||
success=lambda: True,
|
||||
data=SimpleNamespace(message_id="om_codeblock"),
|
||||
)
|
||||
|
||||
adapter._client = SimpleNamespace(
|
||||
im=SimpleNamespace(
|
||||
v1=SimpleNamespace(
|
||||
message=_MessageAPI(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def _direct(func, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
content = (
|
||||
"确认已入库 ✓\n"
|
||||
"文件路径:`/root/.hermes/profiles/agent_cto/cron/jobs.json`\n"
|
||||
"**解码后的内容:**\n"
|
||||
"```json\n"
|
||||
'{"cron": "list"}\n'
|
||||
"```\n"
|
||||
"后续说明仍应保留。"
|
||||
)
|
||||
|
||||
with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct):
|
||||
result = asyncio.run(
|
||||
adapter.send(
|
||||
chat_id="oc_chat",
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(result.success)
|
||||
self.assertEqual(captured["request"].request_body.msg_type, "post")
|
||||
payload = json.loads(captured["request"].request_body.content)
|
||||
rows = payload["zh_cn"]["content"]
|
||||
self.assertEqual(
|
||||
rows,
|
||||
[
|
||||
[
|
||||
{
|
||||
"tag": "md",
|
||||
"text": "确认已入库 ✓\n文件路径:`/root/.hermes/profiles/agent_cto/cron/jobs.json`\n**解码后的内容:**",
|
||||
}
|
||||
],
|
||||
[{"tag": "md", "text": "```json\n{\"cron\": \"list\"}\n```"}],
|
||||
[{"tag": "md", "text": "后续说明仍应保留。"}],
|
||||
],
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_build_post_payload_keeps_fence_like_code_lines_inside_code_block(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
payload = json.loads(
|
||||
adapter._build_post_payload(
|
||||
"before\n```python\n```oops\n```\nafter"
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
payload["zh_cn"]["content"],
|
||||
[
|
||||
[{"tag": "md", "text": "before"}],
|
||||
[{"tag": "md", "text": "```python\n```oops\n```"}],
|
||||
[{"tag": "md", "text": "after"}],
|
||||
],
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_build_post_payload_preserves_trailing_spaces_in_code_block(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
payload = json.loads(
|
||||
adapter._build_post_payload(
|
||||
"before\n```python\nline with two spaces \n```\nafter"
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
payload["zh_cn"]["content"],
|
||||
[
|
||||
[{"tag": "md", "text": "before"}],
|
||||
[{"tag": "md", "text": "```python\nline with two spaces \n```"}],
|
||||
[{"tag": "md", "text": "after"}],
|
||||
],
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_build_post_payload_splits_multiple_fenced_code_blocks(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
payload = json.loads(
|
||||
adapter._build_post_payload(
|
||||
"before\n```python\nprint(1)\n```\nmiddle\n```json\n{}\n```\nafter"
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
payload["zh_cn"]["content"],
|
||||
[
|
||||
[{"tag": "md", "text": "before"}],
|
||||
[{"tag": "md", "text": "```python\nprint(1)\n```"}],
|
||||
[{"tag": "md", "text": "middle"}],
|
||||
[{"tag": "md", "text": "```json\n{}\n```"}],
|
||||
[{"tag": "md", "text": "after"}],
|
||||
],
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_send_falls_back_to_text_when_post_payload_is_rejected(self):
|
||||
from gateway.config import PlatformConfig
|
||||
|
|
@ -2505,6 +2657,135 @@ class TestAdapterBehavior(unittest.TestCase):
|
|||
)
|
||||
|
||||
|
||||
@unittest.skipUnless(_HAS_LARK_OAPI, "lark-oapi not installed")
|
||||
class TestHydrateBotIdentity(unittest.TestCase):
|
||||
"""Hydration of bot identity via /open-apis/bot/v3/info and application info.
|
||||
|
||||
Covers the manual-setup path where FEISHU_BOT_OPEN_ID / FEISHU_BOT_USER_ID
|
||||
are not configured. Hydration must populate _bot_open_id so that
|
||||
_is_self_sent_bot_message() can filter the adapter's own outbound echoes.
|
||||
"""
|
||||
|
||||
def _make_adapter(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
return FeishuAdapter(PlatformConfig())
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_hydration_populates_open_id_from_bot_info(self):
|
||||
adapter = self._make_adapter()
|
||||
adapter._client = Mock()
|
||||
payload = json.dumps(
|
||||
{
|
||||
"code": 0,
|
||||
"bot": {
|
||||
"bot_name": "Hermes Bot",
|
||||
"open_id": "ou_hermes_hydrated",
|
||||
},
|
||||
}
|
||||
).encode("utf-8")
|
||||
response = SimpleNamespace(content=payload)
|
||||
adapter._client.request = Mock(return_value=response)
|
||||
|
||||
asyncio.run(adapter._hydrate_bot_identity())
|
||||
|
||||
self.assertEqual(adapter._bot_open_id, "ou_hermes_hydrated")
|
||||
self.assertEqual(adapter._bot_name, "Hermes Bot")
|
||||
# Application-info fallback must NOT run when bot_name is already set.
|
||||
self.assertFalse(
|
||||
adapter._client.application.v6.application.get.called
|
||||
if hasattr(adapter._client, "application") else False
|
||||
)
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"FEISHU_BOT_OPEN_ID": "ou_env",
|
||||
"FEISHU_BOT_NAME": "Env Hermes",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
def test_hydration_skipped_when_env_vars_supply_both_fields(self):
|
||||
adapter = self._make_adapter()
|
||||
adapter._client = Mock()
|
||||
adapter._client.request = Mock()
|
||||
|
||||
asyncio.run(adapter._hydrate_bot_identity())
|
||||
|
||||
# Neither probe should run — both fields are already populated.
|
||||
adapter._client.request.assert_not_called()
|
||||
self.assertEqual(adapter._bot_open_id, "ou_env")
|
||||
self.assertEqual(adapter._bot_name, "Env Hermes")
|
||||
|
||||
@patch.dict(os.environ, {"FEISHU_BOT_OPEN_ID": "ou_env"}, clear=True)
|
||||
def test_hydration_fills_only_missing_fields(self):
|
||||
"""Env-var open_id must NOT be overwritten by a different probe value."""
|
||||
adapter = self._make_adapter()
|
||||
adapter._client = Mock()
|
||||
payload = json.dumps(
|
||||
{
|
||||
"code": 0,
|
||||
"bot": {
|
||||
"bot_name": "Hermes Bot",
|
||||
"open_id": "ou_probe_DIFFERENT",
|
||||
},
|
||||
}
|
||||
).encode("utf-8")
|
||||
adapter._client.request = Mock(return_value=SimpleNamespace(content=payload))
|
||||
|
||||
asyncio.run(adapter._hydrate_bot_identity())
|
||||
|
||||
self.assertEqual(adapter._bot_open_id, "ou_env") # preserved
|
||||
self.assertEqual(adapter._bot_name, "Hermes Bot") # filled in
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_hydration_tolerates_probe_failure_and_falls_back_to_app_info(self):
|
||||
adapter = self._make_adapter()
|
||||
adapter._client = Mock()
|
||||
adapter._client.request = Mock(side_effect=RuntimeError("network down"))
|
||||
|
||||
# Make the application-info fallback succeed for _bot_name.
|
||||
app_response = Mock()
|
||||
app_response.success = Mock(return_value=True)
|
||||
app_response.data = SimpleNamespace(app=SimpleNamespace(app_name="Fallback Bot"))
|
||||
adapter._client.application.v6.application.get = Mock(return_value=app_response)
|
||||
adapter._build_get_application_request = Mock(return_value=object())
|
||||
|
||||
asyncio.run(adapter._hydrate_bot_identity())
|
||||
|
||||
# Primary probe failed — open_id stays empty, but bot_name came from app-info.
|
||||
self.assertEqual(adapter._bot_open_id, "")
|
||||
self.assertEqual(adapter._bot_name, "Fallback Bot")
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_hydrated_open_id_enables_self_send_filter(self):
|
||||
"""E2E: after hydration, _is_self_sent_bot_message() rejects adapter's own id."""
|
||||
adapter = self._make_adapter()
|
||||
adapter._client = Mock()
|
||||
payload = json.dumps(
|
||||
{"code": 0, "bot": {"bot_name": "Hermes", "open_id": "ou_hermes"}}
|
||||
).encode("utf-8")
|
||||
adapter._client.request = Mock(return_value=SimpleNamespace(content=payload))
|
||||
|
||||
asyncio.run(adapter._hydrate_bot_identity())
|
||||
|
||||
self_event = SimpleNamespace(
|
||||
sender=SimpleNamespace(
|
||||
sender_type="bot",
|
||||
sender_id=SimpleNamespace(open_id="ou_hermes", user_id=""),
|
||||
)
|
||||
)
|
||||
peer_event = SimpleNamespace(
|
||||
sender=SimpleNamespace(
|
||||
sender_type="bot",
|
||||
sender_id=SimpleNamespace(open_id="ou_peer_bot", user_id=""),
|
||||
)
|
||||
)
|
||||
self.assertTrue(adapter._is_self_sent_bot_message(self_event))
|
||||
self.assertFalse(adapter._is_self_sent_bot_message(peer_event))
|
||||
|
||||
|
||||
@unittest.skipUnless(_HAS_LARK_OAPI, "lark-oapi not installed")
|
||||
class TestPendingInboundQueue(unittest.TestCase):
|
||||
"""Tests for the loop-not-ready race (#5499): inbound events arriving
|
||||
|
|
@ -2970,3 +3251,231 @@ class TestSenderNameResolution(unittest.TestCase):
|
|||
result = asyncio.run(adapter._resolve_sender_name_from_api("ou_broken"))
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
|
||||
@unittest.skipUnless(_HAS_LARK_OAPI, "lark-oapi not installed")
|
||||
class TestProcessingReactions(unittest.TestCase):
|
||||
"""Typing on start → removed on SUCCESS, swapped for CrossMark on FAILURE,
|
||||
removed (no replacement) on CANCELLED."""
|
||||
|
||||
@staticmethod
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
def _build_adapter(
|
||||
self,
|
||||
create_success: bool = True,
|
||||
delete_success: bool = True,
|
||||
next_reaction_id: str = "r1",
|
||||
):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
tracker = SimpleNamespace(
|
||||
create_calls=[],
|
||||
delete_calls=[],
|
||||
next_reaction_id=next_reaction_id,
|
||||
create_success=create_success,
|
||||
delete_success=delete_success,
|
||||
)
|
||||
|
||||
def _create(request):
|
||||
tracker.create_calls.append(
|
||||
request.request_body.reaction_type["emoji_type"]
|
||||
)
|
||||
if tracker.create_success:
|
||||
return SimpleNamespace(
|
||||
success=lambda: True,
|
||||
data=SimpleNamespace(reaction_id=tracker.next_reaction_id),
|
||||
)
|
||||
return SimpleNamespace(
|
||||
success=lambda: False, code=99, msg="rejected", data=None,
|
||||
)
|
||||
|
||||
def _delete(request):
|
||||
tracker.delete_calls.append(request.reaction_id)
|
||||
return SimpleNamespace(
|
||||
success=lambda: tracker.delete_success,
|
||||
code=0 if tracker.delete_success else 99,
|
||||
msg="success" if tracker.delete_success else "rejected",
|
||||
)
|
||||
|
||||
adapter._client = SimpleNamespace(
|
||||
im=SimpleNamespace(
|
||||
v1=SimpleNamespace(
|
||||
message_reaction=SimpleNamespace(create=_create, delete=_delete),
|
||||
),
|
||||
),
|
||||
)
|
||||
return adapter, tracker
|
||||
|
||||
@staticmethod
|
||||
def _event(message_id: str = "om_msg"):
|
||||
return SimpleNamespace(message_id=message_id)
|
||||
|
||||
def _patch_to_thread(self):
|
||||
async def _direct(func, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct)
|
||||
|
||||
# ------------------------------------------------------------------ start
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_start_adds_typing_and_caches_reaction_id(self):
|
||||
adapter, tracker = self._build_adapter(next_reaction_id="r_typing")
|
||||
with self._patch_to_thread():
|
||||
self._run(adapter.on_processing_start(self._event()))
|
||||
self.assertEqual(tracker.create_calls, ["Typing"])
|
||||
self.assertEqual(adapter._pending_processing_reactions["om_msg"], "r_typing")
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_start_is_idempotent_for_same_message_id(self):
|
||||
adapter, tracker = self._build_adapter(next_reaction_id="r_typing")
|
||||
with self._patch_to_thread():
|
||||
self._run(adapter.on_processing_start(self._event()))
|
||||
self._run(adapter.on_processing_start(self._event()))
|
||||
self.assertEqual(tracker.create_calls, ["Typing"])
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_start_does_not_cache_when_create_fails(self):
|
||||
adapter, tracker = self._build_adapter(create_success=False)
|
||||
with self._patch_to_thread():
|
||||
self._run(adapter.on_processing_start(self._event()))
|
||||
self.assertEqual(tracker.create_calls, ["Typing"])
|
||||
self.assertNotIn("om_msg", adapter._pending_processing_reactions)
|
||||
|
||||
# --------------------------------------------------------------- complete
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_success_removes_typing_and_adds_nothing(self):
|
||||
adapter, tracker = self._build_adapter(next_reaction_id="r_typing")
|
||||
with self._patch_to_thread():
|
||||
self._run(adapter.on_processing_start(self._event()))
|
||||
self._run(
|
||||
adapter.on_processing_complete(self._event(), ProcessingOutcome.SUCCESS)
|
||||
)
|
||||
self.assertEqual(tracker.create_calls, ["Typing"])
|
||||
self.assertEqual(tracker.delete_calls, ["r_typing"])
|
||||
self.assertNotIn("om_msg", adapter._pending_processing_reactions)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_failure_removes_typing_then_adds_cross_mark(self):
|
||||
adapter, tracker = self._build_adapter(next_reaction_id="r_typing")
|
||||
with self._patch_to_thread():
|
||||
self._run(adapter.on_processing_start(self._event()))
|
||||
self._run(
|
||||
adapter.on_processing_complete(self._event(), ProcessingOutcome.FAILURE)
|
||||
)
|
||||
self.assertEqual(tracker.create_calls, ["Typing", "CrossMark"])
|
||||
self.assertEqual(tracker.delete_calls, ["r_typing"])
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_cancelled_removes_typing_and_adds_nothing(self):
|
||||
adapter, tracker = self._build_adapter(next_reaction_id="r_typing")
|
||||
with self._patch_to_thread():
|
||||
self._run(adapter.on_processing_start(self._event()))
|
||||
self._run(
|
||||
adapter.on_processing_complete(self._event(), ProcessingOutcome.CANCELLED)
|
||||
)
|
||||
self.assertEqual(tracker.create_calls, ["Typing"])
|
||||
self.assertEqual(tracker.delete_calls, ["r_typing"])
|
||||
self.assertNotIn("om_msg", adapter._pending_processing_reactions)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_failure_without_preceding_start_still_adds_cross_mark(self):
|
||||
adapter, tracker = self._build_adapter()
|
||||
with self._patch_to_thread():
|
||||
self._run(
|
||||
adapter.on_processing_complete(self._event(), ProcessingOutcome.FAILURE)
|
||||
)
|
||||
self.assertEqual(tracker.create_calls, ["CrossMark"])
|
||||
self.assertEqual(tracker.delete_calls, [])
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_success_without_preceding_start_is_full_noop(self):
|
||||
adapter, tracker = self._build_adapter()
|
||||
with self._patch_to_thread():
|
||||
self._run(
|
||||
adapter.on_processing_complete(self._event(), ProcessingOutcome.SUCCESS)
|
||||
)
|
||||
self.assertEqual(tracker.create_calls, [])
|
||||
self.assertEqual(tracker.delete_calls, [])
|
||||
|
||||
# ------------------------- delete failure: don't stack badges -----------
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_delete_failure_on_failure_outcome_skips_cross_mark(self):
|
||||
# Removing Typing is best-effort — but if it fails, we must NOT
|
||||
# additionally add CrossMark, or the UI would show two contradictory
|
||||
# badges. The handle stays in the cache for LRU to clean up later.
|
||||
adapter, tracker = self._build_adapter(
|
||||
next_reaction_id="r_typing", delete_success=False,
|
||||
)
|
||||
with self._patch_to_thread():
|
||||
self._run(adapter.on_processing_start(self._event()))
|
||||
self._run(
|
||||
adapter.on_processing_complete(self._event(), ProcessingOutcome.FAILURE)
|
||||
)
|
||||
self.assertEqual(tracker.create_calls, ["Typing"]) # CrossMark NOT added
|
||||
self.assertEqual(tracker.delete_calls, ["r_typing"]) # delete was attempted
|
||||
self.assertEqual(
|
||||
adapter._pending_processing_reactions["om_msg"], "r_typing",
|
||||
) # handle retained
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_delete_failure_on_success_outcome_retains_handle(self):
|
||||
adapter, tracker = self._build_adapter(
|
||||
next_reaction_id="r_typing", delete_success=False,
|
||||
)
|
||||
with self._patch_to_thread():
|
||||
self._run(adapter.on_processing_start(self._event()))
|
||||
self._run(
|
||||
adapter.on_processing_complete(self._event(), ProcessingOutcome.SUCCESS)
|
||||
)
|
||||
self.assertEqual(tracker.create_calls, ["Typing"])
|
||||
self.assertEqual(tracker.delete_calls, ["r_typing"])
|
||||
self.assertEqual(
|
||||
adapter._pending_processing_reactions["om_msg"], "r_typing",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------- env toggle
|
||||
@patch.dict(os.environ, {"FEISHU_REACTIONS": "false"}, clear=True)
|
||||
def test_env_disable_short_circuits_both_hooks(self):
|
||||
adapter, tracker = self._build_adapter()
|
||||
with self._patch_to_thread():
|
||||
self._run(adapter.on_processing_start(self._event()))
|
||||
self._run(
|
||||
adapter.on_processing_complete(self._event(), ProcessingOutcome.FAILURE)
|
||||
)
|
||||
self.assertEqual(tracker.create_calls, [])
|
||||
self.assertEqual(tracker.delete_calls, [])
|
||||
|
||||
# ------------------------------------------------------------- LRU bounds
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_cache_evicts_oldest_entry_beyond_size_limit(self):
|
||||
from gateway.platforms.feishu import _FEISHU_PROCESSING_REACTION_CACHE_SIZE
|
||||
|
||||
adapter, _ = self._build_adapter()
|
||||
counter = {"n": 0}
|
||||
|
||||
def _create(_request):
|
||||
counter["n"] += 1
|
||||
return SimpleNamespace(
|
||||
success=lambda: True,
|
||||
data=SimpleNamespace(reaction_id=f"r{counter['n']}"),
|
||||
)
|
||||
|
||||
adapter._client.im.v1.message_reaction.create = _create
|
||||
|
||||
with self._patch_to_thread():
|
||||
for i in range(_FEISHU_PROCESSING_REACTION_CACHE_SIZE + 1):
|
||||
self._run(adapter.on_processing_start(self._event(f"om_{i}")))
|
||||
|
||||
self.assertNotIn("om_0", adapter._pending_processing_reactions)
|
||||
self.assertIn(
|
||||
f"om_{_FEISHU_PROCESSING_REACTION_CACHE_SIZE}",
|
||||
adapter._pending_processing_reactions,
|
||||
)
|
||||
self.assertEqual(
|
||||
len(adapter._pending_processing_reactions),
|
||||
_FEISHU_PROCESSING_REACTION_CACHE_SIZE,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -355,8 +355,17 @@ async def test_none_user_id_does_not_generate_pairing_code(monkeypatch, tmp_path
|
|||
async def test_non_internal_event_without_user_triggers_pairing(monkeypatch, tmp_path):
|
||||
"""Verify the normal (non-internal) path still triggers pairing for unknown users."""
|
||||
import gateway.run as gateway_run
|
||||
import gateway.pairing as pairing_mod
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
# gateway.pairing.PAIRING_DIR is a module-level constant captured at
|
||||
# import time from whichever HERMES_HOME was set then. Per-test
|
||||
# HERMES_HOME redirection in conftest doesn't retroactively move it.
|
||||
# Override directly so pairing rate-limit state lives in this test's
|
||||
# tmp_path (and so stale state from prior xdist workers can't leak in).
|
||||
pairing_dir = tmp_path / "pairing"
|
||||
pairing_dir.mkdir()
|
||||
monkeypatch.setattr(pairing_mod, "PAIRING_DIR", pairing_dir)
|
||||
(tmp_path / "config.yaml").write_text("", encoding="utf-8")
|
||||
|
||||
# Clear env vars that could let all users through (loaded by
|
||||
|
|
|
|||
|
|
@ -1,13 +1,18 @@
|
|||
"""Tests for the pending_event None guard in recursive _run_agent calls.
|
||||
"""Tests for pending follow-up extraction in recursive _run_agent calls.
|
||||
|
||||
When pending_event is None (Path B: pending comes from interrupt_message),
|
||||
accessing pending_event.channel_prompt previously raised AttributeError.
|
||||
This verifies the fix: channel_prompt is captured inside the
|
||||
`if pending_event is not None:` block and falls back to None otherwise.
|
||||
|
||||
Also verifies that internal control interrupt reasons like "Stop requested"
|
||||
do not get recycled into the pending-user-message follow-up path.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from gateway.run import _is_control_interrupt_message
|
||||
|
||||
|
||||
def _extract_channel_prompt(pending_event):
|
||||
"""Reproduce the fixed logic from gateway/run.py.
|
||||
|
|
@ -21,6 +26,15 @@ def _extract_channel_prompt(pending_event):
|
|||
return next_channel_prompt
|
||||
|
||||
|
||||
def _extract_pending_text(interrupted, pending_event, interrupt_message):
|
||||
"""Reproduce the fixed pending-text selection from gateway/run.py."""
|
||||
if interrupted and pending_event is None and interrupt_message:
|
||||
if _is_control_interrupt_message(interrupt_message):
|
||||
return None
|
||||
return interrupt_message
|
||||
return None
|
||||
|
||||
|
||||
class TestPendingEventNoneChannelPrompt:
|
||||
"""Guard against AttributeError when pending_event is None."""
|
||||
|
||||
|
|
@ -40,3 +54,19 @@ class TestPendingEventNoneChannelPrompt:
|
|||
event = SimpleNamespace()
|
||||
result = _extract_channel_prompt(event)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestControlInterruptMessages:
|
||||
"""Control interrupt reasons must not become follow-up user input."""
|
||||
|
||||
def test_stop_requested_is_not_treated_as_pending_user_message(self):
|
||||
result = _extract_pending_text(True, None, "Stop requested")
|
||||
assert result is None
|
||||
|
||||
def test_session_reset_requested_is_not_treated_as_pending_user_message(self):
|
||||
result = _extract_pending_text(True, None, "Session reset requested")
|
||||
assert result is None
|
||||
|
||||
def test_real_user_interrupt_message_still_requeues(self):
|
||||
result = _extract_pending_text(True, None, "actually use postgres instead")
|
||||
assert result == "actually use postgres instead"
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
import pytest
|
||||
|
||||
from gateway.config import Platform, StreamingConfig
|
||||
from gateway.platforms.base import resolve_proxy_url
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
|
@ -19,6 +20,7 @@ def _make_runner(proxy_url=None):
|
|||
runner.config = MagicMock()
|
||||
runner.config.streaming = StreamingConfig()
|
||||
runner._running_agents = {}
|
||||
runner._session_run_generation = {}
|
||||
runner._session_model_overrides = {}
|
||||
runner._agent_cache = {}
|
||||
runner._agent_cache_lock = None
|
||||
|
|
@ -132,6 +134,15 @@ class TestGetProxyUrl:
|
|||
assert runner._get_proxy_url() is None
|
||||
|
||||
|
||||
class TestResolveProxyUrl:
|
||||
def test_normalizes_socks_alias_from_all_proxy(self, monkeypatch):
|
||||
for key in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY",
|
||||
"https_proxy", "http_proxy", "all_proxy"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
monkeypatch.setenv("ALL_PROXY", "socks://127.0.0.1:1080/")
|
||||
assert resolve_proxy_url() == "socks5://127.0.0.1:1080/"
|
||||
|
||||
|
||||
class TestRunAgentProxyDispatch:
|
||||
"""Test that _run_agent() delegates to proxy when configured."""
|
||||
|
||||
|
|
@ -160,10 +171,12 @@ class TestRunAgentProxyDispatch:
|
|||
source=source,
|
||||
session_id="test-session-123",
|
||||
session_key="test-key",
|
||||
run_generation=7,
|
||||
)
|
||||
|
||||
assert result["final_response"] == "Hello from remote!"
|
||||
runner._run_agent_via_proxy.assert_called_once()
|
||||
assert runner._run_agent_via_proxy.call_args.kwargs["run_generation"] == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_skips_proxy_when_not_configured(self, monkeypatch):
|
||||
|
|
@ -370,6 +383,40 @@ class TestRunAgentViaProxy:
|
|||
assert "session_id" in result
|
||||
assert result["session_id"] == "sess-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_stale_generation_returns_empty_result(self, monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")
|
||||
monkeypatch.delenv("GATEWAY_PROXY_KEY", raising=False)
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
runner._session_run_generation["test-key"] = 2
|
||||
|
||||
resp = _FakeSSEResponse(
|
||||
status=200,
|
||||
sse_chunks=[
|
||||
'data: {"choices":[{"delta":{"content":"stale"}}]}\n\n',
|
||||
"data: [DONE]\n\n",
|
||||
],
|
||||
)
|
||||
session = _FakeSession(resp)
|
||||
|
||||
with patch("gateway.run._load_gateway_config", return_value={}):
|
||||
with _patch_aiohttp(session):
|
||||
with patch("aiohttp.ClientTimeout"):
|
||||
result = await runner._run_agent_via_proxy(
|
||||
message="hi",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="sess-123",
|
||||
session_key="test-key",
|
||||
run_generation=1,
|
||||
)
|
||||
|
||||
assert result["final_response"] == ""
|
||||
assert result["messages"] == []
|
||||
assert result["api_calls"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_auth_header_without_key(self, monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")
|
||||
|
|
|
|||
159
tests/gateway/test_reply_to_injection.py
Normal file
159
tests/gateway/test_reply_to_injection.py
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
"""Tests for reply-to pointer injection in _prepare_inbound_message_text.
|
||||
|
||||
The `[Replying to: "..."]` prefix is a *disambiguation pointer*, not
|
||||
deduplication. It must always be injected when the user explicitly replies
|
||||
to a prior message — even when the quoted text already exists somewhere
|
||||
in the conversation history. History can contain the same or similar text
|
||||
multiple times, and without an explicit pointer the agent has to guess
|
||||
which prior message the user is referencing.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def _make_runner() -> GatewayRunner:
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake")},
|
||||
)
|
||||
runner.adapters = {}
|
||||
runner._model = "openai/gpt-4.1-mini"
|
||||
runner._base_url = None
|
||||
return runner
|
||||
|
||||
|
||||
def _source() -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_name="DM",
|
||||
chat_type="private",
|
||||
user_name="Alice",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_prefix_injected_when_text_absent_from_history():
|
||||
runner = _make_runner()
|
||||
source = _source()
|
||||
event = MessageEvent(
|
||||
text="What's the best time to go?",
|
||||
source=source,
|
||||
reply_to_message_id="42",
|
||||
reply_to_text="Japan is great for culture, food, and efficiency.",
|
||||
)
|
||||
|
||||
result = await runner._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=[{"role": "user", "content": "unrelated"}],
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.startswith(
|
||||
'[Replying to: "Japan is great for culture, food, and efficiency."]'
|
||||
)
|
||||
assert result.endswith("What's the best time to go?")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_prefix_still_injected_when_text_in_history():
|
||||
"""Regression test: the pointer must survive even when the quoted text
|
||||
already appears in history. Previously a `found_in_history` guard
|
||||
silently dropped the prefix, leaving the agent to guess which prior
|
||||
message the user was referencing."""
|
||||
runner = _make_runner()
|
||||
source = _source()
|
||||
quoted = "Japan is great for culture, food, and efficiency."
|
||||
event = MessageEvent(
|
||||
text="What's the best time to go?",
|
||||
source=source,
|
||||
reply_to_message_id="42",
|
||||
reply_to_text=quoted,
|
||||
)
|
||||
|
||||
history = [
|
||||
{"role": "user", "content": "I'm thinking of going to Japan or Italy."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
f"{quoted} Italy is better if you prefer a relaxed pace."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": "How long should I stay?"},
|
||||
{"role": "assistant", "content": "For Japan, 10-14 days is ideal."},
|
||||
]
|
||||
|
||||
result = await runner._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=history,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.startswith(f'[Replying to: "{quoted}"]')
|
||||
assert result.endswith("What's the best time to go?")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_prefix_without_reply_context():
|
||||
runner = _make_runner()
|
||||
source = _source()
|
||||
event = MessageEvent(text="hello", source=source)
|
||||
|
||||
result = await runner._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=[],
|
||||
)
|
||||
|
||||
assert result == "hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_prefix_when_reply_to_text_is_empty():
|
||||
"""reply_to_message_id alone without text (e.g. a reply to a media-only
|
||||
message) should not produce an empty `[Replying to: ""]` prefix."""
|
||||
runner = _make_runner()
|
||||
source = _source()
|
||||
event = MessageEvent(
|
||||
text="hi",
|
||||
source=source,
|
||||
reply_to_message_id="42",
|
||||
reply_to_text=None,
|
||||
)
|
||||
|
||||
result = await runner._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=[],
|
||||
)
|
||||
|
||||
assert result == "hi"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_snippet_truncated_to_500_chars():
|
||||
runner = _make_runner()
|
||||
source = _source()
|
||||
long_text = "x" * 800
|
||||
event = MessageEvent(
|
||||
text="follow-up",
|
||||
source=source,
|
||||
reply_to_message_id="42",
|
||||
reply_to_text=long_text,
|
||||
)
|
||||
|
||||
result = await runner._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=[],
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.startswith('[Replying to: "' + "x" * 500 + '"]')
|
||||
assert "x" * 501 not in result
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import shutil
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
|
@ -8,7 +9,7 @@ import pytest
|
|||
import gateway.run as gateway_run
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
from gateway.session import build_session_key
|
||||
from gateway.session import SessionEntry, build_session_key
|
||||
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
|
||||
|
||||
|
||||
|
|
@ -242,3 +243,31 @@ async def test_shutdown_notification_send_failure_does_not_block():
|
|||
|
||||
# Should not raise
|
||||
await runner._notify_active_sessions_of_shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_notification_uses_persisted_origin_for_colon_ids():
|
||||
"""Shutdown notifications should route from persisted origin, not reparsed keys."""
|
||||
runner, adapter = make_restart_runner()
|
||||
adapter.send = AsyncMock()
|
||||
source = make_restart_source(chat_id="!room123:example.org", chat_type="group")
|
||||
source.platform = gateway_run.Platform.MATRIX
|
||||
session_key = build_session_key(source)
|
||||
runner._running_agents[session_key] = MagicMock()
|
||||
runner.session_store._entries = {
|
||||
session_key: SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
origin=source,
|
||||
platform=source.platform,
|
||||
chat_type=source.chat_type,
|
||||
)
|
||||
}
|
||||
runner.adapters = {gateway_run.Platform.MATRIX: adapter}
|
||||
|
||||
await runner._notify_active_sessions_of_shutdown()
|
||||
|
||||
assert adapter.send.await_count == 1
|
||||
assert adapter.send.await_args.args[0] == "!room123:example.org"
|
||||
|
|
|
|||
|
|
@ -51,6 +51,9 @@ class ProgressCaptureAdapter(BasePlatformAdapter):
|
|||
async def send_typing(self, chat_id, metadata=None) -> None:
|
||||
self.typing.append({"chat_id": chat_id, "metadata": metadata})
|
||||
|
||||
async def stop_typing(self, chat_id) -> None:
|
||||
self.typing.append({"chat_id": chat_id, "metadata": {"stopped": True}})
|
||||
|
||||
async def get_chat_info(self, chat_id: str):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
|
@ -90,6 +93,40 @@ class LongPreviewAgent:
|
|||
}
|
||||
|
||||
|
||||
class DelayedProgressAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
self.tool_progress_callback("tool.started", "terminal", "first command", {})
|
||||
time.sleep(0.45)
|
||||
self.tool_progress_callback("tool.started", "terminal", "second command", {})
|
||||
time.sleep(0.1)
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
class DelayedInterimAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
self.interim_assistant_callback("first interim")
|
||||
time.sleep(0.45)
|
||||
self.interim_assistant_callback("second interim")
|
||||
time.sleep(0.1)
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
def _make_runner(adapter):
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
GatewayRunner = gateway_run.GatewayRunner
|
||||
|
|
@ -104,6 +141,7 @@ def _make_runner(adapter):
|
|||
runner._fallback_model = None
|
||||
runner._session_db = None
|
||||
runner._running_agents = {}
|
||||
runner._session_run_generation = {}
|
||||
runner.hooks = SimpleNamespace(loaded_hooks=False)
|
||||
runner.config = SimpleNamespace(
|
||||
thread_sessions_per_user=False,
|
||||
|
|
@ -744,6 +782,154 @@ async def test_base_processing_releases_post_delivery_callback_after_main_send()
|
|||
assert released == [True]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_drops_tool_progress_after_generation_invalidation(monkeypatch, tmp_path):
|
||||
import yaml
|
||||
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
yaml.dump({"display": {"tool_progress": "all"}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = DelayedProgressAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
import tools.terminal_tool # noqa: F401 - register terminal tool metadata
|
||||
|
||||
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
|
||||
runner = _make_runner(adapter)
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="dm-1",
|
||||
chat_type="dm",
|
||||
thread_id=None,
|
||||
)
|
||||
session_key = "agent:main:discord:dm:dm-1"
|
||||
runner._session_run_generation[session_key] = 1
|
||||
|
||||
original_send = adapter.send
|
||||
invalidated = {"done": False}
|
||||
|
||||
async def send_and_invalidate(chat_id, content, reply_to=None, metadata=None):
|
||||
result = await original_send(chat_id, content, reply_to=reply_to, metadata=metadata)
|
||||
if "first command" in content and not invalidated["done"]:
|
||||
invalidated["done"] = True
|
||||
runner._invalidate_session_run_generation(session_key, reason="test_stop")
|
||||
return result
|
||||
|
||||
adapter.send = send_and_invalidate
|
||||
|
||||
result = await runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="sess-progress-stop",
|
||||
session_key=session_key,
|
||||
run_generation=1,
|
||||
)
|
||||
|
||||
all_progress_text = " ".join(call["content"] for call in adapter.sent)
|
||||
all_progress_text += " ".join(call["content"] for call in adapter.edits)
|
||||
assert result["final_response"] == "done"
|
||||
assert 'first command' in all_progress_text
|
||||
assert 'second command' not in all_progress_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_drops_interim_commentary_after_generation_invalidation(monkeypatch, tmp_path):
|
||||
import yaml
|
||||
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
yaml.dump({"display": {"tool_progress": "off", "interim_assistant_messages": True}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = DelayedInterimAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
|
||||
runner = _make_runner(adapter)
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="dm-2",
|
||||
chat_type="dm",
|
||||
thread_id=None,
|
||||
)
|
||||
session_key = "agent:main:discord:dm:dm-2"
|
||||
runner._session_run_generation[session_key] = 1
|
||||
|
||||
original_send = adapter.send
|
||||
invalidated = {"done": False}
|
||||
|
||||
async def send_and_invalidate(chat_id, content, reply_to=None, metadata=None):
|
||||
result = await original_send(chat_id, content, reply_to=reply_to, metadata=metadata)
|
||||
if content == "first interim" and not invalidated["done"]:
|
||||
invalidated["done"] = True
|
||||
runner._invalidate_session_run_generation(session_key, reason="test_stop")
|
||||
return result
|
||||
|
||||
adapter.send = send_and_invalidate
|
||||
|
||||
result = await runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="sess-commentary-stop",
|
||||
session_key=session_key,
|
||||
run_generation=1,
|
||||
)
|
||||
|
||||
sent_texts = [call["content"] for call in adapter.sent]
|
||||
assert result["final_response"] == "done"
|
||||
assert "first interim" in sent_texts
|
||||
assert "second interim" not in sent_texts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keep_typing_stops_immediately_when_interrupt_event_is_set():
|
||||
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
task = asyncio.create_task(
|
||||
adapter._keep_typing(
|
||||
"dm-typing-stop",
|
||||
interval=30.0,
|
||||
stop_event=stop_event,
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
stop_event.set()
|
||||
await asyncio.wait_for(task, timeout=0.5)
|
||||
|
||||
normal_typing_calls = [
|
||||
call for call in adapter.typing if call.get("metadata") != {"stopped": True}
|
||||
]
|
||||
stopped_calls = [
|
||||
call for call in adapter.typing if call.get("metadata") == {"stopped": True}
|
||||
]
|
||||
assert len(normal_typing_calls) == 1
|
||||
assert len(stopped_calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verbose_mode_does_not_truncate_args_by_default(monkeypatch, tmp_path):
|
||||
"""Verbose mode with default tool_preview_length (0) should NOT truncate args.
|
||||
|
|
|
|||
|
|
@ -184,8 +184,15 @@ async def test_start_gateway_replace_force_uses_terminate_pid(monkeypatch, tmp_p
|
|||
async def stop(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("gateway.status.get_running_pid", lambda: 42)
|
||||
monkeypatch.setattr("gateway.status.remove_pid_file", lambda: None)
|
||||
# get_running_pid returns 42 before we kill the old gateway, then None
|
||||
# after remove_pid_file() clears the record (reflects real behavior).
|
||||
_pid_state = {"alive": True}
|
||||
def _mock_get_running_pid():
|
||||
return 42 if _pid_state["alive"] else None
|
||||
def _mock_remove_pid_file():
|
||||
_pid_state["alive"] = False
|
||||
monkeypatch.setattr("gateway.status.get_running_pid", _mock_get_running_pid)
|
||||
monkeypatch.setattr("gateway.status.remove_pid_file", _mock_remove_pid_file)
|
||||
monkeypatch.setattr("gateway.status.release_all_scoped_locks", lambda: 0)
|
||||
monkeypatch.setattr("gateway.status.terminate_pid", lambda pid, force=False: calls.append((pid, force)))
|
||||
monkeypatch.setattr("gateway.run.os.getpid", lambda: 100)
|
||||
|
|
@ -253,8 +260,13 @@ async def test_start_gateway_replace_writes_takeover_marker_before_sigterm(
|
|||
async def stop(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("gateway.status.get_running_pid", lambda: 42)
|
||||
monkeypatch.setattr("gateway.status.remove_pid_file", lambda: None)
|
||||
_pid_state = {"alive": True}
|
||||
def _mock_get_running_pid():
|
||||
return 42 if _pid_state["alive"] else None
|
||||
def _mock_remove_pid_file():
|
||||
_pid_state["alive"] = False
|
||||
monkeypatch.setattr("gateway.status.get_running_pid", _mock_get_running_pid)
|
||||
monkeypatch.setattr("gateway.status.remove_pid_file", _mock_remove_pid_file)
|
||||
monkeypatch.setattr("gateway.status.release_all_scoped_locks", lambda: 0)
|
||||
monkeypatch.setattr("gateway.status.write_takeover_marker", record_write_marker)
|
||||
monkeypatch.setattr("gateway.status.terminate_pid", record_terminate)
|
||||
|
|
@ -319,3 +331,23 @@ async def test_start_gateway_replace_clears_marker_on_permission_denied(
|
|||
assert ok is False
|
||||
# Marker must NOT be left behind
|
||||
assert not (tmp_path / ".gateway-takeover.json").exists()
|
||||
|
||||
|
||||
def test_runner_warns_when_docker_gateway_lacks_explicit_output_mount(monkeypatch, tmp_path, caplog):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("TERMINAL_ENV", "docker")
|
||||
monkeypatch.setenv("TERMINAL_DOCKER_VOLUMES", '["/etc/localtime:/etc/localtime:ro"]')
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
GatewayRunner(config)
|
||||
|
||||
assert any(
|
||||
"host-visible output mount" in record.message
|
||||
for record in caplog.records
|
||||
)
|
||||
|
|
|
|||
167
tests/gateway/test_running_agent_session_toggles.py
Normal file
167
tests/gateway/test_running_agent_session_toggles.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
"""Regression tests: /yolo and /verbose dispatch mid-agent-run.
|
||||
|
||||
When an agent is running, the gateway's running-agent guard rejects most
|
||||
slash commands with "⏳ Agent is running — /{cmd} can't run mid-turn"
|
||||
(PR #12334). A small allowlist bypasses that and actually dispatches:
|
||||
|
||||
* /yolo — toggles the session yolo flag; useful to pre-approve a
|
||||
pending approval prompt without waiting for the agent to finish.
|
||||
* /verbose — cycles the per-platform tool-progress display mode;
|
||||
affects the ongoing stream.
|
||||
|
||||
Commands whose handlers say "takes effect on next message" stay on the
|
||||
catch-all by design:
|
||||
|
||||
* /fast — writes config.yaml only
|
||||
* /reasoning — writes config.yaml only
|
||||
|
||||
These tests lock in both behaviors so the allowlist doesn't silently
|
||||
grow or shrink.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_source() -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def _make_event(text: str) -> MessageEvent:
|
||||
return MessageEvent(text=text, source=_make_source(), message_id="m1")
|
||||
|
||||
|
||||
def _make_runner():
|
||||
"""Minimal GatewayRunner with an active running agent for this session."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = session_entry
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner.session_store.has_any_sessions.return_value = True
|
||||
runner.session_store.append_to_transcript = MagicMock()
|
||||
runner.session_store.rewrite_transcript = MagicMock()
|
||||
runner.session_store.update_session = MagicMock()
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._show_reasoning = False
|
||||
runner._service_tier = None
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner._set_session_env = lambda _context: None
|
||||
runner._should_send_voice_reply = lambda *_args, **_kwargs: False
|
||||
runner._send_voice_reply = AsyncMock()
|
||||
runner._capture_gateway_honcho_if_configured = lambda *args, **kwargs: None
|
||||
runner._emit_gateway_run_progress = AsyncMock()
|
||||
|
||||
# Simulate agent actively running for this session so the guard fires.
|
||||
# Note: the stale-eviction branch calls agent.get_activity_summary() and
|
||||
# compares seconds_since_activity against HERMES_AGENT_TIMEOUT. Return a
|
||||
# dict with recent activity so the eviction path doesn't clear our
|
||||
# fake running agent before the toggle guard runs.
|
||||
import time
|
||||
sk = build_session_key(_make_source())
|
||||
agent_mock = MagicMock()
|
||||
agent_mock.get_activity_summary.return_value = {
|
||||
"seconds_since_activity": 0.0,
|
||||
"last_activity_desc": "api_call",
|
||||
"api_call_count": 1,
|
||||
"max_iterations": 60,
|
||||
}
|
||||
runner._running_agents[sk] = agent_mock
|
||||
runner._running_agents_ts[sk] = time.time()
|
||||
return runner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_yolo_dispatches_mid_run(monkeypatch):
|
||||
"""/yolo mid-run must dispatch to its handler, not hit the catch-all."""
|
||||
runner = _make_runner()
|
||||
runner._handle_yolo_command = AsyncMock(return_value="⚡ YOLO mode **ON** for this session")
|
||||
|
||||
result = await runner._handle_message(_make_event("/yolo"))
|
||||
|
||||
runner._handle_yolo_command.assert_awaited_once()
|
||||
assert result == "⚡ YOLO mode **ON** for this session"
|
||||
assert "can't run mid-turn" not in (result or "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verbose_dispatches_mid_run(monkeypatch):
|
||||
"""/verbose mid-run must dispatch to its handler, not hit the catch-all."""
|
||||
runner = _make_runner()
|
||||
runner._handle_verbose_command = AsyncMock(return_value="tool progress: new")
|
||||
|
||||
result = await runner._handle_message(_make_event("/verbose"))
|
||||
|
||||
runner._handle_verbose_command.assert_awaited_once()
|
||||
assert result == "tool progress: new"
|
||||
assert "can't run mid-turn" not in (result or "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_rejected_mid_run():
|
||||
"""/fast mid-run must hit the busy catch-all — config-only, next message."""
|
||||
runner = _make_runner()
|
||||
runner._handle_fast_command = AsyncMock(
|
||||
side_effect=AssertionError("/fast should not dispatch mid-run")
|
||||
)
|
||||
|
||||
result = await runner._handle_message(_make_event("/fast"))
|
||||
|
||||
runner._handle_fast_command.assert_not_awaited()
|
||||
assert result is not None
|
||||
assert "can't run mid-turn" in result
|
||||
assert "/fast" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_rejected_mid_run():
|
||||
"""/reasoning mid-run must hit the busy catch-all — config-only, next message."""
|
||||
runner = _make_runner()
|
||||
runner._handle_reasoning_command = AsyncMock(
|
||||
side_effect=AssertionError("/reasoning should not dispatch mid-run")
|
||||
)
|
||||
|
||||
result = await runner._handle_message(_make_event("/reasoning high"))
|
||||
|
||||
runner._handle_reasoning_command.assert_not_awaited()
|
||||
assert result is not None
|
||||
assert "can't run mid-turn" in result
|
||||
assert "/reasoning" in result
|
||||
|
|
@ -356,6 +356,28 @@ class TestBuildSessionContextPrompt:
|
|||
assert "**User:** Alice" in prompt
|
||||
assert "Multi-user thread" not in prompt
|
||||
|
||||
def test_shared_non_thread_group_prompt_hides_single_user(self):
|
||||
"""Shared non-thread group sessions should avoid pinning one user."""
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake"),
|
||||
},
|
||||
group_sessions_per_user=False,
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1002285219667",
|
||||
chat_name="Test Group",
|
||||
chat_type="group",
|
||||
user_name="Alice",
|
||||
)
|
||||
ctx = build_session_context(source, config)
|
||||
prompt = build_session_context_prompt(ctx)
|
||||
|
||||
assert "Multi-user session" in prompt
|
||||
assert "[sender name]" in prompt
|
||||
assert "**User:** Alice" not in prompt
|
||||
|
||||
def test_dm_thread_shows_user_not_multi(self):
|
||||
"""DM threads are single-user and should show User, not multi-user note."""
|
||||
config = GatewayConfig(
|
||||
|
|
@ -1037,6 +1059,7 @@ class TestRewriteTranscriptPreservesReasoning:
|
|||
role="assistant",
|
||||
content="The answer is 42.",
|
||||
reasoning="I need to think step by step.",
|
||||
reasoning_content="provider scratchpad",
|
||||
reasoning_details=[{"type": "summary", "text": "step by step"}],
|
||||
codex_reasoning_items=[{"id": "r1", "type": "reasoning"}],
|
||||
)
|
||||
|
|
@ -1044,6 +1067,7 @@ class TestRewriteTranscriptPreservesReasoning:
|
|||
# Verify all three were stored
|
||||
before = db.get_messages_as_conversation(session_id)
|
||||
assert before[0].get("reasoning") == "I need to think step by step."
|
||||
assert before[0].get("reasoning_content") == "provider scratchpad"
|
||||
assert before[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}]
|
||||
assert before[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}]
|
||||
|
||||
|
|
@ -1060,5 +1084,6 @@ class TestRewriteTranscriptPreservesReasoning:
|
|||
# Load again — all three reasoning fields must survive
|
||||
after = db.get_messages_as_conversation(session_id)
|
||||
assert after[0].get("reasoning") == "I need to think step by step."
|
||||
assert after[0].get("reasoning_content") == "provider scratchpad"
|
||||
assert after[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}]
|
||||
assert after[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}]
|
||||
|
|
|
|||
76
tests/gateway/test_session_list_allowed_sources.py
Normal file
76
tests/gateway/test_session_list_allowed_sources.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
"""Regression tests for the TUI gateway's ``session.list`` handler.
|
||||
|
||||
Reported during TUI v2 blitz retest: the ``/resume`` modal inside a TUI
|
||||
session only surfaced ``tui``/``cli`` rows, hiding telegram sessions users
|
||||
could still resume directly via ``hermes --tui --resume <id>``.
|
||||
|
||||
The fix widens the picker to a curated allowlist of user-facing sources
|
||||
(tui/cli + chat adapters) while still filtering internal/system sources.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from tui_gateway import server
|
||||
|
||||
|
||||
class _StubDB:
|
||||
def __init__(self, rows):
|
||||
self.rows = rows
|
||||
self.calls: list[dict] = []
|
||||
|
||||
def list_sessions_rich(self, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
return list(self.rows)
|
||||
|
||||
|
||||
def _call(limit: int = 20):
|
||||
return server.handle_request({
|
||||
"id": "1",
|
||||
"method": "session.list",
|
||||
"params": {"limit": limit},
|
||||
})
|
||||
|
||||
|
||||
def test_session_list_includes_telegram_but_filters_internal_sources(monkeypatch):
|
||||
rows = [
|
||||
{"id": "tui-1", "source": "tui", "started_at": 9},
|
||||
{"id": "tool-1", "source": "tool", "started_at": 8},
|
||||
{"id": "tg-1", "source": "telegram", "started_at": 7},
|
||||
{"id": "acp-1", "source": "acp", "started_at": 6},
|
||||
{"id": "cli-1", "source": "cli", "started_at": 5},
|
||||
]
|
||||
db = _StubDB(rows)
|
||||
monkeypatch.setattr(server, "_get_db", lambda: db)
|
||||
|
||||
resp = _call(limit=10)
|
||||
sessions = resp["result"]["sessions"]
|
||||
ids = [s["id"] for s in sessions]
|
||||
|
||||
assert "tg-1" in ids and "tui-1" in ids and "cli-1" in ids, ids
|
||||
assert "tool-1" not in ids and "acp-1" not in ids, ids
|
||||
|
||||
|
||||
def test_session_list_fetches_wider_window_before_filtering(monkeypatch):
|
||||
db = _StubDB([{"id": "x", "source": "cli", "started_at": 1}])
|
||||
monkeypatch.setattr(server, "_get_db", lambda: db)
|
||||
|
||||
_call(limit=10)
|
||||
|
||||
assert len(db.calls) == 1
|
||||
assert db.calls[0].get("source") is None, db.calls[0]
|
||||
assert db.calls[0].get("limit") == 100, db.calls[0]
|
||||
|
||||
|
||||
def test_session_list_preserves_ordering_after_filter(monkeypatch):
|
||||
rows = [
|
||||
{"id": "newest", "source": "telegram", "started_at": 5},
|
||||
{"id": "internal", "source": "tool", "started_at": 4},
|
||||
{"id": "middle", "source": "tui", "started_at": 3},
|
||||
{"id": "oldest", "source": "discord", "started_at": 1},
|
||||
]
|
||||
monkeypatch.setattr(server, "_get_db", lambda: _StubDB(rows))
|
||||
|
||||
resp = _call()
|
||||
ids = [s["id"] for s in resp["result"]["sessions"]]
|
||||
|
||||
assert ids == ["newest", "middle", "oldest"]
|
||||
|
|
@ -24,10 +24,18 @@ class _FakeAdapter:
|
|||
|
||||
def __init__(self):
|
||||
self._pending_messages = {}
|
||||
self._active_sessions = {}
|
||||
self.interrupted_sessions = []
|
||||
|
||||
async def send(self, chat_id, text, **kwargs):
|
||||
pass
|
||||
|
||||
async def interrupt_session_activity(self, session_key, chat_id):
|
||||
self.interrupted_sessions.append((session_key, chat_id))
|
||||
event = self._active_sessions.get(session_key)
|
||||
if event is not None:
|
||||
event.set()
|
||||
|
||||
|
||||
def _make_runner():
|
||||
runner = object.__new__(GatewayRunner)
|
||||
|
|
@ -37,6 +45,7 @@ def _make_runner():
|
|||
runner.adapters = {Platform.TELEGRAM: _FakeAdapter()}
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._session_run_generation = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._voice_mode = {}
|
||||
|
|
@ -81,7 +90,7 @@ async def test_sentinel_placed_before_agent_setup():
|
|||
# Patch _handle_message_with_agent to capture state at entry
|
||||
sentinel_was_set = False
|
||||
|
||||
async def mock_inner(self_inner, ev, src, qk):
|
||||
async def mock_inner(self_inner, ev, src, qk, generation):
|
||||
nonlocal sentinel_was_set
|
||||
sentinel_was_set = runner._running_agents.get(qk) is _AGENT_PENDING_SENTINEL
|
||||
return "ok"
|
||||
|
|
@ -105,7 +114,7 @@ async def test_sentinel_cleaned_up_after_handler_returns():
|
|||
event = _make_event()
|
||||
session_key = build_session_key(event.source)
|
||||
|
||||
async def mock_inner(self_inner, ev, src, qk):
|
||||
async def mock_inner(self_inner, ev, src, qk, generation):
|
||||
return "ok"
|
||||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
|
||||
|
|
@ -127,7 +136,7 @@ async def test_sentinel_cleaned_up_on_exception():
|
|||
event = _make_event()
|
||||
session_key = build_session_key(event.source)
|
||||
|
||||
async def mock_inner(self_inner, ev, src, qk):
|
||||
async def mock_inner(self_inner, ev, src, qk, generation):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
|
||||
|
|
@ -154,7 +163,7 @@ async def test_second_message_during_sentinel_queued_not_duplicate():
|
|||
|
||||
barrier = asyncio.Event()
|
||||
|
||||
async def slow_inner(self_inner, ev, src, qk):
|
||||
async def slow_inner(self_inner, ev, src, qk, generation):
|
||||
# Simulate slow setup — wait until test tells us to proceed
|
||||
await barrier.wait()
|
||||
return "ok"
|
||||
|
|
@ -333,7 +342,7 @@ async def test_stop_during_sentinel_force_cleans_session():
|
|||
|
||||
barrier = asyncio.Event()
|
||||
|
||||
async def slow_inner(self_inner, ev, src, qk):
|
||||
async def slow_inner(self_inner, ev, src, qk, generation):
|
||||
await barrier.wait()
|
||||
return "ok"
|
||||
|
||||
|
|
@ -381,6 +390,7 @@ async def test_stop_hard_kills_running_agent():
|
|||
fake_agent = MagicMock()
|
||||
fake_agent.get_activity_summary.return_value = {"seconds_since_activity": 0}
|
||||
runner._running_agents[session_key] = fake_agent
|
||||
runner.adapters[Platform.TELEGRAM]._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
# Send /stop
|
||||
stop_event = _make_event(text="/stop")
|
||||
|
|
@ -393,6 +403,10 @@ async def test_stop_hard_kills_running_agent():
|
|||
assert session_key not in runner._running_agents, (
|
||||
"/stop must remove the agent from _running_agents so the session is unlocked"
|
||||
)
|
||||
assert runner.adapters[Platform.TELEGRAM].interrupted_sessions == [
|
||||
(session_key, "12345")
|
||||
]
|
||||
assert runner.adapters[Platform.TELEGRAM]._active_sessions[session_key].is_set()
|
||||
|
||||
# Must return a confirmation
|
||||
assert result is not None
|
||||
|
|
|
|||
|
|
@ -117,11 +117,20 @@ class TestPruneBasics:
|
|||
assert "idle" not in store._entries
|
||||
|
||||
def test_prune_skips_entries_with_active_processes(self, tmp_path):
|
||||
"""Sessions with active bg processes aren't pruned even if old."""
|
||||
active_session_ids = {"sid_active"}
|
||||
"""Sessions with active bg processes aren't pruned even if old.
|
||||
|
||||
def _has_active(session_id: str) -> bool:
|
||||
return session_id in active_session_ids
|
||||
The callback is keyed by session_key — matching what
|
||||
process_registry.has_active_for_session() actually consumes in
|
||||
gateway/run.py. Prior to the fix this test passed the callback a
|
||||
session_id, which silently matched an implementation bug where
|
||||
prune_old_entries was also passing session_id; real-world usage
|
||||
(via process_registry) takes a session_key and never matched, so
|
||||
active sessions were still being pruned.
|
||||
"""
|
||||
active_session_keys = {"active"}
|
||||
|
||||
def _has_active(session_key: str) -> bool:
|
||||
return session_key in active_session_keys
|
||||
|
||||
store = _make_store(tmp_path, has_active_processes_fn=_has_active)
|
||||
store._entries["active"] = _entry(
|
||||
|
|
@ -137,6 +146,26 @@ class TestPruneBasics:
|
|||
assert "active" in store._entries
|
||||
assert "idle" not in store._entries
|
||||
|
||||
def test_prune_active_check_uses_session_key_not_session_id(self, tmp_path):
|
||||
"""Regression guard: a callback that only recognises session_ids must
|
||||
NOT protect entries during prune. This pins the fix so a future
|
||||
refactor can't silently revert to passing session_id again.
|
||||
"""
|
||||
def _recognises_only_ids(identifier: str) -> bool:
|
||||
return identifier.startswith("sid_")
|
||||
|
||||
store = _make_store(tmp_path, has_active_processes_fn=_recognises_only_ids)
|
||||
store._entries["active"] = _entry(
|
||||
"active", age_days=1000, session_id="sid_active"
|
||||
)
|
||||
|
||||
removed = store.prune_old_entries(max_age_days=90)
|
||||
|
||||
# Entry is pruned because the callback receives "active" (session_key),
|
||||
# not "sid_active" (session_id), so _recognises_only_ids returns False.
|
||||
assert removed == 1
|
||||
assert "active" not in store._entries
|
||||
|
||||
def test_prune_does_not_write_disk_when_no_removals(self, tmp_path):
|
||||
"""If nothing is evictable, _save() should NOT be called."""
|
||||
store = _make_store(tmp_path)
|
||||
|
|
|
|||
70
tests/gateway/test_shared_group_sender_prefix.py
Normal file
70
tests/gateway/test_shared_group_sender_prefix.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def _make_runner(config: GatewayConfig) -> GatewayRunner:
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = config
|
||||
runner.adapters = {}
|
||||
runner._model = "openai/gpt-4.1-mini"
|
||||
runner._base_url = None
|
||||
return runner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preprocess_prefixes_sender_for_shared_non_thread_group_session():
|
||||
runner = _make_runner(
|
||||
GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake"),
|
||||
},
|
||||
group_sessions_per_user=False,
|
||||
)
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1002285219667",
|
||||
chat_name="Test Group",
|
||||
chat_type="group",
|
||||
user_name="Alice",
|
||||
)
|
||||
event = MessageEvent(text="hello", source=source)
|
||||
|
||||
result = await runner._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=[],
|
||||
)
|
||||
|
||||
assert result == "[Alice] hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preprocess_keeps_plain_text_for_default_group_sessions():
|
||||
runner = _make_runner(
|
||||
GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake"),
|
||||
},
|
||||
)
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1002285219667",
|
||||
chat_name="Test Group",
|
||||
chat_type="group",
|
||||
user_name="Alice",
|
||||
)
|
||||
event = MessageEvent(text="hello", source=source)
|
||||
|
||||
result = await runner._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=[],
|
||||
)
|
||||
|
||||
assert result == "hello"
|
||||
|
|
@ -91,6 +91,29 @@ class TestSignalAdapterInit:
|
|||
assert adapter._account_normalized == "+15551234567"
|
||||
|
||||
|
||||
class TestSignalConnectCleanup:
|
||||
"""Regression coverage for failed connect() cleanup."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_releases_lock_and_closes_client_on_healthcheck_failure(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=MagicMock(status_code=503))
|
||||
mock_client.aclose = AsyncMock()
|
||||
|
||||
with patch("gateway.platforms.signal.httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("gateway.status.acquire_scoped_lock", return_value=(True, None)), \
|
||||
patch("gateway.status.release_scoped_lock") as mock_release:
|
||||
result = await adapter.connect()
|
||||
|
||||
assert result is False
|
||||
mock_client.aclose.assert_awaited_once()
|
||||
mock_release.assert_called_once_with("signal-phone", "+15551234567")
|
||||
assert adapter.client is None
|
||||
assert adapter._platform_lock_identity is None
|
||||
|
||||
|
||||
class TestSignalHelpers:
|
||||
def test_redact_phone_long(self):
|
||||
from gateway.platforms.helpers import redact_phone
|
||||
|
|
@ -283,7 +306,13 @@ class TestSignalSessionSource:
|
|||
class TestSignalPhoneRedaction:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _ensure_redaction_enabled(self, monkeypatch):
|
||||
# agent.redact snapshots _REDACT_ENABLED at import time from the
|
||||
# HERMES_REDACT_SECRETS env var. monkeypatch.delenv is too late —
|
||||
# the module was already imported during test collection with
|
||||
# whatever value was in the env then. Force the flag directly.
|
||||
# See skill: xdist-cross-test-pollution Pattern 5.
|
||||
monkeypatch.delenv("HERMES_REDACT_SECRETS", raising=False)
|
||||
monkeypatch.setattr("agent.redact._REDACT_ENABLED", True)
|
||||
|
||||
def test_us_number(self):
|
||||
from agent.redact import redact_sensitive_text
|
||||
|
|
@ -438,6 +467,97 @@ class TestSignalSendImageFile:
|
|||
assert "failed" in result.error.lower()
|
||||
|
||||
|
||||
class TestSignalRecipientResolution:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_prefers_cached_uuid_for_direct_messages(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
adapter._stop_typing_indicator = AsyncMock()
|
||||
adapter._remember_recipient_identifiers("+15551230000", "68680952-6d86-45bc-85e0-1a4d186d53ee")
|
||||
|
||||
captured = []
|
||||
|
||||
async def mock_rpc(method, params, rpc_id=None, **kwargs):
|
||||
captured.append({"method": method, "params": dict(params)})
|
||||
return {"timestamp": 1234567890}
|
||||
|
||||
adapter._rpc = mock_rpc
|
||||
|
||||
result = await adapter.send(chat_id="+15551230000", content="hello")
|
||||
|
||||
assert result.success is True
|
||||
assert captured[0]["method"] == "send"
|
||||
assert captured[0]["params"]["recipient"] == ["68680952-6d86-45bc-85e0-1a4d186d53ee"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_looks_up_uuid_via_list_contacts(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
adapter._stop_typing_indicator = AsyncMock()
|
||||
|
||||
captured = []
|
||||
|
||||
async def mock_rpc(method, params, rpc_id=None, **kwargs):
|
||||
captured.append({"method": method, "params": dict(params)})
|
||||
if method == "listContacts":
|
||||
return [{
|
||||
"recipient": "351935789098",
|
||||
"number": "+15551230000",
|
||||
"uuid": "68680952-6d86-45bc-85e0-1a4d186d53ee",
|
||||
"isRegistered": True,
|
||||
}]
|
||||
if method == "send":
|
||||
return {"timestamp": 1234567890}
|
||||
return None
|
||||
|
||||
adapter._rpc = mock_rpc
|
||||
|
||||
result = await adapter.send(chat_id="+15551230000", content="hello")
|
||||
|
||||
assert result.success is True
|
||||
assert captured[0]["method"] == "listContacts"
|
||||
assert captured[1]["method"] == "send"
|
||||
assert captured[1]["params"]["recipient"] == ["68680952-6d86-45bc-85e0-1a4d186d53ee"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_falls_back_to_phone_when_no_uuid_found(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
adapter._stop_typing_indicator = AsyncMock()
|
||||
|
||||
captured = []
|
||||
|
||||
async def mock_rpc(method, params, rpc_id=None, **kwargs):
|
||||
captured.append({"method": method, "params": dict(params)})
|
||||
if method == "listContacts":
|
||||
return []
|
||||
if method == "send":
|
||||
return {"timestamp": 1234567890}
|
||||
return None
|
||||
|
||||
adapter._rpc = mock_rpc
|
||||
|
||||
result = await adapter.send(chat_id="+15551230000", content="hello")
|
||||
|
||||
assert result.success is True
|
||||
assert captured[1]["params"]["recipient"] == ["+15551230000"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_typing_uses_cached_uuid(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
adapter._remember_recipient_identifiers("+15551230000", "68680952-6d86-45bc-85e0-1a4d186d53ee")
|
||||
|
||||
captured = []
|
||||
|
||||
async def mock_rpc(method, params, rpc_id=None, **kwargs):
|
||||
captured.append({"method": method, "params": dict(params), "rpc_id": rpc_id})
|
||||
return {}
|
||||
|
||||
adapter._rpc = mock_rpc
|
||||
|
||||
await adapter.send_typing("+15551230000")
|
||||
|
||||
assert captured[0]["method"] == "sendTyping"
|
||||
assert captured[0]["params"]["recipient"] == ["68680952-6d86-45bc-85e0-1a4d186d53ee"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send_voice method (#5105)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -150,6 +150,31 @@ class TestAppMentionHandler:
|
|||
assert "/hermes" in registered_commands
|
||||
|
||||
|
||||
class TestSlackConnectCleanup:
|
||||
"""Regression coverage for failed connect() cleanup."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_releases_platform_lock_when_auth_fails(self):
|
||||
config = PlatformConfig(enabled=True, token="xoxb-fake")
|
||||
adapter = SlackAdapter(config)
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_web_client = AsyncMock()
|
||||
mock_web_client.auth_test = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
with patch.object(_slack_mod, "AsyncApp", return_value=mock_app), \
|
||||
patch.object(_slack_mod, "AsyncWebClient", return_value=mock_web_client), \
|
||||
patch.object(_slack_mod, "AsyncSocketModeHandler", return_value=MagicMock()), \
|
||||
patch.dict(os.environ, {"SLACK_APP_TOKEN": "xapp-fake"}), \
|
||||
patch("gateway.status.acquire_scoped_lock", return_value=(True, None)), \
|
||||
patch("gateway.status.release_scoped_lock") as mock_release:
|
||||
result = await adapter.connect()
|
||||
|
||||
assert result is False
|
||||
mock_release.assert_called_once_with("slack-app-token", "xapp-fake")
|
||||
assert adapter._platform_lock_identity is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendDocument
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -1006,7 +1031,7 @@ class TestReactions:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reactions_in_message_flow(self, adapter):
|
||||
"""Reactions should be added on receipt and swapped on completion."""
|
||||
"""Reactions should be bracketed around actual processing via hooks."""
|
||||
adapter._app.client.reactions_add = AsyncMock()
|
||||
adapter._app.client.reactions_remove = AsyncMock()
|
||||
adapter._app.client.users_info = AsyncMock(return_value={
|
||||
|
|
@ -1022,15 +1047,147 @@ class TestReactions:
|
|||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
# Should have added 👀, then removed 👀, then added ✅
|
||||
# _handle_slack_message should register the message for reactions
|
||||
assert "1234567890.000001" in adapter._reacting_message_ids
|
||||
|
||||
# Simulate the base class calling on_processing_start
|
||||
from gateway.platforms.base import MessageEvent, MessageType, SessionSource
|
||||
from gateway.config import Platform
|
||||
source = SessionSource(
|
||||
platform=Platform.SLACK,
|
||||
chat_id="C123",
|
||||
chat_type="dm",
|
||||
user_id="U_USER",
|
||||
)
|
||||
msg_event = MessageEvent(
|
||||
text="hello",
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
message_id="1234567890.000001",
|
||||
)
|
||||
await adapter.on_processing_start(msg_event)
|
||||
|
||||
add_calls = adapter._app.client.reactions_add.call_args_list
|
||||
assert len(add_calls) == 1
|
||||
assert add_calls[0].kwargs["name"] == "eyes"
|
||||
|
||||
# Simulate the base class calling on_processing_complete
|
||||
from gateway.platforms.base import ProcessingOutcome
|
||||
await adapter.on_processing_complete(msg_event, ProcessingOutcome.SUCCESS)
|
||||
|
||||
add_calls = adapter._app.client.reactions_add.call_args_list
|
||||
remove_calls = adapter._app.client.reactions_remove.call_args_list
|
||||
assert len(add_calls) == 2
|
||||
assert add_calls[0].kwargs["name"] == "eyes"
|
||||
assert add_calls[1].kwargs["name"] == "white_check_mark"
|
||||
assert len(remove_calls) == 1
|
||||
assert remove_calls[0].kwargs["name"] == "eyes"
|
||||
|
||||
# Message ID should be cleaned up
|
||||
assert "1234567890.000001" not in adapter._reacting_message_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reactions_failure_outcome(self, adapter):
|
||||
"""Failed processing should add :x: instead of :white_check_mark:."""
|
||||
adapter._app.client.reactions_add = AsyncMock()
|
||||
adapter._app.client.reactions_remove = AsyncMock()
|
||||
|
||||
from gateway.platforms.base import MessageEvent, MessageType, SessionSource, ProcessingOutcome
|
||||
from gateway.config import Platform
|
||||
source = SessionSource(
|
||||
platform=Platform.SLACK,
|
||||
chat_id="C123",
|
||||
chat_type="dm",
|
||||
user_id="U_USER",
|
||||
)
|
||||
adapter._reacting_message_ids.add("1234567890.000002")
|
||||
msg_event = MessageEvent(
|
||||
text="hello",
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
message_id="1234567890.000002",
|
||||
)
|
||||
await adapter.on_processing_complete(msg_event, ProcessingOutcome.FAILURE)
|
||||
|
||||
add_calls = adapter._app.client.reactions_add.call_args_list
|
||||
remove_calls = adapter._app.client.reactions_remove.call_args_list
|
||||
assert len(add_calls) == 1
|
||||
assert add_calls[0].kwargs["name"] == "x"
|
||||
assert len(remove_calls) == 1
|
||||
assert remove_calls[0].kwargs["name"] == "eyes"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reactions_skipped_for_non_dm_non_mention(self, adapter):
|
||||
"""Non-DM, non-mention messages should not get reactions."""
|
||||
adapter._app.client.reactions_add = AsyncMock()
|
||||
adapter._app.client.reactions_remove = AsyncMock()
|
||||
adapter._app.client.users_info = AsyncMock(return_value={
|
||||
"user": {"profile": {"display_name": "Tyler"}}
|
||||
})
|
||||
|
||||
event = {
|
||||
"text": "hello",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": "channel",
|
||||
"ts": "1234567890.000003",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
# Should NOT register for reactions when not mentioned in a channel
|
||||
assert "1234567890.000003" not in adapter._reacting_message_ids
|
||||
adapter._app.client.reactions_add.assert_not_called()
|
||||
adapter._app.client.reactions_remove.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reactions_disabled_via_env(self, adapter, monkeypatch):
|
||||
"""SLACK_REACTIONS=false should suppress all reaction lifecycle."""
|
||||
monkeypatch.setenv("SLACK_REACTIONS", "false")
|
||||
adapter._app.client.reactions_add = AsyncMock()
|
||||
adapter._app.client.reactions_remove = AsyncMock()
|
||||
adapter._app.client.users_info = AsyncMock(return_value={
|
||||
"user": {"profile": {"display_name": "Tyler"}}
|
||||
})
|
||||
|
||||
event = {
|
||||
"text": "hello",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": "im",
|
||||
"ts": "1234567890.000004",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
# Should NOT register for reactions when toggle is off
|
||||
assert "1234567890.000004" not in adapter._reacting_message_ids
|
||||
|
||||
# Hooks should also be no-ops when disabled
|
||||
from gateway.platforms.base import MessageEvent, MessageType, SessionSource, ProcessingOutcome
|
||||
from gateway.config import Platform
|
||||
source = SessionSource(
|
||||
platform=Platform.SLACK,
|
||||
chat_id="C123",
|
||||
chat_type="dm",
|
||||
user_id="U_USER",
|
||||
)
|
||||
msg_event = MessageEvent(
|
||||
text="hello",
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
message_id="1234567890.000004",
|
||||
)
|
||||
# Force-add to verify hooks respect the toggle independently
|
||||
adapter._reacting_message_ids.add("1234567890.000004")
|
||||
await adapter.on_processing_start(msg_event)
|
||||
await adapter.on_processing_complete(msg_event, ProcessingOutcome.SUCCESS)
|
||||
|
||||
adapter._app.client.reactions_add.assert_not_called()
|
||||
adapter._app.client.reactions_remove.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reactions_enabled_by_default(self, adapter):
|
||||
"""SLACK_REACTIONS defaults to true (matches existing behavior)."""
|
||||
assert adapter._reactions_enabled() is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestThreadReplyHandling
|
||||
|
|
|
|||
|
|
@ -19,6 +19,30 @@ class TestGatewayPidState:
|
|||
assert isinstance(payload["argv"], list)
|
||||
assert payload["argv"]
|
||||
|
||||
def test_write_pid_file_is_atomic_against_concurrent_writers(self, tmp_path, monkeypatch):
|
||||
"""Regression: two concurrent --replace invocations must not both win.
|
||||
|
||||
Without O_CREAT|O_EXCL, two processes racing through start_gateway()'s
|
||||
termination-wait would both write to gateway.pid, silently overwriting
|
||||
each other and leaving multiple gateway instances alive (#11718).
|
||||
"""
|
||||
import pytest
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
# First write wins.
|
||||
status.write_pid_file()
|
||||
assert (tmp_path / "gateway.pid").exists()
|
||||
|
||||
# Second write (simulating a racing --replace that missed the earlier
|
||||
# guards) must raise FileExistsError rather than clobber the record.
|
||||
with pytest.raises(FileExistsError):
|
||||
status.write_pid_file()
|
||||
|
||||
# Original record is preserved.
|
||||
payload = json.loads((tmp_path / "gateway.pid").read_text())
|
||||
assert payload["pid"] == os.getpid()
|
||||
|
||||
def test_get_running_pid_rejects_live_non_gateway_pid(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
pid_path = tmp_path / "gateway.pid"
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ def _make_runner(session_entry: SessionEntry):
|
|||
runner.session_store.rewrite_transcript = MagicMock()
|
||||
runner.session_store.update_session = MagicMock()
|
||||
runner._running_agents = {}
|
||||
runner._session_run_generation = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = MagicMock()
|
||||
|
|
@ -223,6 +224,121 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_discards_stale_result_after_session_invalidation(monkeypatch):
|
||||
import gateway.run as gateway_run
|
||||
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
|
||||
session_key = session_entry.session_key
|
||||
runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks = {session_key: object()}
|
||||
|
||||
async def _stale_result(**kwargs):
|
||||
runner._invalidate_session_run_generation(kwargs["session_key"], reason="test_stale_result")
|
||||
return {
|
||||
"final_response": "late reply",
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"history_offset": 0,
|
||||
"last_prompt_tokens": 80,
|
||||
"input_tokens": 120,
|
||||
"output_tokens": 45,
|
||||
"model": "openai/test-model",
|
||||
}
|
||||
|
||||
runner._run_agent = AsyncMock(side_effect=_stale_result)
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
monkeypatch.setattr(
|
||||
"agent.model_metadata.get_model_context_length",
|
||||
lambda *_args, **_kwargs: 100000,
|
||||
)
|
||||
|
||||
result = await runner._handle_message(_make_event("hello"))
|
||||
|
||||
assert result is None
|
||||
runner.session_store.append_to_transcript.assert_not_called()
|
||||
runner.session_store.update_session.assert_not_called()
|
||||
assert session_key not in runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_stale_result_keeps_newer_generation_callback(monkeypatch):
|
||||
import gateway.run as gateway_run
|
||||
|
||||
class _Adapter:
|
||||
def __init__(self):
|
||||
self._post_delivery_callbacks = {}
|
||||
|
||||
async def send(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
def pop_post_delivery_callback(self, session_key, *, generation=None):
|
||||
entry = self._post_delivery_callbacks.get(session_key)
|
||||
if entry is None:
|
||||
return None
|
||||
if isinstance(entry, tuple):
|
||||
entry_generation, callback = entry
|
||||
if generation is not None and entry_generation != generation:
|
||||
return None
|
||||
self._post_delivery_callbacks.pop(session_key, None)
|
||||
return callback
|
||||
if generation is not None:
|
||||
return None
|
||||
return self._post_delivery_callbacks.pop(session_key, None)
|
||||
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
|
||||
session_key = session_entry.session_key
|
||||
adapter = _Adapter()
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
|
||||
async def _stale_result(**kwargs):
|
||||
# Simulate a newer run claiming the callback slot before the stale run unwinds.
|
||||
runner._session_run_generation[session_key] = 2
|
||||
adapter._post_delivery_callbacks[session_key] = (2, lambda: None)
|
||||
return {
|
||||
"final_response": "late reply",
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"history_offset": 0,
|
||||
"last_prompt_tokens": 80,
|
||||
"input_tokens": 120,
|
||||
"output_tokens": 45,
|
||||
"model": "openai/test-model",
|
||||
}
|
||||
|
||||
runner._run_agent = AsyncMock(side_effect=_stale_result)
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
monkeypatch.setattr(
|
||||
"agent.model_metadata.get_model_context_length",
|
||||
lambda *_args, **_kwargs: 100000,
|
||||
)
|
||||
|
||||
result = await runner._handle_message(_make_event("hello"))
|
||||
|
||||
assert result is None
|
||||
assert session_key in adapter._post_delivery_callbacks
|
||||
assert adapter._post_delivery_callbacks[session_key][0] == 2
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_command_bypasses_active_session_guard():
|
||||
|
|
|
|||
|
|
@ -133,6 +133,43 @@ class TestFinalizeCapabilityGate:
|
|||
assert picky.edit_message.call_args[1]["finalize"] is True
|
||||
|
||||
|
||||
class TestEditMessageFinalizeSignature:
|
||||
"""Every concrete platform adapter must accept the ``finalize`` kwarg.
|
||||
|
||||
stream_consumer._send_or_edit always passes ``finalize=`` to
|
||||
``adapter.edit_message(...)`` (see gateway/stream_consumer.py). An
|
||||
adapter that overrides edit_message without accepting finalize raises
|
||||
TypeError the first time streaming hits a segment break or final edit.
|
||||
Guard the contract with an explicit signature check so it cannot
|
||||
silently regress — existing tests use MagicMock which swallows any
|
||||
kwarg and cannot catch this.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"module_path,class_name",
|
||||
[
|
||||
("gateway.platforms.telegram", "TelegramAdapter"),
|
||||
("gateway.platforms.discord", "DiscordAdapter"),
|
||||
("gateway.platforms.slack", "SlackAdapter"),
|
||||
("gateway.platforms.matrix", "MatrixAdapter"),
|
||||
("gateway.platforms.mattermost", "MattermostAdapter"),
|
||||
("gateway.platforms.feishu", "FeishuAdapter"),
|
||||
("gateway.platforms.whatsapp", "WhatsAppAdapter"),
|
||||
("gateway.platforms.dingtalk", "DingTalkAdapter"),
|
||||
],
|
||||
)
|
||||
def test_edit_message_accepts_finalize(self, module_path, class_name):
|
||||
import inspect
|
||||
|
||||
module = pytest.importorskip(module_path)
|
||||
cls = getattr(module, class_name)
|
||||
params = inspect.signature(cls.edit_message).parameters
|
||||
assert "finalize" in params, (
|
||||
f"{class_name}.edit_message must accept 'finalize' kwarg; "
|
||||
f"stream_consumer._send_or_edit passes it unconditionally"
|
||||
)
|
||||
|
||||
|
||||
class TestSendOrEditMediaStripping:
|
||||
"""Verify _send_or_edit strips MEDIA: before sending to the platform."""
|
||||
|
||||
|
|
@ -502,11 +539,13 @@ class TestSegmentBreakOnToolBoundary:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_segment_break_clears_failed_edit_fallback_state(self):
|
||||
"""A tool boundary after edit failure must not duplicate the next segment."""
|
||||
"""A tool boundary after edit failure must flush the undelivered tail
|
||||
without duplicating the prefix the user already saw (#8124)."""
|
||||
adapter = MagicMock()
|
||||
send_results = [
|
||||
SimpleNamespace(success=True, message_id="msg_1"),
|
||||
SimpleNamespace(success=True, message_id="msg_2"),
|
||||
SimpleNamespace(success=True, message_id="msg_3"),
|
||||
]
|
||||
adapter.send = AsyncMock(side_effect=send_results)
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=False, error="flood_control:6"))
|
||||
|
|
@ -526,7 +565,60 @@ class TestSegmentBreakOnToolBoundary:
|
|||
await task
|
||||
|
||||
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
|
||||
assert sent_texts == ["Hello ▉", "Next segment"]
|
||||
# The undelivered "world" tail must reach the user, and the next
|
||||
# segment must not duplicate "Hello" that was already visible.
|
||||
assert sent_texts == ["Hello ▉", "world", "Next segment"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_segment_break_after_mid_stream_edit_failure_preserves_tail(self):
|
||||
"""Regression for #8124: when an earlier edit succeeded but later edits
|
||||
fail (persistent flood control) and a tool boundary arrives before the
|
||||
fallback threshold is reached, the pre-boundary tail must still be
|
||||
delivered — not silently dropped by the segment reset."""
|
||||
adapter = MagicMock()
|
||||
# msg_1 for the initial partial, msg_2 for the flushed tail,
|
||||
# msg_3 for the post-boundary segment.
|
||||
send_results = [
|
||||
SimpleNamespace(success=True, message_id="msg_1"),
|
||||
SimpleNamespace(success=True, message_id="msg_2"),
|
||||
SimpleNamespace(success=True, message_id="msg_3"),
|
||||
]
|
||||
adapter.send = AsyncMock(side_effect=send_results)
|
||||
|
||||
# First two edits succeed, everything after fails with flood control
|
||||
# — simulating Telegram's "edit once then get rate-limited" pattern.
|
||||
edit_results = [
|
||||
SimpleNamespace(success=True), # "Hello world ▉" — succeeds
|
||||
SimpleNamespace(success=False, error="flood_control:6.0"), # "Hello world more ▉" — flood triggered
|
||||
SimpleNamespace(success=False, error="flood_control:6.0"), # finalize edit at segment break
|
||||
SimpleNamespace(success=False, error="flood_control:6.0"), # cursor-strip attempt
|
||||
]
|
||||
adapter.edit_message = AsyncMock(side_effect=edit_results + [edit_results[-1]] * 10)
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉")
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
|
||||
|
||||
consumer.on_delta("Hello")
|
||||
task = asyncio.create_task(consumer.run())
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(" world")
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(" more")
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(None) # tool boundary
|
||||
consumer.on_delta("Here is the tool result.")
|
||||
consumer.finish()
|
||||
await task
|
||||
|
||||
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
|
||||
# "more" must have been delivered, not dropped.
|
||||
all_text = " ".join(sent_texts)
|
||||
assert "more" in all_text, (
|
||||
f"Pre-boundary tail 'more' was silently dropped: sends={sent_texts}"
|
||||
)
|
||||
# Post-boundary text must also reach the user.
|
||||
assert "Here is the tool result." in all_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_message_id_enters_fallback_mode(self):
|
||||
|
|
@ -1161,3 +1253,87 @@ class TestBufferOnlyMode:
|
|||
# text, the consumer may send then edit, or just send once at got_done.
|
||||
# The key assertion: this doesn't break.
|
||||
assert adapter.send.call_count >= 1
|
||||
|
||||
|
||||
# ── Cursor stripping on fallback (#7183) ────────────────────────────────────
|
||||
|
||||
|
||||
class TestCursorStrippingOnFallback:
|
||||
"""Regression: cursor must be stripped when fallback continuation is empty (#7183).
|
||||
|
||||
When _send_fallback_final is called with nothing new to deliver (the visible
|
||||
partial already matches final_text), the last edit may still show the cursor
|
||||
character because fallback mode was entered after a failed edit. Before the
|
||||
fix this would leave the message permanently frozen with a visible ▉.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cursor_stripped_when_continuation_empty(self):
|
||||
"""_send_fallback_final must attempt a final edit to strip the cursor."""
|
||||
adapter = MagicMock()
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
adapter.edit_message = AsyncMock(
|
||||
return_value=SimpleNamespace(success=True, message_id="msg-1")
|
||||
)
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter, "chat-1",
|
||||
config=StreamConsumerConfig(cursor=" ▉"),
|
||||
)
|
||||
consumer._message_id = "msg-1"
|
||||
consumer._last_sent_text = "Hello world ▉"
|
||||
consumer._fallback_final_send = False
|
||||
|
||||
await consumer._send_fallback_final("Hello world")
|
||||
|
||||
adapter.edit_message.assert_called_once()
|
||||
call_args = adapter.edit_message.call_args
|
||||
assert call_args.kwargs["content"] == "Hello world"
|
||||
assert consumer._already_sent is True
|
||||
# _last_sent_text should reflect the cleaned text after a successful strip
|
||||
assert consumer._last_sent_text == "Hello world"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cursor_not_stripped_when_no_cursor_configured(self):
|
||||
"""No edit attempted when cursor is not configured."""
|
||||
adapter = MagicMock()
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
adapter.edit_message = AsyncMock()
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter, "chat-1",
|
||||
config=StreamConsumerConfig(cursor=""),
|
||||
)
|
||||
consumer._message_id = "msg-1"
|
||||
consumer._last_sent_text = "Hello world"
|
||||
consumer._fallback_final_send = False
|
||||
|
||||
await consumer._send_fallback_final("Hello world")
|
||||
|
||||
adapter.edit_message.assert_not_called()
|
||||
assert consumer._already_sent is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cursor_strip_edit_failure_handled(self):
|
||||
"""If the cursor-stripping edit itself fails, it must not crash and
|
||||
must not corrupt _last_sent_text."""
|
||||
adapter = MagicMock()
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
adapter.edit_message = AsyncMock(
|
||||
return_value=SimpleNamespace(success=False, error="flood_control")
|
||||
)
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter, "chat-1",
|
||||
config=StreamConsumerConfig(cursor=" ▉"),
|
||||
)
|
||||
consumer._message_id = "msg-1"
|
||||
consumer._last_sent_text = "Hello ▉"
|
||||
consumer._fallback_final_send = False
|
||||
|
||||
await consumer._send_fallback_final("Hello")
|
||||
|
||||
# Should still set already_sent despite the cursor-strip edit failure
|
||||
assert consumer._already_sent is True
|
||||
# _last_sent_text must NOT be updated when the edit failed
|
||||
assert consumer._last_sent_text == "Hello ▉"
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from gateway.platforms.base import (
|
|||
MessageType,
|
||||
SendResult,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
SUPPORTED_VIDEO_TYPES,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -117,6 +118,12 @@ def _make_update(msg):
|
|||
return update
|
||||
|
||||
|
||||
def _make_video(file_obj=None):
|
||||
video = MagicMock()
|
||||
video.get_file = AsyncMock(return_value=file_obj or _make_file_obj(b"video-bytes"))
|
||||
return video
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -132,10 +139,13 @@ def adapter():
|
|||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _redirect_cache(tmp_path, monkeypatch):
|
||||
"""Point document cache to tmp_path so tests don't touch ~/.hermes."""
|
||||
"""Point document/video cache to tmp_path so tests don't touch ~/.hermes."""
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.VIDEO_CACHE_DIR", tmp_path / "video_cache"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -348,6 +358,37 @@ class TestDocumentDownloadBlock:
|
|||
adapter.handle_message.assert_called_once()
|
||||
|
||||
|
||||
class TestVideoDownloadBlock:
|
||||
@pytest.mark.asyncio
|
||||
async def test_native_video_is_cached(self, adapter):
|
||||
file_obj = _make_file_obj(b"fake-mp4")
|
||||
file_obj.file_path = "videos/clip.mp4"
|
||||
msg = _make_message()
|
||||
msg.video = _make_video(file_obj)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.message_type == MessageType.VIDEO
|
||||
assert len(event.media_urls) == 1
|
||||
assert os.path.exists(event.media_urls[0])
|
||||
assert event.media_types == [SUPPORTED_VIDEO_TYPES[".mp4"]]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mp4_document_is_treated_as_video(self, adapter):
|
||||
file_obj = _make_file_obj(b"fake-mp4-doc")
|
||||
doc = _make_document(file_name="good.mp4", mime_type="video/mp4", file_size=1024, file_obj=file_obj)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.message_type == MessageType.VIDEO
|
||||
assert len(event.media_urls) == 1
|
||||
assert os.path.exists(event.media_urls[0])
|
||||
assert event.media_types == [SUPPORTED_VIDEO_TYPES[".mp4"]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMediaGroups — media group (album) buffering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -483,6 +524,32 @@ class TestSendDocument:
|
|||
assert "not found" in result.error.lower()
|
||||
connected_adapter._bot.send_document.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_workspace_path_has_docker_hint(self, connected_adapter):
|
||||
"""Container-local-looking paths get a more actionable Docker hint."""
|
||||
result = await connected_adapter.send_document(
|
||||
chat_id="12345",
|
||||
file_path="/workspace/report.txt",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "docker sandbox" in result.error.lower()
|
||||
assert "host-visible path" in result.error.lower()
|
||||
connected_adapter._bot.send_document.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_outputs_path_has_docker_hint(self, connected_adapter):
|
||||
"""Legacy /outputs paths also get the Docker hint."""
|
||||
result = await connected_adapter.send_document(
|
||||
chat_id="12345",
|
||||
file_path="/outputs/report.txt",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "docker sandbox" in result.error.lower()
|
||||
assert "host-visible path" in result.error.lower()
|
||||
connected_adapter._bot.send_document.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_not_connected(self, adapter):
|
||||
"""If bot is None, returns not connected error."""
|
||||
|
|
@ -665,6 +732,17 @@ class TestSendVideo:
|
|||
assert result.success is False
|
||||
assert "not found" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_workspace_path_has_docker_hint(self, connected_adapter):
|
||||
result = await connected_adapter.send_video(
|
||||
chat_id="12345",
|
||||
video_path="/workspace/video.mp4",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "docker sandbox" in result.error.lower()
|
||||
assert "host-visible path" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_not_connected(self, adapter):
|
||||
result = await adapter.send_video(
|
||||
|
|
|
|||
|
|
@ -71,7 +71,17 @@ def test_group_messages_can_require_direct_trigger_via_config():
|
|||
assert adapter._should_process_message(_group_message("hello everyone")) is False
|
||||
assert adapter._should_process_message(_group_message("hi @hermes_bot", entities=[_mention_entity("hi @hermes_bot")])) is True
|
||||
assert adapter._should_process_message(_group_message("replying", reply_to_bot=True)) is True
|
||||
assert adapter._should_process_message(_group_message("/status"), is_command=True) is True
|
||||
# Commands must also respect require_mention when it is enabled
|
||||
assert adapter._should_process_message(_group_message("/status"), is_command=True) is False
|
||||
# But commands with @mention still pass (Telegram emits a MENTION entity
|
||||
# for /cmd@botname — the bot menu and python-telegram-bot's CommandHandler
|
||||
# rely on this same mechanism)
|
||||
assert adapter._should_process_message(
|
||||
_group_message("/status@hermes_bot", entities=[_mention_entity("/status@hermes_bot")])
|
||||
) is True
|
||||
# And commands still pass unconditionally when require_mention is disabled
|
||||
adapter_no_mention = _make_adapter(require_mention=False)
|
||||
assert adapter_no_mention._should_process_message(_group_message("/status"), is_command=True) is True
|
||||
|
||||
|
||||
def test_free_response_chats_bypass_mention_requirement():
|
||||
|
|
|
|||
185
tests/gateway/test_telegram_mention_boundaries.py
Normal file
185
tests/gateway/test_telegram_mention_boundaries.py
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
"""Tests for Telegram bot mention detection (bug #12545).
|
||||
|
||||
The old implementation used a naive substring check
|
||||
(`f"@{bot_username}" in text.lower()`), which incorrectly matched partial
|
||||
substrings like 'foo@hermes_bot.example'.
|
||||
|
||||
Detection now relies entirely on the MessageEntity objects Telegram's server
|
||||
emits for real mentions. A bare `@username` substring in message text without
|
||||
a corresponding `MENTION` entity is NOT a mention — this correctly ignores
|
||||
@handles that appear inside URLs, code blocks, email-like strings, or quoted
|
||||
text, because Telegram's parser does not emit mention entities for any of
|
||||
those contexts.
|
||||
"""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
|
||||
|
||||
def _make_adapter():
|
||||
adapter = object.__new__(TelegramAdapter)
|
||||
adapter.platform = Platform.TELEGRAM
|
||||
adapter.config = PlatformConfig(enabled=True, token="***", extra={})
|
||||
adapter._bot = SimpleNamespace(id=999, username="hermes_bot")
|
||||
return adapter
|
||||
|
||||
|
||||
def _mention_entity(text, mention="@hermes_bot"):
|
||||
"""Build a MENTION entity pointing at a literal `@username` in `text`."""
|
||||
offset = text.index(mention)
|
||||
return SimpleNamespace(type="mention", offset=offset, length=len(mention))
|
||||
|
||||
|
||||
def _text_mention_entity(offset, length, user_id):
|
||||
"""Build a TEXT_MENTION entity (used when the target user has no public @handle)."""
|
||||
return SimpleNamespace(
|
||||
type="text_mention",
|
||||
offset=offset,
|
||||
length=length,
|
||||
user=SimpleNamespace(id=user_id),
|
||||
)
|
||||
|
||||
|
||||
def _message(text=None, caption=None, entities=None, caption_entities=None):
|
||||
return SimpleNamespace(
|
||||
text=text,
|
||||
caption=caption,
|
||||
entities=entities or [],
|
||||
caption_entities=caption_entities or [],
|
||||
message_thread_id=None,
|
||||
chat=SimpleNamespace(id=-100, type="group"),
|
||||
reply_to_message=None,
|
||||
)
|
||||
|
||||
|
||||
class TestRealMentionsAreDetected:
|
||||
"""A real Telegram mention always comes with a MENTION entity — detect those."""
|
||||
|
||||
def test_mention_at_start_of_message(self):
|
||||
adapter = _make_adapter()
|
||||
text = "@hermes_bot hello world"
|
||||
msg = _message(text=text, entities=[_mention_entity(text)])
|
||||
assert adapter._message_mentions_bot(msg) is True
|
||||
|
||||
def test_mention_mid_sentence(self):
|
||||
adapter = _make_adapter()
|
||||
text = "hey @hermes_bot, can you help?"
|
||||
msg = _message(text=text, entities=[_mention_entity(text)])
|
||||
assert adapter._message_mentions_bot(msg) is True
|
||||
|
||||
def test_mention_at_end_of_message(self):
|
||||
adapter = _make_adapter()
|
||||
text = "thanks for looking @hermes_bot"
|
||||
msg = _message(text=text, entities=[_mention_entity(text)])
|
||||
assert adapter._message_mentions_bot(msg) is True
|
||||
|
||||
def test_mention_in_caption(self):
|
||||
adapter = _make_adapter()
|
||||
caption = "photo for @hermes_bot"
|
||||
msg = _message(caption=caption, caption_entities=[_mention_entity(caption)])
|
||||
assert adapter._message_mentions_bot(msg) is True
|
||||
|
||||
def test_text_mention_entity_targets_bot(self):
|
||||
"""TEXT_MENTION is Telegram's entity type for @FirstName -> user without a public handle."""
|
||||
adapter = _make_adapter()
|
||||
msg = _message(text="hey you", entities=[_text_mention_entity(4, 3, user_id=999)])
|
||||
assert adapter._message_mentions_bot(msg) is True
|
||||
|
||||
|
||||
class TestSubstringFalsePositivesAreRejected:
|
||||
"""Bare `@bot_username` substrings without a MENTION entity must NOT match.
|
||||
|
||||
These are all inputs where the OLD substring check returned True incorrectly.
|
||||
A word-boundary regex would still over-match some of these (code blocks,
|
||||
URLs). Entity-based detection handles them all correctly because Telegram's
|
||||
parser does not emit mention entities for non-mention contexts.
|
||||
"""
|
||||
|
||||
def test_email_like_substring(self):
|
||||
"""bug #12545 exact repro: 'foo@hermes_bot.example'."""
|
||||
adapter = _make_adapter()
|
||||
msg = _message(text="email me at foo@hermes_bot.example")
|
||||
assert adapter._message_mentions_bot(msg) is False
|
||||
|
||||
def test_hostname_substring(self):
|
||||
adapter = _make_adapter()
|
||||
msg = _message(text="contact user@hermes_bot.domain.com")
|
||||
assert adapter._message_mentions_bot(msg) is False
|
||||
|
||||
def test_superstring_username(self):
|
||||
"""`@hermes_botx` is a different username; Telegram would emit a mention
|
||||
entity for `@hermes_botx`, not `@hermes_bot`."""
|
||||
adapter = _make_adapter()
|
||||
msg = _message(text="@hermes_botx hello")
|
||||
assert adapter._message_mentions_bot(msg) is False
|
||||
|
||||
def test_underscore_suffix_substring(self):
|
||||
adapter = _make_adapter()
|
||||
msg = _message(text="see @hermes_bot_admin for help")
|
||||
assert adapter._message_mentions_bot(msg) is False
|
||||
|
||||
def test_substring_inside_url_without_entity(self):
|
||||
"""@handle inside a URL produces a URL entity, not a MENTION entity."""
|
||||
adapter = _make_adapter()
|
||||
msg = _message(text="see https://example.com/@hermes_bot for details")
|
||||
assert adapter._message_mentions_bot(msg) is False
|
||||
|
||||
def test_substring_inside_code_block_without_entity(self):
|
||||
"""Telegram doesn't emit mention entities inside code/pre entities."""
|
||||
adapter = _make_adapter()
|
||||
msg = _message(text="use the string `@hermes_bot` in config")
|
||||
assert adapter._message_mentions_bot(msg) is False
|
||||
|
||||
def test_plain_text_with_no_at_sign(self):
|
||||
adapter = _make_adapter()
|
||||
msg = _message(text="just a normal group message")
|
||||
assert adapter._message_mentions_bot(msg) is False
|
||||
|
||||
def test_email_substring_in_caption(self):
|
||||
adapter = _make_adapter()
|
||||
msg = _message(caption="foo@hermes_bot.example")
|
||||
assert adapter._message_mentions_bot(msg) is False
|
||||
|
||||
|
||||
class TestEntityEdgeCases:
|
||||
"""Malformed or mismatched entities should not crash or over-match."""
|
||||
|
||||
def test_mention_entity_for_different_username(self):
|
||||
adapter = _make_adapter()
|
||||
text = "@someone_else hi"
|
||||
msg = _message(text=text, entities=[_mention_entity(text, mention="@someone_else")])
|
||||
assert adapter._message_mentions_bot(msg) is False
|
||||
|
||||
def test_text_mention_entity_for_different_user(self):
|
||||
adapter = _make_adapter()
|
||||
msg = _message(text="hi there", entities=[_text_mention_entity(0, 2, user_id=12345)])
|
||||
assert adapter._message_mentions_bot(msg) is False
|
||||
|
||||
def test_malformed_entity_with_negative_offset(self):
|
||||
adapter = _make_adapter()
|
||||
msg = _message(text="@hermes_bot hi",
|
||||
entities=[SimpleNamespace(type="mention", offset=-1, length=11)])
|
||||
assert adapter._message_mentions_bot(msg) is False
|
||||
|
||||
def test_malformed_entity_with_zero_length(self):
|
||||
adapter = _make_adapter()
|
||||
msg = _message(text="@hermes_bot hi",
|
||||
entities=[SimpleNamespace(type="mention", offset=0, length=0)])
|
||||
assert adapter._message_mentions_bot(msg) is False
|
||||
|
||||
|
||||
class TestCaseInsensitivity:
|
||||
"""Telegram usernames are case-insensitive; the slice-compare normalizes both sides."""
|
||||
|
||||
def test_uppercase_mention(self):
|
||||
adapter = _make_adapter()
|
||||
text = "hi @HERMES_BOT"
|
||||
msg = _message(text=text, entities=[_mention_entity(text, mention="@HERMES_BOT")])
|
||||
assert adapter._message_mentions_bot(msg) is True
|
||||
|
||||
def test_mixed_case_mention(self):
|
||||
adapter = _make_adapter()
|
||||
text = "hi @Hermes_Bot"
|
||||
msg = _message(text=text, entities=[_mention_entity(text, mention="@Hermes_Bot")])
|
||||
assert adapter._message_mentions_bot(msg) is True
|
||||
100
tests/gateway/test_telegram_webhook_secret.py
Normal file
100
tests/gateway/test_telegram_webhook_secret.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
"""Tests for GHSA-3vpc-7q5r-276h — Telegram webhook secret required.
|
||||
|
||||
Previously, when TELEGRAM_WEBHOOK_URL was set but TELEGRAM_WEBHOOK_SECRET
|
||||
was not, python-telegram-bot received secret_token=None and the webhook
|
||||
endpoint accepted any HTTP POST.
|
||||
|
||||
The fix refuses to start the adapter in webhook mode without the secret.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
_repo = str(Path(__file__).resolve().parents[2])
|
||||
if _repo not in sys.path:
|
||||
sys.path.insert(0, _repo)
|
||||
|
||||
|
||||
class TestTelegramWebhookSecretRequired:
|
||||
"""Direct source-level check of the webhook-secret guard.
|
||||
|
||||
The guard is embedded in TelegramAdapter.connect() and hard to isolate
|
||||
via mocks (requires a full python-telegram-bot ApplicationBuilder
|
||||
chain). These tests exercise it via source inspection — verifying the
|
||||
check exists, raises RuntimeError with the advisory link, and only
|
||||
fires in webhook mode. End-to-end validation is covered by CI +
|
||||
manual deployment tests.
|
||||
"""
|
||||
|
||||
def _get_source(self) -> str:
|
||||
path = Path(_repo) / "gateway" / "platforms" / "telegram.py"
|
||||
return path.read_text(encoding="utf-8")
|
||||
|
||||
def test_webhook_branch_checks_secret(self):
|
||||
"""The webhook-mode branch of connect() must read
|
||||
TELEGRAM_WEBHOOK_SECRET and refuse when empty."""
|
||||
src = self._get_source()
|
||||
# The guard must appear after TELEGRAM_WEBHOOK_URL is set
|
||||
assert re.search(
|
||||
r'TELEGRAM_WEBHOOK_SECRET.*?\.strip\(\)\s*\n\s*if not webhook_secret:',
|
||||
src, re.DOTALL,
|
||||
), (
|
||||
"TelegramAdapter.connect() must strip TELEGRAM_WEBHOOK_SECRET "
|
||||
"and raise when the secret is empty — see GHSA-3vpc-7q5r-276h"
|
||||
)
|
||||
|
||||
def test_guard_raises_runtime_error(self):
|
||||
"""The guard raises RuntimeError (not a silent log) so operators
|
||||
see the failure at startup."""
|
||||
src = self._get_source()
|
||||
# Between the "if not webhook_secret:" line and the next blank
|
||||
# line block, we should see a RuntimeError being raised
|
||||
guard_match = re.search(
|
||||
r'if not webhook_secret:\s*\n\s*raise\s+RuntimeError\(',
|
||||
src,
|
||||
)
|
||||
assert guard_match, (
|
||||
"Missing webhook secret must raise RuntimeError — silent "
|
||||
"fall-through was the original GHSA-3vpc-7q5r-276h bypass"
|
||||
)
|
||||
|
||||
def test_guard_message_includes_advisory_link(self):
|
||||
"""The RuntimeError message should reference the advisory so
|
||||
operators can read the full context."""
|
||||
src = self._get_source()
|
||||
assert "GHSA-3vpc-7q5r-276h" in src, (
|
||||
"Guard error message must cite the advisory for operator context"
|
||||
)
|
||||
|
||||
def test_guard_message_explains_remediation(self):
|
||||
"""The error should tell the operator how to fix it."""
|
||||
src = self._get_source()
|
||||
# Should mention how to generate a secret
|
||||
assert "openssl rand" in src or "TELEGRAM_WEBHOOK_SECRET=" in src, (
|
||||
"Guard error message should show operators how to set "
|
||||
"TELEGRAM_WEBHOOK_SECRET"
|
||||
)
|
||||
|
||||
def test_polling_branch_has_no_secret_guard(self):
|
||||
"""Polling mode (else-branch) must NOT require the webhook secret —
|
||||
polling authenticates via the bot token, not a webhook secret."""
|
||||
src = self._get_source()
|
||||
# The guard should appear inside the `if webhook_url:` branch,
|
||||
# not the `else:` polling branch. Rough check: the raise is
|
||||
# followed (within ~60 lines) by an `else:` that starts the
|
||||
# polling branch, and there's no secret-check in that polling
|
||||
# branch.
|
||||
webhook_block = re.search(
|
||||
r'if webhook_url:\s*\n(.*?)\n else:\s*\n(.*?)\n',
|
||||
src, re.DOTALL,
|
||||
)
|
||||
if webhook_block:
|
||||
webhook_body = webhook_block.group(1)
|
||||
polling_body = webhook_block.group(2)
|
||||
assert "TELEGRAM_WEBHOOK_SECRET" in webhook_body
|
||||
assert "TELEGRAM_WEBHOOK_SECRET" not in polling_body
|
||||
|
|
@ -148,6 +148,70 @@ class TestDiscordTextBatching:
|
|||
await asyncio.sleep(0.25)
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shield_protects_handle_message_from_cancel(self):
|
||||
"""Regression guard: a follow-up chunk arriving while
|
||||
handle_message is mid-flight must NOT cancel the running
|
||||
dispatch. _enqueue_text_event fires prior_task.cancel() on
|
||||
every new chunk; without asyncio.shield around handle_message
|
||||
the cancel propagates into the agent's streaming request and
|
||||
aborts the response.
|
||||
"""
|
||||
adapter = _make_discord_adapter()
|
||||
|
||||
handle_started = asyncio.Event()
|
||||
release_handle = asyncio.Event()
|
||||
first_handle_cancelled = asyncio.Event()
|
||||
first_handle_completed = asyncio.Event()
|
||||
call_count = [0]
|
||||
|
||||
async def slow_handle(event):
|
||||
call_count[0] += 1
|
||||
# Only the first call (batch 1) is the one we're protecting.
|
||||
if call_count[0] == 1:
|
||||
handle_started.set()
|
||||
try:
|
||||
await release_handle.wait()
|
||||
first_handle_completed.set()
|
||||
except asyncio.CancelledError:
|
||||
first_handle_cancelled.set()
|
||||
raise
|
||||
# Second call (batch 2) returns immediately — not the subject
|
||||
# of this test.
|
||||
|
||||
adapter.handle_message = slow_handle
|
||||
|
||||
# Prime batch 1 and wait for it to land inside handle_message.
|
||||
adapter._enqueue_text_event(_make_event("batch 1", Platform.DISCORD))
|
||||
await asyncio.wait_for(handle_started.wait(), timeout=1.0)
|
||||
|
||||
# A new chunk arrives — _enqueue_text_event fires
|
||||
# prior_task.cancel() on batch 1's flush task, which is
|
||||
# currently awaiting inside handle_message.
|
||||
adapter._enqueue_text_event(_make_event("batch 2 follow-up", Platform.DISCORD))
|
||||
|
||||
# Let the cancel propagate.
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# CRITICAL ASSERTION: batch 1's handle_message must NOT have
|
||||
# been cancelled. Without asyncio.shield this assertion fails
|
||||
# because CancelledError propagates from the flush task's
|
||||
# `await self.handle_message(event)` into slow_handle.
|
||||
assert not first_handle_cancelled.is_set(), (
|
||||
"handle_message for batch 1 was cancelled by a follow-up "
|
||||
"chunk — asyncio.shield is missing or broken"
|
||||
)
|
||||
|
||||
# Release batch 1's handle_message and let it complete.
|
||||
release_handle.set()
|
||||
await asyncio.wait_for(first_handle_completed.wait(), timeout=1.0)
|
||||
assert first_handle_completed.is_set()
|
||||
|
||||
# Cleanup
|
||||
for task in list(adapter._pending_text_batch_tasks.values()):
|
||||
task.cancel()
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Matrix text batching
|
||||
|
|
|
|||
|
|
@ -63,6 +63,12 @@ def _make_runner(platform: Platform, config: GatewayConfig):
|
|||
runner.pairing_store = MagicMock()
|
||||
runner.pairing_store.is_approved.return_value = False
|
||||
runner.pairing_store._is_rate_limited.return_value = False
|
||||
# Attributes required by _handle_message for the authorized-user path
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._update_prompts = {}
|
||||
runner.hooks = SimpleNamespace(dispatch=AsyncMock(return_value=None))
|
||||
runner._sessions = {}
|
||||
return runner, adapter
|
||||
|
||||
|
||||
|
|
@ -295,3 +301,172 @@ async def test_global_ignore_suppresses_pairing_reply(monkeypatch):
|
|||
assert result is None
|
||||
runner.pairing_store.generate_code.assert_not_called()
|
||||
adapter.send.assert_not_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Allowlist-configured platforms default to "ignore" for unauthorized users
|
||||
# (#9337: Signal gateway sends pairing spam when allowlist is configured)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signal_with_allowlist_ignores_unauthorized_dm(monkeypatch):
|
||||
"""When SIGNAL_ALLOWED_USERS is set, unauthorized DMs are silently dropped.
|
||||
|
||||
This is the primary regression test for #9337: before the fix, Signal
|
||||
would send pairing codes to ANY sender even when a strict allowlist was
|
||||
configured, spamming personal contacts with cryptic bot messages.
|
||||
"""
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("SIGNAL_ALLOWED_USERS", "+15550000001") # allowlist set
|
||||
|
||||
config = GatewayConfig(
|
||||
platforms={Platform.SIGNAL: PlatformConfig(enabled=True)},
|
||||
)
|
||||
runner, adapter = _make_runner(Platform.SIGNAL, config)
|
||||
|
||||
result = await runner._handle_message(
|
||||
_make_event(Platform.SIGNAL, "+15559999999", "+15559999999") # not in allowlist
|
||||
)
|
||||
|
||||
assert result is None
|
||||
runner.pairing_store.generate_code.assert_not_called()
|
||||
adapter.send.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_telegram_with_allowlist_ignores_unauthorized_dm(monkeypatch):
|
||||
"""Same behavior for Telegram: allowlist ⟹ ignore unauthorized DMs."""
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("TELEGRAM_ALLOWED_USERS", "111111111")
|
||||
|
||||
config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True)},
|
||||
)
|
||||
runner, adapter = _make_runner(Platform.TELEGRAM, config)
|
||||
|
||||
result = await runner._handle_message(
|
||||
_make_event(Platform.TELEGRAM, "999999999", "999999999")
|
||||
)
|
||||
|
||||
assert result is None
|
||||
runner.pairing_store.generate_code.assert_not_called()
|
||||
adapter.send.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_allowlist_ignores_unauthorized_dm(monkeypatch):
|
||||
"""GATEWAY_ALLOWED_USERS also triggers the 'ignore' behavior."""
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("GATEWAY_ALLOWED_USERS", "111111111")
|
||||
|
||||
config = GatewayConfig(
|
||||
platforms={Platform.SIGNAL: PlatformConfig(enabled=True)},
|
||||
)
|
||||
runner, adapter = _make_runner(Platform.SIGNAL, config)
|
||||
|
||||
result = await runner._handle_message(
|
||||
_make_event(Platform.SIGNAL, "+15559999999", "+15559999999")
|
||||
)
|
||||
|
||||
assert result is None
|
||||
runner.pairing_store.generate_code.assert_not_called()
|
||||
adapter.send.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_allowlist_still_pairs_by_default(monkeypatch):
|
||||
"""Without any allowlist, pairing behavior is preserved (open gateway)."""
|
||||
_clear_auth_env(monkeypatch)
|
||||
# No SIGNAL_ALLOWED_USERS, no GATEWAY_ALLOWED_USERS
|
||||
|
||||
config = GatewayConfig(
|
||||
platforms={Platform.SIGNAL: PlatformConfig(enabled=True)},
|
||||
)
|
||||
runner, adapter = _make_runner(Platform.SIGNAL, config)
|
||||
runner.pairing_store.generate_code.return_value = "PAIR1234"
|
||||
|
||||
result = await runner._handle_message(
|
||||
_make_event(Platform.SIGNAL, "+15559999999", "+15559999999")
|
||||
)
|
||||
|
||||
assert result is None
|
||||
runner.pairing_store.generate_code.assert_called_once()
|
||||
adapter.send.assert_awaited_once()
|
||||
assert "PAIR1234" in adapter.send.await_args.args[1]
|
||||
|
||||
|
||||
def test_explicit_pair_config_overrides_allowlist_default(monkeypatch):
|
||||
"""Explicit unauthorized_dm_behavior='pair' overrides the allowlist default.
|
||||
|
||||
Operators can opt back in to pairing even with an allowlist by setting
|
||||
unauthorized_dm_behavior: pair in their platform config. We test the
|
||||
_get_unauthorized_dm_behavior resolver directly to avoid the full
|
||||
_handle_message pipeline which requires extensive runner state.
|
||||
"""
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("SIGNAL_ALLOWED_USERS", "+15550000001")
|
||||
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.SIGNAL: PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"unauthorized_dm_behavior": "pair"}, # explicit override
|
||||
),
|
||||
},
|
||||
)
|
||||
runner, _adapter = _make_runner(Platform.SIGNAL, config)
|
||||
|
||||
# The per-platform explicit config should beat the allowlist-derived default
|
||||
behavior = runner._get_unauthorized_dm_behavior(Platform.SIGNAL)
|
||||
assert behavior == "pair"
|
||||
|
||||
|
||||
def test_allowlist_authorized_user_returns_ignore_for_unauthorized(monkeypatch):
|
||||
"""_get_unauthorized_dm_behavior returns 'ignore' when allowlist is set.
|
||||
|
||||
We test the resolver directly. The full _handle_message path for
|
||||
authorized users is covered by the integration tests in this module.
|
||||
"""
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("SIGNAL_ALLOWED_USERS", "+15550000001")
|
||||
|
||||
config = GatewayConfig(
|
||||
platforms={Platform.SIGNAL: PlatformConfig(enabled=True)},
|
||||
)
|
||||
runner, _adapter = _make_runner(Platform.SIGNAL, config)
|
||||
|
||||
behavior = runner._get_unauthorized_dm_behavior(Platform.SIGNAL)
|
||||
assert behavior == "ignore"
|
||||
|
||||
|
||||
def test_get_unauthorized_dm_behavior_no_allowlist_returns_pair(monkeypatch):
|
||||
"""Without any allowlist, 'pair' is still the default."""
|
||||
_clear_auth_env(monkeypatch)
|
||||
|
||||
config = GatewayConfig(
|
||||
platforms={Platform.SIGNAL: PlatformConfig(enabled=True)},
|
||||
)
|
||||
runner, _adapter = _make_runner(Platform.SIGNAL, config)
|
||||
|
||||
behavior = runner._get_unauthorized_dm_behavior(Platform.SIGNAL)
|
||||
assert behavior == "pair"
|
||||
|
||||
|
||||
def test_qqbot_with_allowlist_ignores_unauthorized_dm(monkeypatch):
|
||||
"""QQBOT is included in the allowlist-aware default (QQ_ALLOWED_USERS).
|
||||
|
||||
Regression guard: the initial #9337 fix omitted QQBOT from the env map
|
||||
inside _get_unauthorized_dm_behavior, even though _is_user_authorized
|
||||
mapped it to QQ_ALLOWED_USERS. Without QQBOT here, a QQ operator with a
|
||||
strict user allowlist would still get pairing codes sent to strangers.
|
||||
"""
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("QQ_ALLOWED_USERS", "allowed-openid-1")
|
||||
|
||||
config = GatewayConfig(
|
||||
platforms={Platform.QQBOT: PlatformConfig(enabled=True)},
|
||||
)
|
||||
runner, _adapter = _make_runner(Platform.QQBOT, config)
|
||||
|
||||
behavior = runner._get_unauthorized_dm_behavior(Platform.QQBOT)
|
||||
assert behavior == "ignore"
|
||||
|
|
|
|||
|
|
@ -175,3 +175,79 @@ class TestUsageCachedAgent:
|
|||
result = await runner._handle_usage_command(event)
|
||||
|
||||
assert "Cost: included" in result
|
||||
|
||||
|
||||
class TestUsageAccountSection:
|
||||
"""Account-limits section appended to /usage output (PR #2486)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_command_includes_account_section(self, monkeypatch):
|
||||
agent = _make_mock_agent(provider="openai-codex")
|
||||
agent.base_url = "https://chatgpt.com/backend-api/codex"
|
||||
agent.api_key = "unused"
|
||||
runner = _make_runner(SK, cached_agent=agent)
|
||||
event = MagicMock()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.run.fetch_account_usage",
|
||||
lambda provider, base_url=None, api_key=None: object(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"gateway.run.render_account_usage_lines",
|
||||
lambda snapshot, markdown=False: [
|
||||
"📈 **Account limits**",
|
||||
"Provider: openai-codex (Pro)",
|
||||
"Session: 85% remaining (15% used)",
|
||||
],
|
||||
)
|
||||
with patch("agent.rate_limit_tracker.format_rate_limit_compact", return_value="RPM: 50/60"), \
|
||||
patch("agent.usage_pricing.estimate_usage_cost") as mock_cost:
|
||||
mock_cost.return_value = MagicMock(amount_usd=None, status="included")
|
||||
result = await runner._handle_usage_command(event)
|
||||
|
||||
assert "📊 **Session Token Usage**" in result
|
||||
assert "📈 **Account limits**" in result
|
||||
assert "Provider: openai-codex (Pro)" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_command_uses_persisted_provider_when_agent_not_running(self, monkeypatch):
|
||||
runner = _make_runner(SK)
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.get_session.return_value = {
|
||||
"billing_provider": "openai-codex",
|
||||
"billing_base_url": "https://chatgpt.com/backend-api/codex",
|
||||
}
|
||||
session_entry = MagicMock()
|
||||
session_entry.session_id = "sess-1"
|
||||
runner.session_store.get_or_create_session.return_value = session_entry
|
||||
runner.session_store.load_transcript.return_value = [
|
||||
{"role": "user", "content": "earlier"},
|
||||
]
|
||||
|
||||
calls = {}
|
||||
|
||||
async def _fake_to_thread(fn, *args, **kwargs):
|
||||
calls["args"] = args
|
||||
calls["kwargs"] = kwargs
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("gateway.run.asyncio.to_thread", _fake_to_thread)
|
||||
monkeypatch.setattr(
|
||||
"gateway.run.fetch_account_usage",
|
||||
lambda provider, base_url=None, api_key=None: object(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"gateway.run.render_account_usage_lines",
|
||||
lambda snapshot, markdown=False: [
|
||||
"📈 **Account limits**",
|
||||
"Provider: openai-codex (Pro)",
|
||||
],
|
||||
)
|
||||
|
||||
event = MagicMock()
|
||||
result = await runner._handle_usage_command(event)
|
||||
|
||||
assert calls["args"] == ("openai-codex",)
|
||||
assert calls["kwargs"]["base_url"] == "https://chatgpt.com/backend-api/codex"
|
||||
assert "📊 **Session Info**" in result
|
||||
assert "📈 **Account limits**" in result
|
||||
|
|
|
|||
|
|
@ -99,22 +99,22 @@ class TestHandleVoiceCommand:
|
|||
event = _make_event("/voice on")
|
||||
result = await runner._handle_voice_command(event)
|
||||
assert "enabled" in result.lower()
|
||||
assert runner._voice_mode["123"] == "voice_only"
|
||||
assert runner._voice_mode["telegram:123"] == "voice_only"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_off(self, runner):
|
||||
runner._voice_mode["123"] = "voice_only"
|
||||
runner._voice_mode["telegram:123"] = "voice_only"
|
||||
event = _make_event("/voice off")
|
||||
result = await runner._handle_voice_command(event)
|
||||
assert "disabled" in result.lower()
|
||||
assert runner._voice_mode["123"] == "off"
|
||||
assert runner._voice_mode["telegram:123"] == "off"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_tts(self, runner):
|
||||
event = _make_event("/voice tts")
|
||||
result = await runner._handle_voice_command(event)
|
||||
assert "tts" in result.lower()
|
||||
assert runner._voice_mode["123"] == "all"
|
||||
assert runner._voice_mode["telegram:123"] == "all"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_status_off(self, runner):
|
||||
|
|
@ -124,7 +124,7 @@ class TestHandleVoiceCommand:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_status_on(self, runner):
|
||||
runner._voice_mode["123"] = "voice_only"
|
||||
runner._voice_mode["telegram:123"] = "voice_only"
|
||||
event = _make_event("/voice status")
|
||||
result = await runner._handle_voice_command(event)
|
||||
assert "voice reply" in result.lower()
|
||||
|
|
@ -134,15 +134,15 @@ class TestHandleVoiceCommand:
|
|||
event = _make_event("/voice")
|
||||
result = await runner._handle_voice_command(event)
|
||||
assert "enabled" in result.lower()
|
||||
assert runner._voice_mode["123"] == "voice_only"
|
||||
assert runner._voice_mode["telegram:123"] == "voice_only"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toggle_on_to_off(self, runner):
|
||||
runner._voice_mode["123"] = "voice_only"
|
||||
runner._voice_mode["telegram:123"] = "voice_only"
|
||||
event = _make_event("/voice")
|
||||
result = await runner._handle_voice_command(event)
|
||||
assert "disabled" in result.lower()
|
||||
assert runner._voice_mode["123"] == "off"
|
||||
assert runner._voice_mode["telegram:123"] == "off"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistence_saved(self, runner):
|
||||
|
|
@ -150,39 +150,47 @@ class TestHandleVoiceCommand:
|
|||
await runner._handle_voice_command(event)
|
||||
assert runner._VOICE_MODE_PATH.exists()
|
||||
data = json.loads(runner._VOICE_MODE_PATH.read_text())
|
||||
assert data["123"] == "voice_only"
|
||||
assert data["telegram:123"] == "voice_only"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistence_loaded(self, runner):
|
||||
runner._VOICE_MODE_PATH.write_text(json.dumps({"456": "all"}))
|
||||
runner._VOICE_MODE_PATH.write_text(json.dumps({"telegram:456": "all"}))
|
||||
loaded = runner._load_voice_modes()
|
||||
assert loaded == {"456": "all"}
|
||||
assert loaded == {"telegram:456": "all"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistence_saved_for_off(self, runner):
|
||||
event = _make_event("/voice off")
|
||||
await runner._handle_voice_command(event)
|
||||
data = json.loads(runner._VOICE_MODE_PATH.read_text())
|
||||
assert data["123"] == "off"
|
||||
assert data["telegram:123"] == "off"
|
||||
|
||||
def test_sync_voice_mode_state_to_adapter_restores_off_chats(self, runner):
|
||||
runner._voice_mode = {"123": "off", "456": "all"}
|
||||
adapter = SimpleNamespace(_auto_tts_disabled_chats=set())
|
||||
from gateway.config import Platform
|
||||
runner._voice_mode = {"telegram:123": "off", "telegram:456": "all"}
|
||||
adapter = SimpleNamespace(
|
||||
_auto_tts_disabled_chats=set(),
|
||||
platform=Platform.TELEGRAM,
|
||||
)
|
||||
|
||||
runner._sync_voice_mode_state_to_adapter(adapter)
|
||||
|
||||
assert adapter._auto_tts_disabled_chats == {"123"}
|
||||
|
||||
def test_restart_restores_voice_off_state(self, runner, tmp_path):
|
||||
runner._VOICE_MODE_PATH.write_text(json.dumps({"123": "off"}))
|
||||
from gateway.config import Platform
|
||||
runner._VOICE_MODE_PATH.write_text(json.dumps({"telegram:123": "off"}))
|
||||
|
||||
restored_runner = _make_runner(tmp_path)
|
||||
restored_runner._voice_mode = restored_runner._load_voice_modes()
|
||||
adapter = SimpleNamespace(_auto_tts_disabled_chats=set())
|
||||
adapter = SimpleNamespace(
|
||||
_auto_tts_disabled_chats=set(),
|
||||
platform=Platform.TELEGRAM,
|
||||
)
|
||||
|
||||
restored_runner._sync_voice_mode_state_to_adapter(adapter)
|
||||
|
||||
assert restored_runner._voice_mode["123"] == "off"
|
||||
assert restored_runner._voice_mode["telegram:123"] == "off"
|
||||
assert adapter._auto_tts_disabled_chats == {"123"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -191,8 +199,21 @@ class TestHandleVoiceCommand:
|
|||
e2 = _make_event("/voice tts", chat_id="bbb")
|
||||
await runner._handle_voice_command(e1)
|
||||
await runner._handle_voice_command(e2)
|
||||
assert runner._voice_mode["aaa"] == "voice_only"
|
||||
assert runner._voice_mode["bbb"] == "all"
|
||||
assert runner._voice_mode["telegram:aaa"] == "voice_only"
|
||||
assert runner._voice_mode["telegram:bbb"] == "all"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_isolation(self, runner):
|
||||
"""Same chat_id on different platforms must not collide (#12542)."""
|
||||
telegram_event = _make_event("/voice on", chat_id="999")
|
||||
slack_event = _make_event("/voice off", chat_id="999")
|
||||
slack_event.source.platform.value = "slack"
|
||||
|
||||
await runner._handle_voice_command(telegram_event)
|
||||
await runner._handle_voice_command(slack_event)
|
||||
|
||||
assert runner._voice_mode["telegram:999"] == "voice_only"
|
||||
assert runner._voice_mode["slack:999"] == "off"
|
||||
|
||||
|
||||
# =====================================================================
|
||||
|
|
@ -223,9 +244,9 @@ class TestAutoVoiceReply:
|
|||
"""Call real _should_send_voice_reply on a GatewayRunner instance."""
|
||||
chat_id = "123"
|
||||
if voice_mode != "off":
|
||||
runner._voice_mode[chat_id] = voice_mode
|
||||
runner._voice_mode["telegram:" + chat_id] = voice_mode
|
||||
else:
|
||||
runner._voice_mode.pop(chat_id, None)
|
||||
runner._voice_mode.pop("telegram:" + chat_id, None)
|
||||
|
||||
event = _make_event(message_type=message_type)
|
||||
|
||||
|
|
@ -416,6 +437,7 @@ class TestDiscordPlayTtsSkip:
|
|||
adapter.platform = Platform.DISCORD
|
||||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_locks = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
|
|
@ -712,7 +734,7 @@ class TestVoiceChannelCommands:
|
|||
result = await runner._handle_voice_channel_join(event)
|
||||
assert "joined" in result.lower()
|
||||
assert "General" in result
|
||||
assert runner._voice_mode["123"] == "all"
|
||||
assert runner._voice_mode["discord:123"] == "all"
|
||||
assert mock_adapter._voice_sources[111]["chat_id"] == "123"
|
||||
assert mock_adapter._voice_sources[111]["chat_type"] == "group"
|
||||
|
||||
|
|
@ -790,10 +812,10 @@ class TestVoiceChannelCommands:
|
|||
mock_adapter.leave_voice_channel = AsyncMock()
|
||||
event = self._make_discord_event("/voice leave")
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
runner._voice_mode["123"] = "all"
|
||||
runner._voice_mode["discord:123"] = "all"
|
||||
result = await runner._handle_voice_channel_leave(event)
|
||||
assert "left" in result.lower()
|
||||
assert runner._voice_mode["123"] == "off"
|
||||
assert runner._voice_mode["discord:123"] == "off"
|
||||
mock_adapter.leave_voice_channel.assert_called_once_with(111)
|
||||
|
||||
# -- _handle_voice_channel_input --
|
||||
|
|
@ -931,6 +953,7 @@ class TestDiscordVoiceChannelMethods:
|
|||
adapter.config = config
|
||||
adapter._client = MagicMock()
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_locks = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
|
|
@ -1296,11 +1319,11 @@ class TestLeaveExceptionHandling:
|
|||
event = _make_event("/voice leave")
|
||||
event.raw_message = SimpleNamespace(guild_id=111, guild=None)
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
runner._voice_mode["123"] = "all"
|
||||
runner._voice_mode["telegram:123"] = "all"
|
||||
|
||||
result = await runner._handle_voice_channel_leave(event)
|
||||
assert "left" in result.lower()
|
||||
assert runner._voice_mode["123"] == "off"
|
||||
assert runner._voice_mode["telegram:123"] == "off"
|
||||
assert mock_adapter._voice_input_callback is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -1314,7 +1337,7 @@ class TestLeaveExceptionHandling:
|
|||
event = _make_event("/voice leave")
|
||||
event.raw_message = SimpleNamespace(guild_id=111, guild=None)
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
runner._voice_mode["123"] = "all"
|
||||
runner._voice_mode["telegram:123"] = "all"
|
||||
|
||||
await runner._handle_voice_channel_leave(event)
|
||||
assert mock_adapter._voice_input_callback is None
|
||||
|
|
@ -1712,6 +1735,7 @@ class TestVoiceTimeoutCleansRunnerState:
|
|||
adapter.platform = Platform.DISCORD
|
||||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_locks = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
|
|
@ -1760,11 +1784,11 @@ class TestVoiceTimeoutCleansRunnerState:
|
|||
async def test_runner_cleanup_method_removes_voice_mode(self, tmp_path):
|
||||
"""_handle_voice_timeout_cleanup removes voice_mode for chat."""
|
||||
runner = _make_runner(tmp_path)
|
||||
runner._voice_mode["999"] = "all"
|
||||
runner._voice_mode["discord:999"] = "all"
|
||||
|
||||
runner._handle_voice_timeout_cleanup("999")
|
||||
|
||||
assert runner._voice_mode["999"] == "off", \
|
||||
assert runner._voice_mode["discord:999"] == "off", \
|
||||
"voice_mode must persist explicit off state after timeout cleanup"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -1802,6 +1826,7 @@ class TestPlaybackTimeout:
|
|||
adapter.platform = Platform.DISCORD
|
||||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_locks = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
|
|
@ -1983,6 +2008,7 @@ class TestVoiceChannelAwareness:
|
|||
config.token = "fake-token"
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_locks = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_receivers = {}
|
||||
|
|
@ -2453,6 +2479,7 @@ class TestVoiceTTSPlayback:
|
|||
adapter.platform = Platform.DISCORD
|
||||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_locks = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_receivers = {}
|
||||
|
|
@ -2518,7 +2545,7 @@ class TestVoiceTTSPlayback:
|
|||
agent_msgs=None, already_sent=False):
|
||||
from gateway.platforms.base import MessageType, MessageEvent, SessionSource
|
||||
from gateway.config import Platform
|
||||
runner._voice_mode["ch1"] = voice_mode
|
||||
runner._voice_mode["discord:ch1"] = voice_mode
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD, chat_id="ch1",
|
||||
user_id="1", user_name="test", chat_type="channel",
|
||||
|
|
@ -2633,6 +2660,7 @@ class TestUDPKeepalive:
|
|||
adapter.platform = Platform.DISCORD
|
||||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_locks = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_receivers = {}
|
||||
|
|
|
|||
218
tests/gateway/test_voice_mode_platform_isolation.py
Normal file
218
tests/gateway/test_voice_mode_platform_isolation.py
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
"""Tests for voice mode platform isolation (bug #12542).
|
||||
|
||||
Voice mode state stored as {chat_id: mode} without a platform namespace
|
||||
caused collisions: Telegram chat '123' and Slack chat '123' shared the
|
||||
same key. The fix prefixes keys with platform value: 'telegram:123' vs
|
||||
'slack:123'.
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
|
||||
class TestVoiceKeyHelper:
|
||||
"""Test the _voice_key helper method."""
|
||||
|
||||
def test_voice_key_format(self):
|
||||
"""_voice_key returns 'platform:chat_id' format."""
|
||||
runner = _make_runner()
|
||||
assert runner._voice_key(Platform.TELEGRAM, "123") == "telegram:123"
|
||||
assert runner._voice_key(Platform.SLACK, "456") == "slack:456"
|
||||
assert runner._voice_key(Platform.DISCORD, "789") == "discord:789"
|
||||
|
||||
def test_voice_key_different_platforms_same_chat_id(self):
|
||||
"""Same chat_id on different platforms yields different keys."""
|
||||
runner = _make_runner()
|
||||
key_telegram = runner._voice_key(Platform.TELEGRAM, "123")
|
||||
key_slack = runner._voice_key(Platform.SLACK, "123")
|
||||
key_discord = runner._voice_key(Platform.DISCORD, "123")
|
||||
assert key_telegram != key_slack
|
||||
assert key_slack != key_discord
|
||||
assert key_telegram == "telegram:123"
|
||||
assert key_slack == "slack:123"
|
||||
assert key_discord == "discord:123"
|
||||
|
||||
|
||||
class TestVoiceModePlatformIsolation:
|
||||
"""Test that voice mode state is isolated by platform."""
|
||||
|
||||
def test_telegram_and_slack_voice_mode_independent(self):
|
||||
"""Setting voice mode for Telegram chat '123' does not affect Slack chat '123'."""
|
||||
runner = _make_runner()
|
||||
|
||||
# Enable voice mode for Telegram chat '123'
|
||||
runner._voice_mode[runner._voice_key(Platform.TELEGRAM, "123")] = "all"
|
||||
# Enable voice mode for Slack chat '123' to a different mode
|
||||
runner._voice_mode[runner._voice_key(Platform.SLACK, "123")] = "voice_only"
|
||||
|
||||
# Verify they are independent
|
||||
assert runner._voice_mode.get(runner._voice_key(Platform.TELEGRAM, "123")) == "all"
|
||||
assert runner._voice_mode.get(runner._voice_key(Platform.SLACK, "123")) == "voice_only"
|
||||
|
||||
# Disabling Telegram should not affect Slack
|
||||
runner._voice_mode[runner._voice_key(Platform.TELEGRAM, "123")] = "off"
|
||||
assert runner._voice_mode.get(runner._voice_key(Platform.TELEGRAM, "123")) == "off"
|
||||
assert runner._voice_mode.get(runner._voice_key(Platform.SLACK, "123")) == "voice_only"
|
||||
|
||||
|
||||
class TestLegacyKeyMigration:
|
||||
"""Test migration of legacy unprefixed keys in _load_voice_modes."""
|
||||
|
||||
def test_load_voice_modes_skips_legacy_keys(self):
|
||||
"""_load_voice_modes skips keys without ':' prefix and logs a warning."""
|
||||
runner = _make_runner()
|
||||
|
||||
# Simulate legacy persisted data with unprefixed keys
|
||||
legacy_data = {
|
||||
"123": "all",
|
||||
"456": "voice_only",
|
||||
# Also includes a properly prefixed key (from after the fix)
|
||||
"telegram:789": "off",
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
voice_path = Path(tmpdir) / "gateway_voice_mode.json"
|
||||
voice_path.write_text(json.dumps(legacy_data))
|
||||
|
||||
with patch.object(runner, "_VOICE_MODE_PATH", voice_path):
|
||||
with patch("gateway.run.logger") as mock_logger:
|
||||
result = runner._load_voice_modes()
|
||||
|
||||
# Legacy keys without ':' should be skipped
|
||||
assert "123" not in result
|
||||
assert "456" not in result
|
||||
# Prefixed key should be preserved
|
||||
assert result.get("telegram:789") == "off"
|
||||
# Warning should be logged for each legacy key
|
||||
assert mock_logger.warning.called
|
||||
warning_calls = [str(call) for call in mock_logger.warning.call_args_list]
|
||||
assert any("Skipping legacy unprefixed voice mode key" in str(c) for c in warning_calls)
|
||||
|
||||
def test_load_voice_modes_preserves_prefixed_keys(self):
|
||||
"""_load_voice_modes correctly loads platform-prefixed keys."""
|
||||
runner = _make_runner()
|
||||
|
||||
persisted_data = {
|
||||
"telegram:123": "all",
|
||||
"slack:456": "voice_only",
|
||||
"discord:789": "off",
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
voice_path = Path(tmpdir) / "gateway_voice_mode.json"
|
||||
voice_path.write_text(json.dumps(persisted_data))
|
||||
|
||||
with patch.object(runner, "_VOICE_MODE_PATH", voice_path):
|
||||
result = runner._load_voice_modes()
|
||||
|
||||
assert result.get("telegram:123") == "all"
|
||||
assert result.get("slack:456") == "voice_only"
|
||||
assert result.get("discord:789") == "off"
|
||||
|
||||
def test_load_voice_modes_invalid_modes_filtered(self):
|
||||
"""_load_voice_modes filters out invalid mode values."""
|
||||
runner = _make_runner()
|
||||
|
||||
data = {
|
||||
"telegram:123": "all",
|
||||
"telegram:456": "invalid_mode",
|
||||
"telegram:789": "voice_only",
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
voice_path = Path(tmpdir) / "gateway_voice_mode.json"
|
||||
voice_path.write_text(json.dumps(data))
|
||||
|
||||
with patch.object(runner, "_VOICE_MODE_PATH", voice_path):
|
||||
result = runner._load_voice_modes()
|
||||
|
||||
assert result.get("telegram:123") == "all"
|
||||
assert "telegram:456" not in result
|
||||
assert result.get("telegram:789") == "voice_only"
|
||||
|
||||
|
||||
class TestSyncVoiceModeStateToAdapter:
|
||||
"""Test _sync_voice_mode_state_to_adapter filters by platform."""
|
||||
|
||||
def test_sync_only_includes_platform_chats(self):
|
||||
"""Only chats matching the adapter's platform are synced."""
|
||||
runner = _make_runner()
|
||||
|
||||
# Set up voice mode state with multiple platforms
|
||||
runner._voice_mode = {
|
||||
"telegram:123": "off", # Should sync
|
||||
"telegram:456": "all", # Should NOT sync (mode is not "off")
|
||||
"slack:123": "off", # Should NOT sync (different platform)
|
||||
"discord:789": "off", # Should NOT sync (different platform)
|
||||
}
|
||||
|
||||
# Create a mock Telegram adapter
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.platform = Platform.TELEGRAM
|
||||
mock_adapter._auto_tts_disabled_chats = set()
|
||||
|
||||
runner._sync_voice_mode_state_to_adapter(mock_adapter)
|
||||
|
||||
# Only telegram:123 should be in disabled_chats (mode="off" for telegram)
|
||||
assert mock_adapter._auto_tts_disabled_chats == {"123"}
|
||||
|
||||
def test_sync_clears_existing_state(self):
|
||||
"""_sync_voice_mode_state_to_adapter clears existing disabled_chats first."""
|
||||
runner = _make_runner()
|
||||
|
||||
runner._voice_mode = {
|
||||
"telegram:123": "off",
|
||||
}
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.platform = Platform.TELEGRAM
|
||||
mock_adapter._auto_tts_disabled_chats = {"old_chat_id", "another_old"}
|
||||
|
||||
runner._sync_voice_mode_state_to_adapter(mock_adapter)
|
||||
|
||||
# Old entries should be cleared
|
||||
assert mock_adapter._auto_tts_disabled_chats == {"123"}
|
||||
|
||||
def test_sync_returns_early_without_platform(self):
|
||||
"""_sync_voice_mode_state_to_adapter returns early if adapter has no platform."""
|
||||
runner = _make_runner()
|
||||
runner._voice_mode = {"telegram:123": "off"}
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.platform = None
|
||||
mock_adapter._auto_tts_disabled_chats = {"old"}
|
||||
|
||||
runner._sync_voice_mode_state_to_adapter(mock_adapter)
|
||||
|
||||
# disabled_chats should not be modified
|
||||
assert mock_adapter._auto_tts_disabled_chats == {"old"}
|
||||
|
||||
def test_sync_returns_early_without_auto_tts_disabled_chats(self):
|
||||
"""_sync_voice_mode_state_to_adapter returns early if adapter lacks _auto_tts_disabled_chats."""
|
||||
runner = _make_runner()
|
||||
runner._voice_mode = {"telegram:123": "off"}
|
||||
|
||||
mock_adapter = MagicMock(spec=[]) # No _auto_tts_disabled_chats attribute
|
||||
|
||||
# Should not raise
|
||||
runner._sync_voice_mode_state_to_adapter(mock_adapter)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_runner() -> GatewayRunner:
|
||||
"""Create a minimal GatewayRunner for testing."""
|
||||
with patch("gateway.run.GatewayRunner._load_voice_modes", return_value={}):
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._voice_mode = {}
|
||||
runner.adapters = {}
|
||||
return runner
|
||||
473
tests/gateway/test_webhook_deliver_only.py
Normal file
473
tests/gateway/test_webhook_deliver_only.py
Normal file
|
|
@ -0,0 +1,473 @@
|
|||
"""Tests for the webhook adapter's ``deliver_only`` route mode.
|
||||
|
||||
``deliver_only`` lets external services (Supabase webhooks, monitoring
|
||||
alerts, background jobs, other agents) push plain-text notifications to
|
||||
a user's chat via the webhook adapter WITHOUT invoking the agent. The
|
||||
rendered prompt template becomes the literal message body.
|
||||
|
||||
Covers:
|
||||
- Agent is NOT invoked (``handle_message`` never called)
|
||||
- Rendered content is delivered to the target platform adapter
|
||||
- HTTP returns 200 OK on success, 502 on delivery failure
|
||||
- Startup validation rejects ``deliver_only`` without a real delivery target
|
||||
- HMAC auth, rate limiting, and idempotency still apply
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent, SendResult
|
||||
from gateway.platforms.webhook import WebhookAdapter, _INSECURE_NO_AUTH
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_adapter(routes, **extra_kw) -> WebhookAdapter:
|
||||
extra = {"host": "0.0.0.0", "port": 0, "routes": routes}
|
||||
extra.update(extra_kw)
|
||||
config = PlatformConfig(enabled=True, extra=extra)
|
||||
return WebhookAdapter(config)
|
||||
|
||||
|
||||
def _create_app(adapter: WebhookAdapter) -> web.Application:
|
||||
app = web.Application()
|
||||
app.router.add_get("/health", adapter._handle_health)
|
||||
app.router.add_post("/webhooks/{route_name}", adapter._handle_webhook)
|
||||
return app
|
||||
|
||||
|
||||
def _wire_mock_target(adapter: WebhookAdapter, platform_name: str = "telegram"):
|
||||
"""Attach a gateway_runner with a mocked target adapter."""
|
||||
mock_target = AsyncMock()
|
||||
mock_target.send = AsyncMock(return_value=SendResult(success=True))
|
||||
|
||||
mock_runner = MagicMock()
|
||||
mock_runner.adapters = {Platform(platform_name): mock_target}
|
||||
mock_runner.config.get_home_channel.return_value = None
|
||||
|
||||
adapter.gateway_runner = mock_runner
|
||||
return mock_target
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Core behaviour: agent bypass
|
||||
# ===================================================================
|
||||
|
||||
class TestDeliverOnlyBypassesAgent:
|
||||
"""The whole point of the feature — handle_message must not be called."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_delivers_directly_without_agent(self):
|
||||
routes = {
|
||||
"match-alert": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "12345"},
|
||||
"prompt": "{payload.user} matched with {payload.other}!",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
mock_target = _wire_mock_target(adapter)
|
||||
|
||||
# Guard: handle_message must NOT be called in deliver_only mode
|
||||
handle_message_calls: list[MessageEvent] = []
|
||||
|
||||
async def _capture(event):
|
||||
handle_message_calls.append(event)
|
||||
|
||||
adapter.handle_message = _capture
|
||||
|
||||
app = _create_app(adapter)
|
||||
body = json.dumps(
|
||||
{"payload": {"user": "alice", "other": "bob"}}
|
||||
).encode()
|
||||
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/match-alert",
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-GitHub-Delivery": "delivery-1",
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["status"] == "delivered"
|
||||
assert data["route"] == "match-alert"
|
||||
assert data["target"] == "telegram"
|
||||
|
||||
# Let any background tasks settle before asserting no agent call
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Agent was NOT invoked
|
||||
assert handle_message_calls == []
|
||||
|
||||
# Target adapter.send() WAS called with the rendered template
|
||||
mock_target.send.assert_awaited_once()
|
||||
call_args = mock_target.send.await_args
|
||||
chat_id_arg, content_arg = call_args.args[0], call_args.args[1]
|
||||
assert chat_id_arg == "12345"
|
||||
assert content_arg == "alice matched with bob!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_template_rendering_works(self):
|
||||
"""Dot-notation template variables resolve in deliver_only mode."""
|
||||
routes = {
|
||||
"alert": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "chat-1"},
|
||||
"prompt": "Build {build.number} status: {build.status}",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
mock_target = _wire_mock_target(adapter)
|
||||
app = _create_app(adapter)
|
||||
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/alert",
|
||||
json={"build": {"number": 77, "status": "FAILED"}},
|
||||
headers={"X-GitHub-Delivery": "d-render-1"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
|
||||
mock_target.send.assert_awaited_once()
|
||||
content_arg = mock_target.send.await_args.args[1]
|
||||
assert content_arg == "Build 77 status: FAILED"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_id_passed_through(self):
|
||||
"""deliver_extra.thread_id flows through to the target adapter."""
|
||||
routes = {
|
||||
"r": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "c-1", "thread_id": "topic-42"},
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
mock_target = _wire_mock_target(adapter)
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": "d-thread-1"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
|
||||
assert mock_target.send.await_args.kwargs["metadata"] == {
|
||||
"thread_id": "topic-42"
|
||||
}
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# HTTP status codes
|
||||
# ===================================================================
|
||||
|
||||
class TestDeliverOnlyStatusCodes:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delivery_failure_returns_502(self):
|
||||
"""If the target adapter returns SendResult(success=False), 502."""
|
||||
routes = {
|
||||
"r": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "c-1"},
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
mock_target = _wire_mock_target(adapter)
|
||||
mock_target.send = AsyncMock(
|
||||
return_value=SendResult(success=False, error="rate limited by tg")
|
||||
)
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": "d-fail-1"},
|
||||
)
|
||||
assert resp.status == 502
|
||||
data = await resp.json()
|
||||
# Generic error — no adapter-level detail leaks
|
||||
assert data["error"] == "Delivery failed"
|
||||
assert "rate limited" not in json.dumps(data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delivery_exception_returns_502(self):
|
||||
"""If adapter.send() raises, we return 502 (not 500)."""
|
||||
routes = {
|
||||
"r": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "c-1"},
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
mock_target = _wire_mock_target(adapter)
|
||||
mock_target.send = AsyncMock(side_effect=RuntimeError("tg exploded"))
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": "d-exc-1"},
|
||||
)
|
||||
assert resp.status == 502
|
||||
data = await resp.json()
|
||||
assert data["error"] == "Delivery failed"
|
||||
# Exception message must not leak
|
||||
assert "exploded" not in json.dumps(data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_target_platform_not_connected_returns_502(self):
|
||||
"""deliver_only to a platform the gateway doesn't have → 502."""
|
||||
routes = {
|
||||
"r": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "discord", # not configured in mock runner
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "c-1"},
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
_wire_mock_target(adapter, platform_name="telegram") # only TG wired
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": "d-no-platform-1"},
|
||||
)
|
||||
assert resp.status == 502
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Startup validation
|
||||
# ===================================================================
|
||||
|
||||
class TestDeliverOnlyStartupValidation:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deliver_only_with_log_deliver_rejected(self):
|
||||
"""deliver_only=true + deliver=log is nonsense — reject at connect()."""
|
||||
routes = {
|
||||
"bad": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "log",
|
||||
"deliver_only": True,
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
with pytest.raises(ValueError, match="deliver_only=true but deliver is 'log'"):
|
||||
await adapter.connect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deliver_only_with_missing_deliver_rejected(self):
|
||||
"""deliver_only=true with no deliver field defaults to 'log' → reject."""
|
||||
routes = {
|
||||
"bad": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
# no deliver field
|
||||
"deliver_only": True,
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
with pytest.raises(ValueError, match="deliver_only=true"):
|
||||
await adapter.connect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deliver_only_with_real_target_accepted(self):
|
||||
"""Sanity check — a valid deliver_only config passes validation."""
|
||||
routes = {
|
||||
"good": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "c-1"},
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
# connect() does more than validation (binds a socket) — we just
|
||||
# want to verify the validation doesn't raise. Call it and tear
|
||||
# down immediately.
|
||||
try:
|
||||
started = await adapter.connect()
|
||||
if started:
|
||||
await adapter.disconnect()
|
||||
except ValueError:
|
||||
pytest.fail("valid deliver_only config should not raise ValueError")
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Security + reliability invariants still hold
|
||||
# ===================================================================
|
||||
|
||||
class TestDeliverOnlySecurityInvariants:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hmac_still_enforced(self):
|
||||
"""deliver_only does NOT bypass HMAC validation."""
|
||||
secret = "real-secret-123"
|
||||
routes = {
|
||||
"r": {
|
||||
"secret": secret,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "c-1"},
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
mock_target = _wire_mock_target(adapter)
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
# No signature header → reject
|
||||
resp = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": "d-noauth-1"},
|
||||
)
|
||||
assert resp.status == 401
|
||||
|
||||
# Target never called
|
||||
mock_target.send.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_idempotency_still_applies(self):
|
||||
"""Same delivery_id posted twice → second is suppressed."""
|
||||
routes = {
|
||||
"r": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "c-1"},
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
mock_target = _wire_mock_target(adapter)
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
r1 = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": "dup-1"},
|
||||
)
|
||||
assert r1.status == 200
|
||||
|
||||
r2 = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": "dup-1"},
|
||||
)
|
||||
# Existing webhook adapter treats duplicates as 200 + status=duplicate
|
||||
assert r2.status == 200
|
||||
data = await r2.json()
|
||||
assert data["status"] == "duplicate"
|
||||
|
||||
# Target was called exactly once
|
||||
assert mock_target.send.await_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_still_applies(self):
|
||||
"""Route-level rate limit caps deliver_only POSTs too."""
|
||||
routes = {
|
||||
"r": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"deliver": "telegram",
|
||||
"deliver_only": True,
|
||||
"deliver_extra": {"chat_id": "c-1"},
|
||||
"prompt": "hi",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes, rate_limit=2)
|
||||
_wire_mock_target(adapter)
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
for i in range(2):
|
||||
r = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": f"rl-{i}"},
|
||||
)
|
||||
assert r.status == 200
|
||||
|
||||
# Third within the window → 429
|
||||
r3 = await cli.post(
|
||||
"/webhooks/r",
|
||||
json={},
|
||||
headers={"X-GitHub-Delivery": "rl-3"},
|
||||
)
|
||||
assert r3.status == 429
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Unit: _direct_deliver dispatch
|
||||
# ===================================================================
|
||||
|
||||
class TestDirectDeliverUnit:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatches_to_cross_platform_for_messaging_targets(self):
|
||||
adapter = _make_adapter({})
|
||||
mock_target = _wire_mock_target(adapter, "telegram")
|
||||
|
||||
result = await adapter._direct_deliver(
|
||||
"hello",
|
||||
{"deliver": "telegram", "deliver_extra": {"chat_id": "c-1"}},
|
||||
)
|
||||
assert result.success is True
|
||||
mock_target.send.assert_awaited_once_with(
|
||||
"c-1", "hello", metadata=None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatches_to_github_comment(self):
|
||||
adapter = _make_adapter({})
|
||||
with patch.object(
|
||||
adapter, "_deliver_github_comment",
|
||||
new=AsyncMock(return_value=SendResult(success=True)),
|
||||
) as mock_gh:
|
||||
result = await adapter._direct_deliver(
|
||||
"review body",
|
||||
{
|
||||
"deliver": "github_comment",
|
||||
"deliver_extra": {"repo": "org/r", "pr_number": "1"},
|
||||
},
|
||||
)
|
||||
assert result.success is True
|
||||
mock_gh.assert_awaited_once()
|
||||
289
tests/gateway/test_webhook_signature_rate_limit.py
Normal file
289
tests/gateway/test_webhook_signature_rate_limit.py
Normal file
|
|
@ -0,0 +1,289 @@
|
|||
"""Test that HMAC signature validation happens BEFORE rate limiting.
|
||||
|
||||
This verifies the fix for bug #12544: invalid signature requests must NOT
|
||||
consume rate-limit quota. Before the fix, rate limiting was applied before
|
||||
signature validation, so an attacker could exhaust a victim's rate limit
|
||||
with invalidly-signed requests and then make valid requests that get rejected
|
||||
with 429.
|
||||
|
||||
The correct order is:
|
||||
1. Read body
|
||||
2. Validate HMAC signature (reject 401 if invalid)
|
||||
3. Rate limit check (reject 429 if over limit)
|
||||
4. Process the webhook
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from gateway.platforms.webhook import WebhookAdapter
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
def _make_adapter(routes, rate_limit=5, **extra_kw) -> WebhookAdapter:
|
||||
"""Create a WebhookAdapter with the given routes."""
|
||||
extra = {
|
||||
"host": "0.0.0.0",
|
||||
"port": 0,
|
||||
"routes": routes,
|
||||
"rate_limit": rate_limit,
|
||||
}
|
||||
extra.update(extra_kw)
|
||||
config = PlatformConfig(enabled=True, extra=extra)
|
||||
return WebhookAdapter(config)
|
||||
|
||||
|
||||
def _create_app(adapter: WebhookAdapter) -> web.Application:
|
||||
"""Build the aiohttp Application from the adapter."""
|
||||
app = web.Application()
|
||||
app.router.add_get("/health", adapter._handle_health)
|
||||
app.router.add_post("/webhooks/{route_name}", adapter._handle_webhook)
|
||||
return app
|
||||
|
||||
|
||||
def _github_signature(body: bytes, secret: str) -> str:
|
||||
"""Compute X-Hub-Signature-256 for *body* using *secret*."""
|
||||
return "sha256=" + hmac.new(
|
||||
secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
|
||||
SIMPLE_PAYLOAD = {"event": "test", "data": "hello"}
|
||||
|
||||
|
||||
class TestSignatureBeforeRateLimit:
|
||||
"""Verify that invalid signatures do NOT consume rate limit quota."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_signature_does_not_consume_rate_limit(self):
|
||||
"""Send requests with invalid signatures up to the rate limit, then
|
||||
send a valid-signed request and verify it succeeds.
|
||||
|
||||
BEFORE FIX: Invalid signatures consume the rate limit bucket, so
|
||||
after 'rate_limit' bad requests the valid one would get 429.
|
||||
AFTER FIX: Invalid signatures are rejected with 401 first (before
|
||||
rate limiting), so the rate limit bucket is untouched. The valid
|
||||
request after many bad ones still succeeds.
|
||||
"""
|
||||
secret = "test-secret-key"
|
||||
route_name = "test-route"
|
||||
routes = {
|
||||
route_name: {
|
||||
"secret": secret,
|
||||
"events": ["push"],
|
||||
"prompt": "Event: {event}",
|
||||
"deliver": "log",
|
||||
}
|
||||
}
|
||||
rate_limit = 5
|
||||
adapter = _make_adapter(routes, rate_limit=rate_limit)
|
||||
|
||||
captured_events = []
|
||||
|
||||
async def _capture(event):
|
||||
captured_events.append(event)
|
||||
|
||||
adapter.handle_message = _capture
|
||||
app = _create_app(adapter)
|
||||
|
||||
body = json.dumps(SIMPLE_PAYLOAD).encode()
|
||||
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
# First exhaust the rate limit with invalid signatures
|
||||
for i in range(rate_limit):
|
||||
resp = await cli.post(
|
||||
f"/webhooks/{route_name}",
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-GitHub-Event": "push",
|
||||
"X-Hub-Signature-256": "sha256=invalid", # bad sig
|
||||
"X-GitHub-Delivery": f"bad-{i}",
|
||||
},
|
||||
)
|
||||
# Each invalid signature should be rejected with 401
|
||||
assert resp.status == 401, (
|
||||
f"Expected 401 for invalid signature, got {resp.status}"
|
||||
)
|
||||
|
||||
# Now send a valid-signed request — it MUST succeed (202)
|
||||
# BEFORE FIX: This would return 429 because the 5 bad requests
|
||||
# consumed the rate limit bucket.
|
||||
# AFTER FIX: Bad requests don't touch rate limiting, so valid
|
||||
# request succeeds.
|
||||
valid_sig = _github_signature(body, secret)
|
||||
resp = await cli.post(
|
||||
f"/webhooks/{route_name}",
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-GitHub-Event": "push",
|
||||
"X-Hub-Signature-256": valid_sig,
|
||||
"X-GitHub-Delivery": "good-001",
|
||||
},
|
||||
)
|
||||
assert resp.status == 202, (
|
||||
f"Expected 202 for valid request after invalid signatures, "
|
||||
f"got {resp.status}. Rate limit may have been consumed by "
|
||||
f"invalid requests (bug #12544 not fixed)."
|
||||
)
|
||||
|
||||
data = await resp.json()
|
||||
assert data["status"] == "accepted"
|
||||
|
||||
# The valid event should have been captured
|
||||
assert len(captured_events) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_signature_still_rate_limited(self):
|
||||
"""Verify that VALID requests still respect rate limiting normally."""
|
||||
secret = "test-secret-key"
|
||||
route_name = "test-route"
|
||||
routes = {
|
||||
route_name: {
|
||||
"secret": secret,
|
||||
"events": ["push"],
|
||||
"prompt": "Event: {event}",
|
||||
"deliver": "log",
|
||||
}
|
||||
}
|
||||
rate_limit = 3
|
||||
adapter = _make_adapter(routes, rate_limit=rate_limit)
|
||||
|
||||
captured_events = []
|
||||
|
||||
async def _capture(event):
|
||||
captured_events.append(event)
|
||||
|
||||
adapter.handle_message = _capture
|
||||
app = _create_app(adapter)
|
||||
|
||||
body = json.dumps(SIMPLE_PAYLOAD).encode()
|
||||
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
# Send 'rate_limit' valid requests — all should succeed
|
||||
for i in range(rate_limit):
|
||||
valid_sig = _github_signature(body, secret)
|
||||
resp = await cli.post(
|
||||
f"/webhooks/{route_name}",
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-GitHub-Event": "push",
|
||||
"X-Hub-Signature-256": valid_sig,
|
||||
"X-GitHub-Delivery": f"good-{i}",
|
||||
},
|
||||
)
|
||||
assert resp.status == 202
|
||||
|
||||
# The next valid request SHOULD be rate-limited
|
||||
valid_sig = _github_signature(body, secret)
|
||||
resp = await cli.post(
|
||||
f"/webhooks/{route_name}",
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-GitHub-Event": "push",
|
||||
"X-Hub-Signature-256": valid_sig,
|
||||
"X-GitHub-Delivery": "good-over-limit",
|
||||
},
|
||||
)
|
||||
assert resp.status == 429, (
|
||||
f"Expected 429 when exceeding rate limit with valid requests, "
|
||||
f"got {resp.status}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_valid_and_invalid_signatures(self):
|
||||
"""Interleave invalid and valid requests. Only valid ones count
|
||||
against the rate limit."""
|
||||
secret = "test-secret-key"
|
||||
route_name = "test-route"
|
||||
routes = {
|
||||
route_name: {
|
||||
"secret": secret,
|
||||
"events": ["push"],
|
||||
"prompt": "Event: {event}",
|
||||
"deliver": "log",
|
||||
}
|
||||
}
|
||||
rate_limit = 3
|
||||
adapter = _make_adapter(routes, rate_limit=rate_limit)
|
||||
|
||||
captured_events = []
|
||||
|
||||
async def _capture(event):
|
||||
captured_events.append(event)
|
||||
|
||||
adapter.handle_message = _capture
|
||||
app = _create_app(adapter)
|
||||
|
||||
body = json.dumps(SIMPLE_PAYLOAD).encode()
|
||||
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
# Send 2 valid requests (should succeed)
|
||||
for i in range(2):
|
||||
valid_sig = _github_signature(body, secret)
|
||||
resp = await cli.post(
|
||||
f"/webhooks/{route_name}",
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-GitHub-Event": "push",
|
||||
"X-Hub-Signature-256": valid_sig,
|
||||
"X-GitHub-Delivery": f"good-{i}",
|
||||
},
|
||||
)
|
||||
assert resp.status == 202
|
||||
|
||||
# Send 10 invalid requests (should all get 401, not consume quota)
|
||||
for i in range(10):
|
||||
resp = await cli.post(
|
||||
f"/webhooks/{route_name}",
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-GitHub-Event": "push",
|
||||
"X-Hub-Signature-256": "sha256=invalid",
|
||||
"X-GitHub-Delivery": f"bad-{i}",
|
||||
},
|
||||
)
|
||||
assert resp.status == 401
|
||||
|
||||
# One more valid request should STILL succeed (only 2 consumed)
|
||||
valid_sig = _github_signature(body, secret)
|
||||
resp = await cli.post(
|
||||
f"/webhooks/{route_name}",
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-GitHub-Event": "push",
|
||||
"X-Hub-Signature-256": valid_sig,
|
||||
"X-GitHub-Delivery": "good-3",
|
||||
},
|
||||
)
|
||||
assert resp.status == 202, (
|
||||
f"Expected 202 for 3rd valid request after many invalid ones, "
|
||||
f"got {resp.status}"
|
||||
)
|
||||
|
||||
# The 4th valid request should be rate-limited (2 + 2 = 4 = limit)
|
||||
valid_sig = _github_signature(body, secret)
|
||||
resp = await cli.post(
|
||||
f"/webhooks/{route_name}",
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-GitHub-Event": "push",
|
||||
"X-Hub-Signature-256": valid_sig,
|
||||
"X-GitHub-Delivery": "good-4",
|
||||
},
|
||||
)
|
||||
assert resp.status == 429
|
||||
|
||||
assert len(captured_events) == 3
|
||||
|
|
@ -211,6 +211,30 @@ class TestFileHandleClosedOnError:
|
|||
assert adapter._bridge_log_fh is None
|
||||
|
||||
|
||||
class TestConnectCleanup:
|
||||
"""Verify failure paths release the scoped session lock."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_releases_lock_when_npm_install_fails(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
def _path_exists(path_obj):
|
||||
return not str(path_obj).endswith("node_modules")
|
||||
|
||||
install_result = MagicMock(returncode=1, stderr="install failed")
|
||||
|
||||
with patch("gateway.platforms.whatsapp.check_whatsapp_requirements", return_value=True), \
|
||||
patch.object(Path, "exists", autospec=True, side_effect=_path_exists), \
|
||||
patch("subprocess.run", return_value=install_result), \
|
||||
patch("gateway.status.acquire_scoped_lock", return_value=(True, None)), \
|
||||
patch("gateway.status.release_scoped_lock") as mock_release:
|
||||
result = await adapter.connect()
|
||||
|
||||
assert result is False
|
||||
mock_release.assert_called_once_with("whatsapp-session", str(adapter._session_path))
|
||||
assert adapter._platform_lock_identity is None
|
||||
|
||||
|
||||
class TestBridgeRuntimeFailure:
|
||||
"""Verify runtime bridge death is surfaced as a fatal adapter error."""
|
||||
|
||||
|
|
@ -429,6 +453,33 @@ class TestKillPortProcess:
|
|||
class TestHttpSessionLifecycle:
|
||||
"""Verify persistent aiohttp.ClientSession is created and cleaned up."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_uses_taskkill_tree_on_windows(self):
|
||||
"""Windows disconnect should target the bridge process tree, not just the parent PID."""
|
||||
adapter = _make_adapter()
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.pid = 12345
|
||||
mock_proc.poll.side_effect = [0]
|
||||
adapter._bridge_process = mock_proc
|
||||
adapter._poll_task = None
|
||||
adapter._http_session = None
|
||||
adapter._running = True
|
||||
adapter._session_lock_identity = None
|
||||
|
||||
with patch("gateway.platforms.whatsapp._IS_WINDOWS", True), \
|
||||
patch("gateway.platforms.whatsapp.subprocess.run", return_value=MagicMock(returncode=0)) as mock_run, \
|
||||
patch("gateway.platforms.whatsapp.asyncio.sleep", new_callable=AsyncMock):
|
||||
await adapter.disconnect()
|
||||
|
||||
mock_run.assert_called_once_with(
|
||||
["taskkill", "/PID", "12345", "/T"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
mock_proc.terminate.assert_not_called()
|
||||
mock_proc.kill.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_closed_on_disconnect(self):
|
||||
"""disconnect() should close self._http_session."""
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue