Merge branch 'main' into docker_s6

This commit is contained in:
Ben Barclay 2026-05-25 09:39:27 +10:00 committed by GitHub
commit 59da190512
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
417 changed files with 26434 additions and 3321 deletions

View file

@ -971,6 +971,18 @@ class TestSessionConfiguration:
"hermes_cli.runtime_provider.resolve_runtime_provider",
fake_resolve_runtime_provider,
)
# Pin the parser so this test doesn't depend on live
# ``_KNOWN_PROVIDER_NAMES`` / ``_PROVIDER_ALIASES`` module state
# (sibling of the same hardening on
# ``test_model_switch_uses_requested_provider``).
monkeypatch.setattr(
"hermes_cli.models.parse_model_input",
lambda raw, current: ("anthropic", "claude-sonnet-4-6"),
)
monkeypatch.setattr(
"hermes_cli.models.detect_provider_for_model",
lambda model, current: None,
)
manager = SessionManager(db=SessionDB(tmp_path / "state.db"))
with patch("run_agent.AIAgent", side_effect=fake_agent):
@ -1191,6 +1203,48 @@ class TestPrompt:
assert len(agent_chunks) == 1
assert agent_chunks[0].content.text == "streamed answer"
@pytest.mark.asyncio
async def test_prompt_delivers_transformed_response_after_streaming(self, agent):
"""If a transform_llm_output plugin hook modifies the response after
streaming, ACP must deliver the transformed final_response so the
appended/rewritten text reaches the client.
"""
new_resp = await agent.new_session(cwd=".")
state = agent.session_manager.get_session(new_resp.session_id)
def mock_run(*args, **kwargs):
state.agent.stream_delta_callback("original answer")
return {
"final_response": "original answer\n\n[plugin appended this]",
"response_transformed": True,
"messages": [],
}
state.agent.run_conversation = mock_run
mock_conn = MagicMock(spec=acp.Client)
mock_conn.session_update = AsyncMock()
agent._conn = mock_conn
prompt = [TextContentBlock(type="text", text="hello")]
await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
updates = [
call.kwargs.get("update") or call.args[1]
for call in mock_conn.session_update.call_args_list
]
# The streamed chunk and the post-stream transformed message should
# both be present (final delivery is a separate update_agent_message_text
# call carrying the full transformed text).
all_texts = [
getattr(getattr(u, "content", None), "text", None)
for u in updates
]
assert any(
text and "[plugin appended this]" in text for text in all_texts
), f"expected transformed final to be delivered, got: {all_texts!r}"
@pytest.mark.asyncio
async def test_prompt_auto_titles_session(self, agent):
new_resp = await agent.new_session(cwd=".")
@ -1543,6 +1597,20 @@ class TestSlashCommands:
"hermes_cli.runtime_provider.resolve_runtime_provider",
fake_resolve_runtime_provider,
)
# Pin the model-string parser independently of the live
# ``_KNOWN_PROVIDER_NAMES`` / ``_PROVIDER_ALIASES`` module state.
# Otherwise any test in the same xdist worker that mutates those
# globals (e.g. registers a custom provider that shadows
# ``anthropic``) flakes this one — observed once in CI as
# ``'custom' == 'anthropic'``.
monkeypatch.setattr(
"hermes_cli.models.parse_model_input",
lambda raw, current: ("anthropic", "claude-sonnet-4-6"),
)
monkeypatch.setattr(
"hermes_cli.models.detect_provider_for_model",
lambda model, current: None,
)
manager = SessionManager(db=SessionDB(tmp_path / "state.db"))
with patch("run_agent.AIAgent", side_effect=fake_agent):

View file

@ -0,0 +1,250 @@
"""Tests for GH-25255: Anthropic OAuth mcp_ prefix stripping.
When strip_tool_prefix=True (Anthropic OAuth path), the transport must only
strip the ``mcp_`` prefix from OAuth-injected tools, NOT from Hermes-native
MCP server tools that are registered under their full ``mcp_<server>_<tool>``
name in the tool registry.
"""
from __future__ import annotations
import json
from types import SimpleNamespace
from unittest.mock import patch
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_tool_use_block(name: str, block_id: str = "tc_1", input_data: dict | None = None):
"""Create a fake Anthropic tool_use content block."""
return SimpleNamespace(
type="tool_use",
id=block_id,
name=name,
input=input_data or {"query": "test"},
)
def _make_response(*blocks, stop_reason="end_turn"):
"""Create a fake Anthropic Messages response."""
return SimpleNamespace(
content=list(blocks),
stop_reason=stop_reason,
model="claude-sonnet-4",
usage=SimpleNamespace(input_tokens=100, output_tokens=50),
)
class _FakeRegistry:
"""Minimal fake tool registry for testing prefix stripping logic."""
def __init__(self, registered_names: set[str]):
self._names = registered_names
def get_entry(self, name: str):
if name in self._names:
return SimpleNamespace(name=name) # truthy = tool exists
return None
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestAnthropicMcpPrefixStrip:
"""Verify that strip_tool_prefix only strips OAuth-injected prefixes."""
def _get_transport(self):
from agent.transports.anthropic import AnthropicTransport
return AnthropicTransport()
def test_strips_prefix_for_oauth_injected_tool(self):
"""OAuth tools: mcp_read_file -> read_file (stripped).
The tool was registered as 'read_file' in the registry.
Anthropic sees 'mcp_read_file' because Hermes adds the prefix.
On response, we must strip it back to 'read_file'.
"""
transport = self._get_transport()
block = _make_tool_use_block("mcp_read_file")
response = _make_response(block)
registry = _FakeRegistry({"read_file", "terminal", "web_search"})
with patch("tools.registry.registry", registry):
result = transport.normalize_response(response, strip_tool_prefix=True)
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "read_file"
def test_preserves_native_mcp_server_tool_name(self):
"""Native MCP tools: mcp_composio_SEARCH -> mcp_composio_SEARCH (kept).
The tool is registered with the full mcp_ prefix in the registry.
Stripping would break registry lookup.
"""
transport = self._get_transport()
block = _make_tool_use_block("mcp_composio_COMPOSIO_SEARCH_TOOLS")
response = _make_response(block)
registry = _FakeRegistry({
"mcp_composio_COMPOSIO_SEARCH_TOOLS",
"mcp_composio_COMPOSIO_GET_TOOL_SCHEMAS",
"read_file",
})
with patch("tools.registry.registry", registry):
result = transport.normalize_response(response, strip_tool_prefix=True)
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "mcp_composio_COMPOSIO_SEARCH_TOOLS"
def test_no_strip_when_flag_false(self):
"""When strip_tool_prefix=False, names are never modified."""
transport = self._get_transport()
block = _make_tool_use_block("mcp_read_file")
response = _make_response(block)
registry = _FakeRegistry({"read_file"})
with patch("tools.registry.registry", registry):
result = transport.normalize_response(response, strip_tool_prefix=False)
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "mcp_read_file"
def test_no_strip_when_not_mcp_prefixed(self):
"""Non-mcp_ names are untouched regardless of strip flag."""
transport = self._get_transport()
block = _make_tool_use_block("web_search")
response = _make_response(block)
registry = _FakeRegistry({"web_search"})
with patch("tools.registry.registry", registry):
result = transport.normalize_response(response, strip_tool_prefix=True)
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "web_search"
def test_preserves_name_when_neither_in_registry(self):
"""When neither stripped nor full name is in registry, keep full name.
Safety fallback: if we can't determine the type, prefer the full name
since it's what the LLM was told about.
"""
transport = self._get_transport()
block = _make_tool_use_block("mcp_unknown_tool")
response = _make_response(block)
registry = _FakeRegistry({"read_file"}) # neither name registered
with patch("tools.registry.registry", registry):
result = transport.normalize_response(response, strip_tool_prefix=True)
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "mcp_unknown_tool"
def test_mixed_tools_same_response(self):
"""Both OAuth and native MCP tools in the same response."""
transport = self._get_transport()
block1 = _make_tool_use_block("mcp_read_file", block_id="tc_1")
block2 = _make_tool_use_block("mcp_composio_SEARCH", block_id="tc_2")
block3 = _make_tool_use_block("mcp_composio_SEARCH", block_id="tc_3") # also registered natively
response = _make_response(block1, block2, block3)
registry = _FakeRegistry({
"read_file", # OAuth-injected
"mcp_composio_SEARCH", # native MCP
})
with patch("tools.registry.registry", registry):
result = transport.normalize_response(response, strip_tool_prefix=True)
assert len(result.tool_calls) == 3
# OAuth tool: stripped
assert result.tool_calls[0].name == "read_file"
# Native MCP: preserved (both stripped and full are registered, full wins)
assert result.tool_calls[1].name == "mcp_composio_SEARCH"
assert result.tool_calls[2].name == "mcp_composio_SEARCH"
def test_both_stripped_and_full_registered_prefers_full(self):
"""Edge case: both 'foo' and 'mcp_foo' exist in registry.
Keep 'mcp_foo' (the original name) since it's what the LLM requested.
"""
transport = self._get_transport()
block = _make_tool_use_block("mcp_foo")
response = _make_response(block)
registry = _FakeRegistry({"foo", "mcp_foo"})
with patch("tools.registry.registry", registry):
result = transport.normalize_response(response, strip_tool_prefix=True)
assert len(result.tool_calls) == 1
# Both exist — the condition `get_entry(stripped) and not get_entry(name)`
# is False because get_entry(name) IS truthy, so we keep the full name.
assert result.tool_calls[0].name == "mcp_foo"
class TestAnthropicOAuthOutgoingPrefix:
"""Verify the outgoing-side companion fix: build_anthropic_kwargs must not
double-prefix tool names that already start with ``mcp_`` (native MCP server
tools registered as ``mcp_<server>_<tool>``). GH-25255."""
def _build(self, tools, is_oauth=True):
from agent.anthropic_adapter import build_anthropic_kwargs
return build_anthropic_kwargs(
model="claude-sonnet-4-6",
messages=[{"role": "user", "content": "Hi"}],
tools=tools,
max_tokens=4096,
reasoning_config=None,
is_oauth=is_oauth,
)
def test_oauth_adds_prefix_to_bare_tool_name(self):
"""OAuth + bare name → prefix added (existing Claude Code convention)."""
kwargs = self._build([{
"type": "function",
"function": {"name": "read_file", "description": "x", "parameters": {}},
}])
names = [t["name"] for t in kwargs["tools"]]
assert names == ["mcp_read_file"]
def test_oauth_does_not_double_prefix_native_mcp_tool(self):
"""OAuth + already-prefixed native MCP name → left alone."""
kwargs = self._build([{
"type": "function",
"function": {
"name": "mcp_composio_COMPOSIO_SEARCH_TOOLS",
"description": "x",
"parameters": {},
},
}])
names = [t["name"] for t in kwargs["tools"]]
# Must NOT become "mcp_mcp_composio_..." — that breaks the round-trip
# because normalize_response only strips ONE mcp_ prefix.
assert names == ["mcp_composio_COMPOSIO_SEARCH_TOOLS"]
def test_oauth_mixed_native_and_bare_tools(self):
"""Mixed: native MCP preserved, bare names prefixed."""
kwargs = self._build([
{"type": "function", "function": {"name": "read_file",
"description": "x", "parameters": {}}},
{"type": "function", "function": {"name": "mcp_composio_SEARCH",
"description": "y", "parameters": {}}},
{"type": "function", "function": {"name": "terminal",
"description": "z", "parameters": {}}},
])
names = sorted(t["name"] for t in kwargs["tools"])
assert names == ["mcp_composio_SEARCH", "mcp_read_file", "mcp_terminal"]
def test_non_oauth_path_untouched(self):
"""Non-OAuth requests never get the prefix — schemas pass through as-is."""
kwargs = self._build([
{"type": "function", "function": {"name": "read_file",
"description": "x", "parameters": {}}},
{"type": "function", "function": {"name": "mcp_composio_SEARCH",
"description": "y", "parameters": {}}},
], is_oauth=False)
names = sorted(t["name"] for t in kwargs["tools"])
assert names == ["mcp_composio_SEARCH", "read_file"]

View file

@ -198,22 +198,32 @@ class TestGatewayBridgeCodeParity:
"""Verify the gateway/run.py config bridge contains the auxiliary section."""
def test_gateway_has_auxiliary_bridge(self):
"""The gateway config bridge must include auxiliary.* bridging."""
"""The gateway config bridge must include auxiliary.* bridging.
After the plugin-aux-task API refactor (2026-05), gateway env-var
names are derived dynamically (``AUXILIARY_<KEY_UPPER>_*``) so the
literal strings ``AUXILIARY_VISION_PROVIDER`` etc. no longer appear
in source. Assert the dynamic shape and the canonical built-in keys
bridged set instead.
"""
gateway_path = Path(__file__).parent.parent.parent / "gateway" / "run.py"
# Pin encoding to UTF-8: source files in this repo are UTF-8, but
# Path.read_text() defaults to the system locale — which is cp1252
# on most Western Windows installs and crashes as soon as the file
# contains any non-ASCII byte (e.g. an em-dash in a comment).
content = gateway_path.read_text(encoding="utf-8")
# Check for key patterns that indicate the bridge is present
assert "AUXILIARY_VISION_PROVIDER" in content
assert "AUXILIARY_VISION_MODEL" in content
assert "AUXILIARY_VISION_BASE_URL" in content
assert "AUXILIARY_VISION_API_KEY" in content
assert "AUXILIARY_WEB_EXTRACT_PROVIDER" in content
assert "AUXILIARY_WEB_EXTRACT_MODEL" in content
assert "AUXILIARY_WEB_EXTRACT_BASE_URL" in content
assert "AUXILIARY_WEB_EXTRACT_API_KEY" in content
# Dynamic env-var derivation present
assert 'f"AUXILIARY_{_upper}_PROVIDER"' in content
assert 'f"AUXILIARY_{_upper}_MODEL"' in content
assert 'f"AUXILIARY_{_upper}_BASE_URL"' in content
assert 'f"AUXILIARY_{_upper}_API_KEY"' in content
# Built-in bridged keys present
assert "_aux_bridged_keys" in content
assert '"vision"' in content
assert '"web_extract"' in content
assert '"approval"' in content
# Plugin-aux-task discovery hooked into bridging
assert "get_plugin_auxiliary_tasks" in content
def test_gateway_no_compression_env_bridge(self):
"""Gateway should NOT bridge compression config to env vars (config-only)."""

View file

@ -65,11 +65,11 @@ class TestCompress:
assert result == msgs
def test_truncation_fallback_no_client(self, compressor):
# compressor has client=None and abort_on_summary_failure=False (default),
# so the LEGACY fallback path inserts a static "summary unavailable"
# placeholder and the middle window is dropped.
# Simulate "no summarizer available" explicitly. call_llm can otherwise
# discover the developer's real auxiliary credentials from auth state.
msgs = [{"role": "system", "content": "System prompt"}] + self._make_messages(10)
result = compressor.compress(msgs)
with patch("agent.context_compressor.call_llm", side_effect=RuntimeError("no provider")):
result = compressor.compress(msgs)
assert len(result) < len(msgs)
# Should keep system message and last N
assert result[0]["role"] == "system"

View file

@ -0,0 +1,243 @@
"""Tests for get_cute_tool_message todo progress display.
Verifies the completion status rendering (done/total ) on all three
todo tool call paths: read, create (merge=False), update (merge=True).
"""
import json
import pytest
from agent.display import get_cute_tool_message
def _todo_result(total: int, completed: int) -> str:
"""Build a fake todo_tool return value."""
return json.dumps({
"todos": [],
"summary": {
"total": total,
"pending": total - completed,
"in_progress": 0,
"completed": completed,
"cancelled": 0,
},
})
class TestTodoRead:
"""get_cute_tool_message(…, result=…) when todos_arg is None (read path)."""
def test_read_no_result(self):
msg = get_cute_tool_message("todo", {}, 0.5)
assert "reading tasks" in msg
assert "0.5s" in msg
def test_read_with_progress(self):
msg = get_cute_tool_message("todo", {}, 0.5,
result=_todo_result(4, 2))
assert "2/4" in msg
assert "task(s)" in msg
def test_read_all_done(self):
msg = get_cute_tool_message("todo", {}, 0.5,
result=_todo_result(4, 4))
assert "4/4" in msg
assert "task(s)" in msg
def test_read_zero_total(self):
"""Edge case: empty todo list returns summary with total=0."""
msg = get_cute_tool_message("todo", {}, 0.5,
result=_todo_result(0, 0))
assert "reading tasks" in msg
def test_read_invalid_result_fallback(self):
"""Garbage result should not crash; fall back to reading tasks."""
msg = get_cute_tool_message("todo", {}, 0.5, result="not json")
assert "reading tasks" in msg
def test_read_result_missing_summary(self):
msg = get_cute_tool_message("todo", {}, 0.5,
result='{"todos": []}')
assert "reading tasks" in msg
class TestTodoCreate:
"""get_cute_tool_message when merge=False (new plan creation)."""
def test_create_default(self):
"""Brand-new plan: all pending, no result — plain count."""
msg = get_cute_tool_message("todo",
{"todos": [
{"id": "a", "content": "x", "status": "pending"},
]}, 0.3)
assert "1 task(s)" in msg
assert "0.3s" in msg
assert "/" not in msg # no progress fraction
def test_create_multiple(self):
msg = get_cute_tool_message("todo",
{"todos": [
{"id": "a", "content": "x", "status": "pending"},
{"id": "b", "content": "y", "status": "pending"},
{"id": "c", "content": "z", "status": "pending"},
]}, 0.2)
assert "3 task(s)" in msg
def test_create_with_result_shows_progress_when_done(self):
"""Even on create, if result has completed tasks show it."""
msg = get_cute_tool_message("todo",
{"todos": [{"id": "a", "content": "x", "status": "completed"}]},
0.4,
result=_todo_result(1, 1))
assert "1/1" in msg
assert "task(s)" in msg
def test_create_with_result_zero_done(self):
"""New plan with 0 done — plain count, no progress fraction."""
msg = get_cute_tool_message("todo",
{"todos": [
{"id": "a", "content": "x", "status": "pending"},
{"id": "b", "content": "y", "status": "pending"},
]},
0.3,
result=_todo_result(2, 0))
assert "2 task(s)" in msg
assert "/" not in msg
class TestTodoUpdate:
"""get_cute_tool_message when merge=True (incremental update)."""
def test_update_no_result(self):
"""No result available — plain update N task(s)."""
msg = get_cute_tool_message("todo",
{"todos": [{"id": "a", "status": "completed"}],
"merge": True}, 0.5)
assert "update 1 task(s)" in msg
def test_update_partial_progress(self):
"""1/4 tasks completed — show fraction with checkmark."""
msg = get_cute_tool_message("todo",
{"todos": [{"id": "a", "status": "completed"}],
"merge": True},
0.5,
result=_todo_result(4, 1))
assert "update" in msg
assert "1/4" in msg
assert "" in msg
def test_update_halfway(self):
"""2/4 — midpoint progress."""
msg = get_cute_tool_message("todo",
{"todos": [{"id": "b", "status": "in_progress"}],
"merge": True},
0.7,
result=_todo_result(4, 2))
assert "2/4" in msg
assert "" in msg
def test_update_all_completed(self):
"""4/4 — full checkmark."""
msg = get_cute_tool_message("todo",
{"todos": [{"id": "d", "status": "completed"}],
"merge": True},
0.2,
result=_todo_result(4, 4))
assert "4/4" in msg
assert "" in msg
def test_update_zero_done(self):
"""No completed tasks yet — plain update N task(s)."""
msg = get_cute_tool_message("todo",
{"todos": [{"id": "a", "status": "pending"}],
"merge": True},
0.3,
result=_todo_result(3, 0))
assert "update 1 task(s)" in msg
assert "" not in msg
assert "/" not in msg # no progress fraction when done=0
def test_update_invalid_result_fallback(self):
"""Bad JSON result — fall back to plain update N task(s)."""
msg = get_cute_tool_message("todo",
{"todos": [{"id": "a", "status": "completed"}],
"merge": True},
0.6,
result="{broken")
assert "update 1 task(s)" in msg
assert "" not in msg
def test_update_result_missing_summary(self):
"""Result no summary key — fall back to plain update."""
msg = get_cute_tool_message("todo",
{"todos": [{"id": "a", "status": "completed"}],
"merge": True},
0.4,
result='{"todos": []}')
assert "update 1 task(s)" in msg
assert "" not in msg
def test_update_total_not_in_summary(self):
"""Result summary missing total key."""
msg = get_cute_tool_message("todo",
{"todos": [{"id": "a", "status": "completed"}],
"merge": True},
0.3,
result=json.dumps({"summary": {"completed": 2}}))
assert "update 1 task(s)" in msg
assert "" not in msg
def test_update_multiple_tasks_in_line(self):
"""Update line with several tasks in the update request."""
msg = get_cute_tool_message("todo",
{"todos": [
{"id": "a", "status": "completed"},
{"id": "b", "status": "in_progress"},
], "merge": True},
0.5,
result=_todo_result(5, 3))
assert "update" in msg
assert "3/5" in msg
assert "" in msg
class TestTodoEdgeCases:
"""Boundary cases that should not crash."""
def test_merge_default_value(self):
"""merge defaults to False in function signature, should be False when absent."""
msg = get_cute_tool_message("todo",
{"todos": [{"id": "a", "content": "x", "status": "pending"}]},
1.0)
assert "1 task(s)" in msg
def test_duration_formatting(self):
"""Duration formatting works correctly."""
msg = get_cute_tool_message("todo", {}, 0.123)
assert "0.1s" in msg
msg = get_cute_tool_message("todo", {}, 1.0)
assert "1.0s" in msg
msg = get_cute_tool_message("todo", {}, 123.456)
assert "123.5s" in msg
def test_large_task_count(self):
"""Many tasks should not break formatting."""
many = [{"id": str(i), "content": "x", "status": "pending"} for i in range(50)]
msg = get_cute_tool_message("todo", {"todos": many}, 0.5)
assert "50 task(s)" in msg
def test_read_with_no_args_and_no_result(self):
"""Completely empty call."""
msg = get_cute_tool_message("todo", {}, 0.0)
assert "reading tasks" in msg
class TestTodoSkinIntegration:
"""Verify the skin prefix is applied to todo messages too.
This uses the same pattern as test_skin_engine test_tool_message_uses_skin_prefix.
"""
def test_default_skin_prefix(self):
msg = get_cute_tool_message("todo", {}, 0.5)
assert msg.startswith("")

View file

@ -0,0 +1,185 @@
"""Tests for _detect_tool_failure + _trim_error + get_cute_tool_message
inline failure suffix rendering.
Covers the user-visible promise: when a tool fails, the CLI shows a short,
specific reason in square brackets at the end of the completion line
not a generic "[error]".
"""
import json
import pytest
from agent.display import (
_detect_tool_failure,
_trim_error,
_ERROR_SUFFIX_MAX_LEN,
get_cute_tool_message,
)
class TestTrimError:
"""The helper that shrinks an error message for inline display."""
def test_short_message_unchanged(self):
assert _trim_error("nope") == "nope"
def test_whitespace_stripped(self):
assert _trim_error(" bad input ") == "bad input"
def test_long_message_truncated_to_cap(self):
msg = "x" * 200
trimmed = _trim_error(msg)
assert len(trimmed) <= _ERROR_SUFFIX_MAX_LEN
assert trimmed.endswith("...")
def test_file_not_found_path_collapsed_to_filename(self):
long_path = "File not found: /home/teknium/.hermes/hermes-agent/very/deep/path/foo.py"
assert _trim_error(long_path) == "File not found: foo.py"
def test_file_not_found_already_short_unchanged(self):
assert _trim_error("File not found: foo.py") == "File not found: foo.py"
def test_file_not_found_relative_path_unchanged(self):
# Without a slash there's no path to trim.
assert _trim_error("File not found: foo.py") == "File not found: foo.py"
class TestDetectToolFailureTerminal:
"""terminal: non-zero exit_code is the canonical failure signal."""
def test_success_returns_no_suffix(self):
result = json.dumps({"output": "ok\n", "exit_code": 0})
assert _detect_tool_failure("terminal", result) == (False, "")
def test_nonzero_exit_with_no_error_shows_exit_code(self):
result = json.dumps({"output": "", "exit_code": 1})
is_failure, suffix = _detect_tool_failure("terminal", result)
assert is_failure is True
assert suffix == " [exit 1]"
def test_nonzero_exit_with_error_shows_message(self):
result = json.dumps({
"output": "",
"exit_code": 127,
"error": "ls: cannot access 'foo': No such file or directory",
})
is_failure, suffix = _detect_tool_failure("terminal", result)
assert is_failure is True
assert "cannot access" in suffix
# Trimmed to the cap, in brackets
assert suffix.startswith(" [")
assert suffix.endswith("]")
def test_malformed_json_returns_no_suffix(self):
# Terminal is special: only exit_code matters. Malformed JSON should
# not crash and should not be flagged as failure.
assert _detect_tool_failure("terminal", "not json") == (False, "")
def test_none_result_returns_no_suffix(self):
assert _detect_tool_failure("terminal", None) == (False, "")
class TestDetectToolFailureMemory:
"""memory: 'full' is distinct from real errors."""
def test_memory_full_returns_full_suffix(self):
result = json.dumps({"success": False, "error": "would exceed the limit"})
assert _detect_tool_failure("memory", result) == (True, " [full]")
def test_memory_other_error_returns_specific_message(self):
# An error that's NOT a "full" overflow falls through to the
# structured-error path and surfaces the actual message.
result = json.dumps({"success": False, "error": "invalid action: zap"})
is_failure, suffix = _detect_tool_failure("memory", result)
assert is_failure is True
assert "invalid action" in suffix
class TestDetectToolFailureStructured:
"""Generic path: any tool that returns {"error": ...} JSON."""
def test_read_file_error_surfaced(self):
result = json.dumps({
"path": "/nope/missing.py",
"success": False,
"error": "File not found: /nope/missing.py",
})
is_failure, suffix = _detect_tool_failure("read_file", result)
assert is_failure is True
# _trim_error reduces the path to the basename.
assert suffix == " [File not found: missing.py]"
def test_error_without_success_key_still_flagged(self):
# Some tools return {"error": "..."} with no explicit success flag.
result = json.dumps({"error": "remote unavailable"})
is_failure, suffix = _detect_tool_failure("web_search", result)
assert is_failure is True
assert suffix == " [remote unavailable]"
def test_message_field_only_with_success_false_flagged(self):
# When success is False and only 'message' is set, surface it.
result = json.dumps({"success": False, "message": "rate limited"})
is_failure, suffix = _detect_tool_failure("web_search", result)
assert is_failure is True
assert "rate limited" in suffix
def test_successful_result_not_flagged(self):
result = json.dumps({"success": True, "data": "hello"})
assert _detect_tool_failure("web_search", result) == (False, "")
def test_dict_without_error_or_success_uses_generic_heuristic(self):
# Plain successful dict — should pass through the generic
# heuristic which only fires on the string "Error" / '"error"' / etc.
result = json.dumps({"data": "hello"})
is_failure, _ = _detect_tool_failure("web_search", result)
assert is_failure is False
class TestGetCuteToolMessageFailureSuffix:
"""End-to-end: failure suffix is appended by get_cute_tool_message."""
def test_read_file_failure_suffix_appended(self):
fail = json.dumps({
"path": "/etc/missing",
"success": False,
"error": "File not found: /etc/missing",
})
line = get_cute_tool_message("read_file", {"path": "/etc/missing"}, 0.1, result=fail)
assert "[File not found: missing]" in line
def test_terminal_exit_only_suffix(self):
fail = json.dumps({"output": "", "exit_code": 2})
line = get_cute_tool_message("terminal", {"command": "false"}, 0.1, result=fail)
assert "[exit 2]" in line
def test_terminal_with_stderr_uses_message(self):
fail = json.dumps({
"output": "",
"exit_code": 127,
"error": "command not found: notathing",
})
line = get_cute_tool_message("terminal", {"command": "notathing"}, 0.1, result=fail)
assert "command not found" in line
# No '[exit 127]' tag when we have a specific message
assert "exit 127" not in line
def test_memory_full_suffix(self):
fail = json.dumps({"success": False, "error": "would exceed the limit"})
line = get_cute_tool_message(
"memory",
{"action": "add", "target": "memory", "content": "x"},
0.05,
result=fail,
)
assert "[full]" in line
def test_success_has_no_suffix(self):
ok = json.dumps({"success": True, "data": "hi"})
line = get_cute_tool_message("web_search", {"query": "hi"}, 0.2, result=ok)
assert "[" not in line.split("0.2s", 1)[1]
def test_no_result_has_no_suffix(self):
# No result passed at all — display function should not invent a
# failure suffix.
line = get_cute_tool_message("terminal", {"command": "ls"}, 0.2)
assert "[" not in line.split("0.2s", 1)[1]

View file

@ -56,6 +56,7 @@ class TestFailoverReason:
"overloaded", "server_error", "timeout",
"context_overflow", "payload_too_large", "image_too_large",
"model_not_found", "format_error",
"multimodal_tool_content_unsupported",
"provider_policy_blocked",
"thinking_signature", "long_context_tier",
"oauth_long_context_beta_forbidden",
@ -292,6 +293,64 @@ class TestClassifyApiError:
result = classify_api_error(e)
assert result.reason == FailoverReason.overloaded
# ── 5xx that are actually request-validation errors ──
# Some OpenAI-compatible gateways (e.g. codex.nekos.me) return
# request-validation failures with a 5xx status. These are
# deterministic, so they must NOT be retried — otherwise the retry
# loop hammers the identical bad request into a flood.
def test_502_with_unknown_parameter_is_non_retryable(self):
e = MockAPIError(
"Unknown parameter: 'input[617]._empty_recovery_synthetic'",
status_code=502,
body={
"error": {
"type": "invalid_request_error",
"message": (
"[ObjectParam] [input[617]._empty_recovery_synthetic] "
"[unknown_parameter] Unknown parameter: "
"'input[617]._empty_recovery_synthetic'."
),
}
},
)
result = classify_api_error(e)
assert result.reason == FailoverReason.format_error
assert result.retryable is False
assert result.should_fallback is True
def test_502_with_unsupported_parameter_is_non_retryable(self):
e = MockAPIError(
"Unsupported parameter: logprobs",
status_code=502,
body={
"error": {
"type": "invalid_request_error",
"message": "Unsupported parameter: logprobs",
}
},
)
result = classify_api_error(e)
assert result.reason == FailoverReason.format_error
assert result.retryable is False
def test_500_with_invalid_request_error_type_is_non_retryable(self):
e = MockAPIError(
"bad request",
status_code=500,
body={"error": {"type": "invalid_request_error", "message": "bad request"}},
)
result = classify_api_error(e)
assert result.reason == FailoverReason.format_error
assert result.retryable is False
def test_502_plain_bad_gateway_still_retryable(self):
"""A genuine 502 with no request-validation signal stays retryable."""
e = MockAPIError("Bad Gateway", status_code=502)
result = classify_api_error(e)
assert result.reason == FailoverReason.server_error
assert result.retryable is True
# ── Model not found ──
def test_404_model_not_found(self):
@ -1256,3 +1315,66 @@ class TestRateLimitErrorWithoutStatusCode:
e.status_code = None
result = classify_api_error(e, provider="copilot", model="gpt-4o")
assert result.reason != FailoverReason.rate_limit
# ── Test: multimodal_tool_content_unsupported pattern ───────────────────
class TestMultimodalToolContentUnsupported:
"""Issue #27344 — providers that reject list-type tool message content
should be classified as ``multimodal_tool_content_unsupported`` so the
retry loop can downgrade screenshots to text and try again.
"""
def test_xiaomi_mimo_text_is_not_set_pattern(self):
"""The actual Xiaomi MiMo 400 wording from the bug report."""
e = MockAPIError(
"Error code: 400 - {'error': {'code': '400', 'message': 'Param Incorrect', 'param': 'text is not set', 'type': ''}}",
status_code=400,
)
result = classify_api_error(e, provider="xiaomi", model="mimo-v2.5")
assert result.reason == FailoverReason.multimodal_tool_content_unsupported
assert result.retryable is True
def test_generic_tool_message_must_be_string(self):
e = MockAPIError(
"tool message content must be a string",
status_code=400,
)
result = classify_api_error(e, provider="custom", model="some-model")
assert result.reason == FailoverReason.multimodal_tool_content_unsupported
def test_expected_string_got_list(self):
e = MockAPIError(
"Schema validation failed: expected string, got list",
status_code=400,
)
result = classify_api_error(e, provider="custom", model="some-model")
assert result.reason == FailoverReason.multimodal_tool_content_unsupported
def test_multimodal_tool_content_takes_priority_over_context_overflow(self):
"""Some providers return a 400 whose message contains BOTH
'text is not set' and a length-shaped phrase; the tool-content
recovery is cheaper than compression so it must win the priority.
"""
e = MockAPIError(
"text is not set; context length exceeded",
status_code=400,
)
result = classify_api_error(e, provider="xiaomi", model="mimo-v2.5")
assert result.reason == FailoverReason.multimodal_tool_content_unsupported
def test_no_status_code_path_also_classifies(self):
"""When the error reaches us without a status code (transport
layer ate it) the message-only classifier branch must also
recognise the pattern.
"""
e = MockTransportError("tool_call.content must be string")
result = classify_api_error(e, provider="alibaba", model="qwen3.5-plus")
assert result.reason == FailoverReason.multimodal_tool_content_unsupported
def test_unrelated_400_is_not_misclassified(self):
"""Make sure the patterns don't false-positive on normal 400s."""
e = MockAPIError("bad request: missing field 'model'", status_code=400)
result = classify_api_error(e, provider="openrouter", model="anthropic/claude-sonnet-4")
assert result.reason != FailoverReason.multimodal_tool_content_unsupported

View file

@ -0,0 +1,275 @@
"""Tests for HERMES_HOME credential-file read blocking in file_safety.
Regression for https://github.com/NousResearch/hermes-agent/issues/17656
``read_file`` was previously only sandboxed against ``HERMES_HOME`` itself,
which left ``auth.json`` and ``.anthropic_oauth.json`` (plaintext provider
keys + OAuth tokens) readable by the agent. A prompt-injection reaching
``read_file`` could exfiltrate active credentials.
These tests verify that ``get_read_block_error`` returns a denial message
for the credential stores while leaving arbitrary ``HERMES_HOME`` files
readable, and that the existing ``skills/.hub`` deny still applies.
"""
from __future__ import annotations
import os
from pathlib import Path
import pytest
@pytest.fixture()
def fake_home(tmp_path, monkeypatch):
"""Point ``_hermes_home_path()`` at a tmp dir for isolated checks."""
import agent.file_safety as fs
home = tmp_path / "hermes_home"
home.mkdir()
monkeypatch.setattr(fs, "_hermes_home_path", lambda: home)
return home
def _create(home: Path, rel: str | Path) -> Path:
"""Create the file (with parents) so realpath() resolves it."""
p = home / rel
p.parent.mkdir(parents=True, exist_ok=True)
p.write_text("dummy", encoding="utf-8")
return p
def test_auth_json_blocked(fake_home):
from agent.file_safety import get_read_block_error
auth = _create(fake_home, "auth.json")
err = get_read_block_error(str(auth))
assert err is not None
assert "credential store" in err
assert "auth.json" in err
def test_auth_lock_blocked(fake_home):
from agent.file_safety import get_read_block_error
lock = _create(fake_home, "auth.lock")
err = get_read_block_error(str(lock))
assert err is not None
assert "credential store" in err
def test_anthropic_oauth_json_blocked(fake_home):
from agent.file_safety import get_read_block_error
oauth = _create(fake_home, ".anthropic_oauth.json")
err = get_read_block_error(str(oauth))
assert err is not None
assert "credential store" in err
def test_arbitrary_hermes_home_file_not_blocked(fake_home):
"""Non-credential files inside HERMES_HOME stay readable."""
from agent.file_safety import get_read_block_error
safe = _create(fake_home, "session_log.txt")
assert get_read_block_error(str(safe)) is None
def test_subdirectory_named_auth_json_not_blocked(fake_home):
"""Only the top-level auth.json is the credential store; a file with the
same name in a subdirectory (e.g., a skill mock) must remain readable."""
from agent.file_safety import get_read_block_error
nested = _create(fake_home, Path("skills") / "my-skill" / "auth.json")
assert get_read_block_error(str(nested)) is None
def test_skills_hub_block_still_applies(fake_home):
"""Regression guard: the original skills/.hub deny must keep working."""
from agent.file_safety import get_read_block_error
hub_file = _create(fake_home, "skills/.hub/manifest.json")
err = get_read_block_error(str(hub_file))
assert err is not None
assert "internal Hermes cache file" in err
def test_path_traversal_resolves_to_blocked(fake_home, tmp_path):
"""A path that traverses through a sibling dir back into HERMES_HOME's
auth.json must still be caught the check resolves through realpath."""
from agent.file_safety import get_read_block_error
_create(fake_home, "auth.json")
sibling = tmp_path / "elsewhere"
sibling.mkdir()
traversal = sibling / ".." / "hermes_home" / "auth.json"
err = get_read_block_error(str(traversal))
assert err is not None
assert "credential store" in err
def test_symlink_to_auth_json_blocked(fake_home, tmp_path):
"""A symlink pointing at HERMES_HOME/auth.json from outside the home
must be blocked readlink-resolution catches the indirection."""
from agent.file_safety import get_read_block_error
target = _create(fake_home, "auth.json")
link = tmp_path / "shim.json"
try:
os.symlink(target, link)
except (OSError, NotImplementedError):
pytest.skip("symlinks not supported on this platform/filesystem")
err = get_read_block_error(str(link))
assert err is not None
assert "credential store" in err
def test_read_file_tool_blocks_relative_path_under_terminal_cwd(
fake_home, tmp_path, monkeypatch
):
"""Bypass guard: a relative path like ``"auth.json"`` resolved by
``read_file_tool`` against ``TERMINAL_CWD == HERMES_HOME`` must still
be blocked, even though ``get_read_block_error``'s own ``resolve()``
is anchored at the (different) Python process cwd.
"""
import json
import tools.file_tools as ft
_create(fake_home, "auth.json")
# Force the file_tools resolver to anchor relative paths at HERMES_HOME
# while the Python process cwd remains tmp_path (a different directory).
monkeypatch.setenv("TERMINAL_CWD", str(fake_home))
monkeypatch.chdir(tmp_path)
monkeypatch.setattr(
ft, "_get_live_tracking_cwd", lambda task_id="default": None
)
out = json.loads(ft.read_file_tool("auth.json"))
assert "error" in out
assert "credential store" in out["error"]
# ---------------------------------------------------------------------------
# Widening: .env, webhook_subscriptions.json, mcp-tokens/
# ---------------------------------------------------------------------------
def test_dotenv_blocked(fake_home):
""".env in HERMES_HOME holds API keys — blocked."""
from agent.file_safety import get_read_block_error
env = _create(fake_home, ".env")
err = get_read_block_error(str(env))
assert err is not None
assert "credential store" in err
def test_webhook_subscriptions_blocked(fake_home):
"""webhook_subscriptions.json holds per-route HMAC secrets — blocked."""
from agent.file_safety import get_read_block_error
subs = _create(fake_home, "webhook_subscriptions.json")
err = get_read_block_error(str(subs))
assert err is not None
assert "credential store" in err
def test_mcp_tokens_file_blocked(fake_home):
"""Files under mcp-tokens/ hold OAuth tokens — blocked."""
from agent.file_safety import get_read_block_error
tok = _create(fake_home, Path("mcp-tokens") / "github.json")
err = get_read_block_error(str(tok))
assert err is not None
assert "MCP token" in err
def test_mcp_tokens_nested_blocked(fake_home):
"""Nested files inside mcp-tokens/ are also blocked."""
from agent.file_safety import get_read_block_error
tok = _create(fake_home, Path("mcp-tokens") / "providers" / "azure.json")
err = get_read_block_error(str(tok))
assert err is not None
assert "MCP token" in err
def test_mcp_tokens_dir_itself_blocked(fake_home):
"""The mcp-tokens directory itself is blocked (listing is exfiltrating)."""
from agent.file_safety import get_read_block_error
tokens_dir = fake_home / "mcp-tokens"
tokens_dir.mkdir(parents=True, exist_ok=True)
err = get_read_block_error(str(tokens_dir))
assert err is not None
assert "MCP token" in err
def test_identically_named_files_outside_hermes_home_not_blocked(
fake_home, tmp_path
):
"""A project's ``.env``, ``auth.json``, or ``mcp-tokens/`` outside
HERMES_HOME must remain readable the gate is per-location, not
per-filename."""
from agent.file_safety import get_read_block_error
project = tmp_path / "myproject"
project.mkdir()
for rel in (".env", "auth.json"):
p = project / rel
p.write_text("not secret here", encoding="utf-8")
assert get_read_block_error(str(p)) is None, (
f"{rel} outside HERMES_HOME should NOT be blocked"
)
tokens = project / "mcp-tokens"
tokens.mkdir()
tok_file = tokens / "token.json"
tok_file.write_text("not really a token", encoding="utf-8")
assert get_read_block_error(str(tok_file)) is None
def test_config_yaml_not_blocked(fake_home):
"""config.yaml is NOT a credential file — agent should still be
able to read it for debugging. (Writes are denied separately by
is_write_denied; reads stay allowed.)"""
from agent.file_safety import get_read_block_error
cfg = _create(fake_home, "config.yaml")
assert get_read_block_error(str(cfg)) is None
def test_profile_mode_blocks_root_credentials(tmp_path, monkeypatch):
"""Under a profile, HERMES_HOME = <root>/profiles/<name>, but
<root>/auth.json must ALSO be blocked credentials at root are
inherited by every profile."""
import agent.file_safety as fs
root = tmp_path / "hermes"
profile = root / "profiles" / "coder"
profile.mkdir(parents=True)
monkeypatch.setattr(fs, "_hermes_home_path", lambda: profile)
monkeypatch.setattr(fs, "_hermes_root_path", lambda: root)
from agent.file_safety import get_read_block_error
# Profile-local credential store: blocked
profile_auth = profile / "auth.json"
profile_auth.write_text("x")
assert "credential store" in (get_read_block_error(str(profile_auth)) or "")
# Root-level credential store: ALSO blocked (this is the widening)
root_auth = root / "auth.json"
root_auth.write_text("x")
assert "credential store" in (get_read_block_error(str(root_auth)) or "")
# Root-level .env: blocked too
root_env = root / ".env"
root_env.write_text("x")
assert "credential store" in (get_read_block_error(str(root_env)) or "")
# Root-level mcp-tokens: blocked
root_tok = root / "mcp-tokens" / "gh.json"
root_tok.parent.mkdir(parents=True, exist_ok=True)
root_tok.write_text("x")
assert "MCP token" in (get_read_block_error(str(root_tok)) or "")

View file

@ -0,0 +1,219 @@
"""Tests for the cross-Hermes-profile write guard in agent/file_safety.
The guard fires when a tool tries to write into another Hermes profile's
skills/plugins/cron/memories directory. It's a soft guard — defense in
depth, NOT a security boundary but it prevents the agent from silently
corrupting a profile that belongs to a different session.
Reference: May 2026 incident a hermes-security profile session
accidentally edited skills under both ~/.hermes/profiles/hermes-security/skills/
AND ~/.hermes/skills/ (the default profile's skills), realizing only
afterwards that the second path belonged to a different profile.
"""
from __future__ import annotations
import os
from pathlib import Path
import pytest
# ---------------------------------------------------------------------------
# Helpers — set up a fake Hermes root with two profiles, monkeypatch the
# resolver helpers so the classifier sees the test layout.
# ---------------------------------------------------------------------------
@pytest.fixture
def fake_hermes(tmp_path, monkeypatch):
"""Build a fake Hermes layout:
<tmp>/
skills/foo/SKILL.md # default profile
plugins/foo/__init__.py
cron/<state>
memories/MEMORY.md
profiles/
hermes-security/
skills/foo/SKILL.md # named profile
plugins/...
coder/
skills/foo/SKILL.md # another named profile
"""
root = tmp_path / "fake-hermes"
(root / "skills" / "foo").mkdir(parents=True)
(root / "skills" / "foo" / "SKILL.md").write_text("# default skill\n")
(root / "plugins" / "foo").mkdir(parents=True)
(root / "memories").mkdir(parents=True)
(root / "cron").mkdir(parents=True)
sec_home = root / "profiles" / "hermes-security"
(sec_home / "skills" / "foo").mkdir(parents=True)
(sec_home / "skills" / "foo" / "SKILL.md").write_text("# sec skill\n")
(sec_home / "plugins").mkdir(parents=True)
coder_home = root / "profiles" / "coder"
(coder_home / "skills" / "foo").mkdir(parents=True)
(coder_home / "skills" / "foo" / "SKILL.md").write_text("# coder skill\n")
# Monkeypatch the resolver functions used by file_safety so each test
# can choose which profile is "active".
import hermes_constants
monkeypatch.setattr(hermes_constants, "get_default_hermes_root", lambda: root)
# The reloads below ensure get_cross_profile_warning/classify see the patched root.
import agent.file_safety as fs
monkeypatch.setattr(fs, "_hermes_root_path", lambda: root)
return {
"root": root,
"default_home": root,
"security_home": sec_home,
"coder_home": coder_home,
}
def _set_active_home(monkeypatch, hermes_home: Path):
"""Point file_safety._hermes_home_path at a specific profile dir."""
import agent.file_safety as fs
monkeypatch.setattr(fs, "_hermes_home_path", lambda: hermes_home)
# ---------------------------------------------------------------------------
# _resolve_active_profile_name
# ---------------------------------------------------------------------------
class TestResolveActiveProfileName:
def test_default_when_home_is_root(self, fake_hermes, monkeypatch):
_set_active_home(monkeypatch, fake_hermes["default_home"])
from agent.file_safety import _resolve_active_profile_name
assert _resolve_active_profile_name() == "default"
def test_named_profile(self, fake_hermes, monkeypatch):
_set_active_home(monkeypatch, fake_hermes["security_home"])
from agent.file_safety import _resolve_active_profile_name
assert _resolve_active_profile_name() == "hermes-security"
def test_falls_back_to_default_on_resolution_failure(self, fake_hermes, monkeypatch):
"""If HERMES_HOME resolution raises, return 'default' rather than crashing the tool."""
import agent.file_safety as fs
def _boom():
raise RuntimeError("simulated")
monkeypatch.setattr(fs, "_hermes_home_path", _boom)
# Should not raise — falls back to "default"
assert fs._resolve_active_profile_name() == "default"
# ---------------------------------------------------------------------------
# classify_cross_profile_target
# ---------------------------------------------------------------------------
class TestClassifyCrossProfileTarget:
def test_same_profile_write_returns_none(self, fake_hermes, monkeypatch):
_set_active_home(monkeypatch, fake_hermes["security_home"])
from agent.file_safety import classify_cross_profile_target
result = classify_cross_profile_target(
str(fake_hermes["security_home"] / "skills" / "foo" / "SKILL.md")
)
assert result is None
def test_security_writing_default_skill(self, fake_hermes, monkeypatch):
"""The exact incident from May 2026."""
_set_active_home(monkeypatch, fake_hermes["security_home"])
from agent.file_safety import classify_cross_profile_target
result = classify_cross_profile_target(
str(fake_hermes["default_home"] / "skills" / "foo" / "SKILL.md")
)
assert result is not None
assert result["active_profile"] == "hermes-security"
assert result["target_profile"] == "default"
assert result["area"] == "skills"
def test_default_writing_security_skill(self, fake_hermes, monkeypatch):
"""Inverse direction — default-profile session reaching into a named profile."""
_set_active_home(monkeypatch, fake_hermes["default_home"])
from agent.file_safety import classify_cross_profile_target
result = classify_cross_profile_target(
str(fake_hermes["security_home"] / "skills" / "foo" / "SKILL.md")
)
assert result is not None
assert result["active_profile"] == "default"
assert result["target_profile"] == "hermes-security"
def test_named_to_named_cross_profile(self, fake_hermes, monkeypatch):
_set_active_home(monkeypatch, fake_hermes["security_home"])
from agent.file_safety import classify_cross_profile_target
result = classify_cross_profile_target(
str(fake_hermes["coder_home"] / "skills" / "foo" / "SKILL.md")
)
assert result is not None
assert result["target_profile"] == "coder"
@pytest.mark.parametrize("area", ["skills", "plugins", "cron", "memories"])
def test_all_profile_scoped_areas_classified(self, fake_hermes, monkeypatch, area):
_set_active_home(monkeypatch, fake_hermes["security_home"])
from agent.file_safety import classify_cross_profile_target
target = fake_hermes["default_home"] / area / "foo.txt"
result = classify_cross_profile_target(str(target))
assert result is not None
assert result["area"] == area
def test_non_hermes_path_returns_none(self, fake_hermes, monkeypatch, tmp_path):
_set_active_home(monkeypatch, fake_hermes["security_home"])
from agent.file_safety import classify_cross_profile_target
# Path outside any Hermes root
assert classify_cross_profile_target(str(tmp_path / "random.txt")) is None
def test_hermes_config_not_classified_as_cross_profile(self, fake_hermes, monkeypatch):
"""Files under <root>/config.yaml or <root>/.env are NOT profile-scoped
(already covered by build_write_denied_paths). Don't double-warn."""
_set_active_home(monkeypatch, fake_hermes["security_home"])
from agent.file_safety import classify_cross_profile_target
# config.yaml at root level is not in PROFILE_SCOPED_AREAS
result = classify_cross_profile_target(
str(fake_hermes["default_home"] / "config.yaml")
)
assert result is None
# ---------------------------------------------------------------------------
# get_cross_profile_warning
# ---------------------------------------------------------------------------
class TestGetCrossProfileWarning:
def test_in_profile_returns_none(self, fake_hermes, monkeypatch):
_set_active_home(monkeypatch, fake_hermes["security_home"])
from agent.file_safety import get_cross_profile_warning
assert get_cross_profile_warning(
str(fake_hermes["security_home"] / "skills" / "foo" / "SKILL.md")
) is None
def test_cross_profile_warning_names_both_profiles(self, fake_hermes, monkeypatch):
_set_active_home(monkeypatch, fake_hermes["security_home"])
from agent.file_safety import get_cross_profile_warning
warn = get_cross_profile_warning(
str(fake_hermes["default_home"] / "skills" / "foo" / "SKILL.md")
)
assert warn is not None
# Must name BOTH profiles so the model knows which is which.
assert "default" in warn
assert "hermes-security" in warn
# Must name the bypass kwarg.
assert "cross_profile=True" in warn
# Must reference the area.
assert "skills" in warn
def test_warning_is_defense_in_depth_not_boundary(self, fake_hermes, monkeypatch):
_set_active_home(monkeypatch, fake_hermes["security_home"])
from agent.file_safety import get_cross_profile_warning
warn = get_cross_profile_warning(
str(fake_hermes["default_home"] / "skills" / "foo" / "SKILL.md")
)
# Must self-document as defense-in-depth so future reviewers
# don't promote it to a hard block.
assert "not a security boundary" in warn.lower()

View file

@ -0,0 +1,22 @@
"""Test that last_total_tokens is correctly set by ContextCompressor."""
from agent.context_compressor import ContextCompressor
def test_update_from_response_sets_total_tokens():
"""ABC contract: last_total_tokens must be set from API response."""
c = ContextCompressor(model="test", quiet_mode=True, config_context_length=200000)
c.update_from_response({"prompt_tokens": 100, "completion_tokens": 30, "total_tokens": 130})
assert c.last_total_tokens == 130
c.update_from_response({"prompt_tokens": 100, "completion_tokens": 30})
assert c.last_total_tokens == 130
def test_session_reset_clears_total_tokens():
"""on_session_reset must zero total_tokens."""
c = ContextCompressor(model="test", quiet_mode=True, config_context_length=200000)
c.update_from_response({"prompt_tokens": 100, "completion_tokens": 30, "total_tokens": 130})
c.on_session_reset()
assert c.last_total_tokens == 0

View file

@ -1060,3 +1060,191 @@ class TestHonchoCadenceTracking:
p.on_turn_start(2, "second message")
should_skip = p._injection_frequency == "first-turn" and p._turn_count > 1
assert should_skip, "Second turn (turn 2) SHOULD be skipped"
class TestMemoryToolToolsetGate:
"""Issue #5544: memory provider tools must respect platform_toolsets.
Before the fix, MemoryManager.get_all_tool_schemas() output was appended
to AIAgent.tools unconditionally in agent_init.py bypassing the
enabled_toolsets filter. Result: `platform_toolsets: telegram: []`
still leaked fact_store and other memory tools into the tool surface,
causing 10x latency on local models (Qwen3-30B: 1.7s 42s) and
tool-call loops on small models.
These tests mirror the gate logic in agent/agent_init.py around the
memory provider tool injection block. The gate condition is:
enabled_toolsets is None no filter, inject (backward compat)
"memory" in enabled_toolsets user opted in, inject
otherwise (incl. []) skip injection
"""
@staticmethod
def _run_memory_injection(enabled_toolsets, memory_manager):
"""Simulate the gated memory-tool injection block from agent_init.py."""
tools = []
valid_tool_names = set()
if memory_manager and tools is not None and (
enabled_toolsets is None or "memory" in enabled_toolsets
):
_existing = {
t.get("function", {}).get("name")
for t in tools
if isinstance(t, dict)
}
for _schema in memory_manager.get_all_tool_schemas():
_tname = _schema.get("name", "")
if _tname and _tname in _existing:
continue
tools.append({"type": "function", "function": _schema})
if _tname:
valid_tool_names.add(_tname)
_existing.add(_tname)
return tools, valid_tool_names
def _mgr_with_tools(self, *tool_names):
"""Build a MemoryManager whose providers expose the named tool schemas."""
mgr = MemoryManager()
p = FakeMemoryProvider(
"ext",
tools=[{"name": n, "description": n, "parameters": {}} for n in tool_names],
)
mgr.add_provider(p)
return mgr
def test_none_toolsets_injects(self):
"""enabled_toolsets=None (no filter) injects memory tools — backward compat."""
mgr = self._mgr_with_tools("fact_store")
tools, names = self._run_memory_injection(None, mgr)
assert "fact_store" in names
assert any(t["function"]["name"] == "fact_store" for t in tools)
def test_memory_in_toolsets_injects(self):
"""enabled_toolsets including 'memory' injects memory tools."""
mgr = self._mgr_with_tools("fact_store")
tools, names = self._run_memory_injection(["terminal", "memory", "web"], mgr)
assert "fact_store" in names
def test_empty_toolsets_blocks_injection(self):
"""`platform_toolsets: telegram: []` must suppress memory tools. (#5544)"""
mgr = self._mgr_with_tools("fact_store")
tools, names = self._run_memory_injection([], mgr)
assert tools == []
assert names == set()
def test_toolsets_without_memory_blocks_injection(self):
"""Toolset list that doesn't name 'memory' must suppress injection."""
mgr = self._mgr_with_tools("fact_store")
tools, names = self._run_memory_injection(["terminal", "web"], mgr)
assert tools == []
assert names == set()
def test_no_memory_manager_no_injection(self):
"""Gate is moot without a memory manager."""
tools, names = self._run_memory_injection(None, None)
assert tools == []
def test_multiple_schemas_all_blocked_together(self):
"""When the gate is closed, no memory tools leak — not even partially."""
mgr = self._mgr_with_tools("fact_store", "memory_search", "memory_add")
tools, names = self._run_memory_injection(["terminal"], mgr)
assert tools == []
assert names == set()
def test_multiple_schemas_all_injected_when_enabled(self):
"""When the gate is open, every memory tool schema is injected."""
mgr = self._mgr_with_tools("fact_store", "memory_search", "memory_add")
tools, names = self._run_memory_injection(None, mgr)
assert names == {"fact_store", "memory_search", "memory_add"}
class TestContextEngineToolsetGate:
"""Issue #5544 (sibling): context engine tools follow the same gate.
`agent.context_compressor.get_tool_schemas()` (e.g. lcm_grep, lcm_describe,
lcm_expand) was appended to AIAgent.tools unconditionally. Same blind
injection class as the memory bug; same local-model penalty. Gate name:
"context_engine" (matches the existing plugin-system convention).
"""
@staticmethod
def _run_context_engine_injection(enabled_toolsets, compressor):
"""Simulate the gated context-engine injection block from agent_init.py."""
tools = []
valid_tool_names = set()
engine_tool_names = set()
if (
compressor is not None
and tools is not None
and (
enabled_toolsets is None
or "context_engine" in enabled_toolsets
)
):
_existing = {
t.get("function", {}).get("name")
for t in tools
if isinstance(t, dict)
}
for _schema in compressor.get_tool_schemas():
_tname = _schema.get("name", "")
if _tname and _tname in _existing:
continue
tools.append({"type": "function", "function": _schema})
if _tname:
valid_tool_names.add(_tname)
engine_tool_names.add(_tname)
_existing.add(_tname)
return tools, valid_tool_names, engine_tool_names
class _FakeCompressor:
def __init__(self, schemas):
self._schemas = schemas
def get_tool_schemas(self):
return list(self._schemas)
def _compressor_with(self, *tool_names):
return self._FakeCompressor(
[{"name": n, "description": n, "parameters": {}} for n in tool_names]
)
def test_none_toolsets_injects(self):
"""enabled_toolsets=None injects context-engine tools — backward compat."""
c = self._compressor_with("lcm_grep", "lcm_describe", "lcm_expand")
tools, names, engine_names = self._run_context_engine_injection(None, c)
assert engine_names == {"lcm_grep", "lcm_describe", "lcm_expand"}
def test_context_engine_in_toolsets_injects(self):
"""enabled_toolsets including 'context_engine' injects the tools."""
c = self._compressor_with("lcm_grep")
tools, names, engine_names = self._run_context_engine_injection(
["terminal", "context_engine"], c
)
assert "lcm_grep" in engine_names
def test_empty_toolsets_blocks_injection(self):
"""`platform_toolsets: telegram: []` must suppress context-engine tools."""
c = self._compressor_with("lcm_grep")
tools, names, engine_names = self._run_context_engine_injection([], c)
assert tools == []
assert engine_names == set()
def test_toolsets_without_context_engine_blocks_injection(self):
"""A toolset list that doesn't name 'context_engine' suppresses injection."""
c = self._compressor_with("lcm_grep", "lcm_describe")
tools, names, engine_names = self._run_context_engine_injection(
["terminal", "memory"], c
)
assert tools == []
assert engine_names == set()
def test_no_compressor_no_injection(self):
"""Gate is moot without a context_compressor."""
tools, names, engine_names = self._run_context_engine_injection(None, None)
assert tools == []

View file

@ -164,6 +164,7 @@ class TestDefaultContextLengths:
"grok-4-1-fast": 2000000,
"grok-4-fast": 2000000,
"grok-4": 256000,
"grok-build": 256000,
"grok-code-fast": 256000,
"grok-3": 131072,
"grok-2": 131072,
@ -195,6 +196,7 @@ class TestDefaultContextLengths:
("grok-4-fast-non-reasoning", 2000000),
("grok-4", 256000),
("grok-4-0709", 256000),
("grok-build-0.1", 256000),
("grok-code-fast-1", 256000),
("grok-3", 131072),
("grok-3-mini", 131072),
@ -210,6 +212,32 @@ class TestDefaultContextLengths:
f"{model_id}: expected {expected_ctx}, got {actual}"
)
def test_xai_oauth_grok_build_uses_xai_models_dev_context(self):
"""xAI OAuth should share the xAI provider metadata path.
The xAI /v1/models endpoint does not currently include context fields
for grok-build-0.1, so this guards against falling through to the
generic "grok" 131k fallback when using OAuth credentials.
"""
registry = {
"xai": {
"models": {
"grok-build-0.1": {
"limit": {"context": 256000, "output": 64000},
},
},
},
}
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
patch("agent.model_metadata._query_ollama_api_show", return_value=None), \
patch("agent.models_dev.fetch_models_dev", return_value=registry):
assert get_model_context_length(
"grok-build-0.1",
provider="xai-oauth",
base_url="https://api.x.ai/v1",
api_key="oauth-token",
) == 256000
def test_deepseek_v4_models_1m_context(self):
from agent.model_metadata import get_model_context_length
from unittest.mock import patch as mock_patch

View file

@ -41,6 +41,16 @@ SAMPLE_REGISTRY = {
},
},
},
"xai": {
"id": "xai",
"name": "xAI",
"models": {
"grok-build-0.1": {
"id": "grok-build-0.1",
"limit": {"context": 256000, "output": 64000},
},
},
},
"kilo": {
"id": "kilo",
"name": "Kilo Gateway",
@ -86,6 +96,10 @@ class TestProviderMapping:
assert PROVIDER_TO_MODELS_DEV["kilocode"] == "kilo"
assert PROVIDER_TO_MODELS_DEV["ai-gateway"] == "vercel"
def test_xai_oauth_uses_xai_catalog(self):
assert PROVIDER_TO_MODELS_DEV["xai"] == "xai"
assert PROVIDER_TO_MODELS_DEV["xai-oauth"] == "xai"
def test_unmapped_provider_not_in_dict(self):
assert "nous" not in PROVIDER_TO_MODELS_DEV
@ -144,6 +158,12 @@ class TestLookupModelsDevContext:
# GitHub Copilot: only 128K for same model
assert lookup_models_dev_context("copilot", "claude-opus-4.6") == 128000
@patch("agent.models_dev.fetch_models_dev")
def test_xai_oauth_resolves_xai_context(self, mock_fetch):
"""xAI OAuth is an auth path, not a separate model catalog."""
mock_fetch.return_value = SAMPLE_REGISTRY
assert lookup_models_dev_context("xai-oauth", "grok-build-0.1") == 256000
@patch("agent.models_dev.fetch_models_dev")
def test_zero_context_filtered(self, mock_fetch):
mock_fetch.return_value = SAMPLE_REGISTRY

View file

@ -451,6 +451,28 @@ class TestUrlQueryParamRedaction:
result = redact_sensitive_text(text)
assert "opaqueWsToken123" not in result
def test_http_access_log_relative_request_target_query(self):
text = (
'INFO aiohttp.access: 127.0.0.1 "POST '
'/bluebubbles-webhook?password=webhookSecret123&event=new-message '
'HTTP/1.1" 200 173 "-" "test-client"'
)
result = redact_sensitive_text(text)
assert "webhookSecret123" not in result
assert "password=***" in result
assert "event=new-message" in result
def test_http_access_log_absolute_request_target_query(self):
text = (
'INFO aiohttp.access: 127.0.0.1 "GET '
'https://example.com/callback?code=oauthCode123&state=csrf-ok '
'HTTP/1.1" 200 173 "-" "test-client"'
)
result = redact_sensitive_text(text)
assert "oauthCode123" not in result
assert "code=***" in result
assert "state=csrf-ok" in result
class TestUrlUserinfoRedaction:
"""URL userinfo (`scheme://user:pass@host`) for non-DB schemes."""

View file

@ -1,6 +1,12 @@
"""Tests for agent/skill_utils.py."""
from agent.skill_utils import extract_skill_conditions, iter_skill_index_files
from unittest.mock import patch
from agent.skill_utils import (
extract_skill_conditions,
iter_skill_index_files,
skill_matches_platform,
)
def test_metadata_as_dict_with_hermes():
@ -94,3 +100,100 @@ def test_iter_skill_index_files_prunes_dependency_dirs(tmp_path):
found = list(iter_skill_index_files(tmp_path, "SKILL.md"))
assert found == [real / "SKILL.md"]
# ── skill_matches_platform on Termux ──────────────────────────────────────
class TestSkillMatchesPlatformTermux:
"""Termux is Linux userland on Android. Skills tagged platforms:[linux]
must load there regardless of whether Python reports sys.platform as
"linux" (pre-3.13) or "android" (3.13+). Reported by user @LikiusInik
in May 2026 only 3 built-in skills appeared on Termux because every
github/productivity/mlops skill is tagged platforms:[linux,macos,windows]
and sys.platform=="android" did not start with "linux".
"""
def test_no_platforms_field_matches_everywhere(self):
# Backward-compat default — skills without a platforms tag load
# on any OS, Termux included.
with patch("agent.skill_utils.sys.platform", "android"), patch(
"agent.skill_utils.is_termux", return_value=True
):
assert skill_matches_platform({}) is True
assert skill_matches_platform({"name": "foo"}) is True
def test_linux_skill_loads_on_termux_android_platform(self):
# Python 3.13+ on Termux reports sys.platform == "android".
fm = {"platforms": ["linux"]}
with patch("agent.skill_utils.sys.platform", "android"), patch(
"agent.skill_utils.is_termux", return_value=True
):
assert skill_matches_platform(fm) is True
def test_linux_macos_windows_skill_loads_on_termux(self):
# The common "[linux, macos, windows]" tag used by github-*,
# productivity, mlops, etc.
fm = {"platforms": ["linux", "macos", "windows"]}
with patch("agent.skill_utils.sys.platform", "android"), patch(
"agent.skill_utils.is_termux", return_value=True
):
assert skill_matches_platform(fm) is True
def test_linux_skill_loads_on_termux_linux_platform(self):
# Pre-3.13 Termux reports sys.platform == "linux" already — this
# works without the Termux escape hatch but must still pass.
fm = {"platforms": ["linux"]}
with patch("agent.skill_utils.sys.platform", "linux"), patch(
"agent.skill_utils.is_termux", return_value=True
):
assert skill_matches_platform(fm) is True
def test_macos_only_skill_still_excluded_on_termux(self):
# macOS-only skills (apple-notes, imessage, ...) should NOT load
# on Termux. The Termux fallback only widens platforms:[linux,...].
fm = {"platforms": ["macos"]}
with patch("agent.skill_utils.sys.platform", "android"), patch(
"agent.skill_utils.is_termux", return_value=True
):
assert skill_matches_platform(fm) is False
def test_windows_only_skill_still_excluded_on_termux(self):
fm = {"platforms": ["windows"]}
with patch("agent.skill_utils.sys.platform", "android"), patch(
"agent.skill_utils.is_termux", return_value=True
):
assert skill_matches_platform(fm) is False
def test_explicit_termux_or_android_tag_matches(self):
# Skills can also opt in explicitly via platforms:[termux] or
# platforms:[android] — both should match a Termux session.
with patch("agent.skill_utils.sys.platform", "android"), patch(
"agent.skill_utils.is_termux", return_value=True
):
assert skill_matches_platform({"platforms": ["termux"]}) is True
assert skill_matches_platform({"platforms": ["android"]}) is True
def test_non_termux_android_does_not_widen(self):
# If we're somehow on a plain Android Python (not Termux), don't
# silently load Linux skills — Termux is the supported environment.
fm = {"platforms": ["linux"]}
with patch("agent.skill_utils.sys.platform", "android"), patch(
"agent.skill_utils.is_termux", return_value=False
):
assert skill_matches_platform(fm) is False
def test_linux_skill_on_real_linux_unaffected(self):
# The non-Termux Linux path must not change.
fm = {"platforms": ["linux"]}
with patch("agent.skill_utils.sys.platform", "linux"), patch(
"agent.skill_utils.is_termux", return_value=False
):
assert skill_matches_platform(fm) is True
def test_macos_skill_on_real_macos_unaffected(self):
fm = {"platforms": ["macos"]}
with patch("agent.skill_utils.sys.platform", "darwin"), patch(
"agent.skill_utils.is_termux", return_value=False
):
assert skill_matches_platform(fm) is True

View file

@ -0,0 +1,297 @@
"""Regression tests for issue #31179.
Before the fix:
- ``auxiliary.vision.provider: openai`` silently failed to resolve because
``openai`` is not a first-class provider in PROVIDER_REGISTRY (only
``openai-codex`` for OAuth and ``custom`` for OPENAI_BASE_URL).
- The vision branch of ``call_llm`` then silently fell back to ``auto``
which happily picked the user's main provider (e.g. DeepSeek), sending
image content to a text-only endpoint and producing cryptic
``unknown variant 'image_url', expected 'text'`` errors.
- ``check_vision_requirements`` used the explicit-only path, so
``vision_analyze`` disappeared from the tool list while ``browser_vision``
stayed (its check_fn only validated the browser).
The three fixes covered here:
1. ``provider: openai`` in auxiliary task config resolves to
``custom`` + ``https://api.openai.com/v1``.
2. The vision auto-detect chain skips the user's main provider when it
reports ``supports_vision=False`` instead of routing image content to
a text-only endpoint.
3. ``check_vision_requirements`` mirrors the runtime fallback chain so
``vision_analyze`` shows up whenever the auto chain can serve vision,
and ``browser_vision`` gates on vision availability as well.
"""
from __future__ import annotations
import os
import shutil
import sys
import tempfile
import pytest
# ---------------------------------------------------------------------------
# Test infrastructure
# ---------------------------------------------------------------------------
@pytest.fixture
def isolated_home(monkeypatch):
"""Temp HERMES_HOME with config + clean credential env vars."""
test_home = tempfile.mkdtemp(prefix="hermes_test_31179_")
hermes_home = os.path.join(test_home, ".hermes")
os.makedirs(hermes_home)
monkeypatch.setenv("HERMES_HOME", hermes_home)
# Strip all credential-shaped env vars so each scenario starts hermetic.
for k in list(os.environ.keys()):
if k.endswith("_API_KEY") or k.endswith("_TOKEN"):
monkeypatch.delenv(k, raising=False)
yield hermes_home
shutil.rmtree(test_home, ignore_errors=True)
def _write_config(home: str, text: str) -> None:
with open(os.path.join(home, "config.yaml"), "w") as fp:
fp.write(text)
def _fresh_modules():
"""Drop cached hermes modules so each test reloads against current env."""
for mod in list(sys.modules.keys()):
if mod.startswith(("agent.auxiliary_client", "agent.image_routing",
"tools.vision_tools", "tools.browser_tool",
"hermes_cli.config")):
del sys.modules[mod]
# ---------------------------------------------------------------------------
# Fix 1: provider=openai → custom + api.openai.com/v1
# ---------------------------------------------------------------------------
class TestOpenAiAliasForAuxiliary:
"""``auxiliary.<task>.provider: openai`` should produce a working client."""
def test_provider_openai_routes_to_openai_dot_com(self, isolated_home, monkeypatch):
_write_config(isolated_home, """
auxiliary:
vision:
provider: openai
model: gpt-4o-mini
""")
monkeypatch.setenv("OPENAI_API_KEY", "sk-test")
_fresh_modules()
from agent.auxiliary_client import _resolve_task_provider_model
provider, model, base_url, _key, _mode = _resolve_task_provider_model("vision")
assert provider == "custom"
assert model == "gpt-4o-mini"
assert base_url == "https://api.openai.com/v1"
def test_provider_openai_with_explicit_base_url_preserves_user_endpoint(
self, isolated_home, monkeypatch
):
"""User-supplied base_url wins; alias still normalizes provider name
to ``custom`` so resolution doesn't hit the unknown-provider path."""
_write_config(isolated_home, """
auxiliary:
vision:
provider: openai
model: gpt-4o-mini
base_url: https://my-proxy.example.com/v1
""")
monkeypatch.setenv("OPENAI_API_KEY", "sk-test")
_fresh_modules()
from agent.auxiliary_client import _resolve_task_provider_model
provider, _model, base_url, _key, _mode = _resolve_task_provider_model("vision")
assert provider == "custom"
assert base_url == "https://my-proxy.example.com/v1"
def test_provider_openai_resolves_to_working_client(self, isolated_home, monkeypatch):
"""End-to-end: the resolved client points at api.openai.com."""
_write_config(isolated_home, """
auxiliary:
vision:
provider: openai
model: gpt-4o-mini
""")
monkeypatch.setenv("OPENAI_API_KEY", "sk-test")
_fresh_modules()
from agent.auxiliary_client import resolve_vision_provider_client
from urllib.parse import urlparse
provider, client, model = resolve_vision_provider_client()
assert client is not None, "openai alias should produce a usable client"
# Exact hostname comparison (not substring) — defends against URLs
# like ``api.openai.com.evil.example`` and keeps CodeQL happy.
host = urlparse(str(getattr(client, "base_url", ""))).hostname or ""
assert host == "api.openai.com", f"expected api.openai.com host, got {host!r}"
assert model == "gpt-4o-mini"
# ---------------------------------------------------------------------------
# Fix 2: auto chain skips text-only main providers
# ---------------------------------------------------------------------------
class TestTextOnlyMainSkippedForVision:
"""Vision auto-detect must not return a text-only main-provider client."""
def test_text_only_main_skipped_when_no_aggregator(self, isolated_home, monkeypatch):
"""DeepSeek main + no aggregator credentials → no client built.
Pre-fix this silently returned the deepseek client with model
substitution, producing ``unknown variant 'image_url'`` at call time.
"""
_write_config(isolated_home, """
model:
provider: deepseek
default: deepseek-v4-pro
""")
monkeypatch.setenv("DEEPSEEK_API_KEY", "sk-test")
_fresh_modules()
from agent.auxiliary_client import resolve_vision_provider_client
provider, client, _model = resolve_vision_provider_client(provider="auto")
assert client is None, (
f"Vision auto-detect must skip text-only main {provider!r} when "
"no vision-capable aggregator is available, not return a client "
"that will fail at API time"
)
def test_vision_capable_main_used(self, isolated_home, monkeypatch):
"""Vision-capable main provider should be returned by auto chain."""
_write_config(isolated_home, """
model:
provider: anthropic
default: claude-sonnet-4-6
""")
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test")
_fresh_modules()
from agent.auxiliary_client import resolve_vision_provider_client
provider, client, _model = resolve_vision_provider_client(provider="auto")
assert client is not None
assert provider == "anthropic"
def test_unknown_capability_does_not_block(self, isolated_home, monkeypatch):
"""When models.dev has no entry, fall back to permissive (attempt the call).
This keeps new/custom providers working only providers we have
cataloged as text-only are skipped.
"""
_fresh_modules()
from agent.auxiliary_client import _main_model_supports_vision
# Bogus provider/model — capability lookup returns None → permissive.
assert _main_model_supports_vision("nonexistent-provider", "nonexistent-model") is True
# ---------------------------------------------------------------------------
# Fix 3: check_vision_requirements + check_browser_vision_requirements parity
# ---------------------------------------------------------------------------
class TestVisionToolGating:
"""Tool visibility must match runtime capability."""
def test_check_vision_succeeds_for_aliased_openai(self, isolated_home, monkeypatch):
"""The user's exact reported scenario: provider=openai unhides
vision_analyze instead of silently dropping it."""
_write_config(isolated_home, """
auxiliary:
vision:
provider: openai
model: gpt-4o-mini
""")
monkeypatch.setenv("OPENAI_API_KEY", "sk-test")
_fresh_modules()
from tools.vision_tools import check_vision_requirements
assert check_vision_requirements() is True
def test_check_vision_falls_back_to_auto(self, isolated_home, monkeypatch):
"""Bad explicit provider doesn't hide the tool when auto fallback works.
Mirrors call_llm's runtime fallback chain.
"""
_write_config(isolated_home, """
model:
provider: openrouter
default: anthropic/claude-sonnet-4
auxiliary:
vision:
provider: not-a-real-provider
""")
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-test")
_fresh_modules()
from tools.vision_tools import check_vision_requirements
assert check_vision_requirements() is True
def test_check_vision_false_with_text_only_main_and_no_aggregator(
self, isolated_home, monkeypatch
):
_write_config(isolated_home, """
model:
provider: deepseek
default: deepseek-v4-pro
""")
monkeypatch.setenv("DEEPSEEK_API_KEY", "sk-test")
_fresh_modules()
from tools.vision_tools import check_vision_requirements
assert check_vision_requirements() is False
def test_browser_vision_requires_both_browser_and_vision(self, isolated_home, monkeypatch):
"""``browser_vision`` must not be advertised when vision is unavailable."""
from unittest.mock import patch
_write_config(isolated_home, """
model:
provider: deepseek
default: deepseek-v4-pro
""")
monkeypatch.setenv("DEEPSEEK_API_KEY", "sk-test")
_fresh_modules()
import tools.browser_tool
# Force the browser side to True so we exercise the vision-gating part.
with patch.object(tools.browser_tool, "check_browser_requirements", return_value=True):
assert tools.browser_tool.check_browser_vision_requirements() is False
def test_browser_vision_false_when_browser_missing(self, isolated_home, monkeypatch):
from unittest.mock import patch
_write_config(isolated_home, """
model:
provider: openrouter
default: anthropic/claude-sonnet-4
""")
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-test")
_fresh_modules()
import tools.browser_tool
with patch.object(tools.browser_tool, "check_browser_requirements", return_value=False):
# Vision available but browser missing → still False.
assert tools.browser_tool.check_browser_vision_requirements() is False
def test_browser_vision_true_when_both_available(self, isolated_home, monkeypatch):
from unittest.mock import patch
_write_config(isolated_home, """
model:
provider: openrouter
default: anthropic/claude-sonnet-4
""")
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-test")
_fresh_modules()
import tools.browser_tool
with patch.object(tools.browser_tool, "check_browser_requirements", return_value=True):
assert tools.browser_tool.check_browser_vision_requirements() is True

View file

@ -66,6 +66,38 @@ class TestChatCompletionsBasic:
# Original list untouched (deepcopy-on-demand)
assert msgs[2]["tool_name"] == "execute_code"
def test_convert_messages_strips_internal_scaffolding_markers(self, transport):
"""Hermes-internal ``_``-prefixed markers must never reach the wire.
The empty-response recovery path appends synthetic messages tagged
with ``_empty_recovery_synthetic``; permissive providers ignore the
unknown key, but strict gateways (opencode-go, codex.nekos.me)
reject the request, poisoning every later turn in the session.
"""
msgs = [
{"role": "user", "content": "run the task"},
{"role": "assistant", "content": "(empty)", "_empty_recovery_synthetic": True},
{"role": "user", "content": "continue", "_empty_recovery_synthetic": True},
{"role": "assistant", "content": "done", "_thinking_prefill": True,
"_empty_terminal_sentinel": True},
]
result = transport.convert_messages(msgs)
for m in result:
assert not any(k.startswith("_") for k in m), m
# Visible content preserved
assert result[1]["content"] == "(empty)"
assert result[2]["content"] == "continue"
# Original list untouched (deepcopy-on-demand)
assert msgs[1]["_empty_recovery_synthetic"] is True
def test_convert_messages_clean_list_is_identity(self, transport):
"""A list with no internal/codex keys is returned as-is (no copy)."""
msgs = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
]
assert transport.convert_messages(msgs) is msgs
class TestChatCompletionsBuildKwargs:

View file

@ -20,6 +20,7 @@ from agent.transports.codex_app_server_session import (
TurnResult,
_ServerRequestRouting,
_approval_choice_to_codex_decision,
_coerce_turn_input_text,
)
@ -128,6 +129,15 @@ class TestApprovalChoiceMapping:
assert _approval_choice_to_codex_decision(choice) == expected
class TestTurnInputCoercion:
def test_list_content_keeps_text_and_marks_images(self):
text = _coerce_turn_input_text([
{"type": "text", "text": "caption"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
])
assert text == "caption\n\n[image attached]"
# ---- lifecycle ----
class TestLifecycle:
@ -188,6 +198,35 @@ class TestRunTurn:
# turn_id propagated for downstream session-DB linkage
assert r.turn_id == "turn-fake-001"
def test_rich_content_turn_is_collapsed_to_text_payload(self):
client = FakeClient()
client.queue_notification(
"turn/completed",
threadId="t",
turn={"id": "tu1", "status": "completed", "error": None},
)
s = make_session(client)
r = s.run_turn(
[
{
"type": "text",
"text": "look at this\n\n[Image attached at: /tmp/a.png]",
},
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,abc"},
},
],
turn_timeout=2.0,
)
assert r.error is None
method, params = next(req for req in client.requests if req[0] == "turn/start")
assert method == "turn/start"
text = params["input"][0]["text"]
assert isinstance(text, str)
assert "[Image attached at: /tmp/a.png]" in text
assert "[image attached]" in text
def test_tool_iteration_counter_ticks(self):
client = FakeClient()
# Two completed exec items + one final agent message

View file

@ -168,6 +168,25 @@ class TestBranchCommandCLI:
assert cli_instance._resumed is True
def test_branch_rotates_hermes_session_id_env_and_context(self, cli_instance, session_db):
"""Branching must update process-local session-id readers too."""
from cli import HermesCLI
from gateway.session_context import _UNSET, _VAR_MAP, get_session_env
old_session_id = cli_instance.session_id
os.environ["HERMES_SESSION_ID"] = old_session_id
_VAR_MAP["HERMES_SESSION_ID"].set(old_session_id)
try:
HermesCLI._handle_branch_command(cli_instance, "/branch")
assert cli_instance.session_id != old_session_id
assert os.environ["HERMES_SESSION_ID"] == cli_instance.session_id
assert get_session_env("HERMES_SESSION_ID") == cli_instance.session_id
finally:
os.environ.pop("HERMES_SESSION_ID", None)
_VAR_MAP["HERMES_SESSION_ID"].set(_UNSET)
def test_branch_fires_on_session_switch_hook(self, cli_instance, session_db):
"""The /branch command must notify memory providers of the rotation.

View file

@ -102,6 +102,20 @@ class TestVerboseAndToolProgress:
assert cli.tool_progress_mode in {"off", "new", "all", "verbose"}
class TestFallbackChainInit:
def test_merges_new_and_legacy_fallback_config(self):
cli = _make_cli(config_overrides={
"fallback_providers": [
{"provider": "openrouter", "model": "anthropic/claude-sonnet-4.6"},
],
"fallback_model": {"provider": "nous", "model": "Hermes-4"},
})
assert cli._fallback_model == [
{"provider": "openrouter", "model": "anthropic/claude-sonnet-4.6"},
{"provider": "nous", "model": "Hermes-4"},
]
class TestBusyInputMode:
def test_default_busy_input_mode_is_interrupt(self):
cli = _make_cli()
@ -317,7 +331,63 @@ class TestHistoryDisplay:
assert "Recent sessions" in output
assert "Checking Running Hermes Agent" in output
assert "Use /resume <session id or title> to continue" in output
assert "Use /resume" in output
assert "session title" in output
def test_resume_updates_hermes_session_id_env_and_context(self, tmp_path):
from gateway.session_context import _UNSET, _VAR_MAP, get_session_env
from hermes_state import SessionDB
cli = _make_cli()
cli.session_id = "current_session"
cli.conversation_history = []
cli.agent = None
cli._session_db = SessionDB(db_path=tmp_path / "state.db")
cli._session_db.create_session("current_session", "cli")
cli._session_db.create_session("target_session", "cli")
cli._session_db.append_message("target_session", "user", "hello from resumed session")
os.environ["HERMES_SESSION_ID"] = "current_session"
_VAR_MAP["HERMES_SESSION_ID"].set("current_session")
try:
cli._handle_resume_command("/resume target_session")
assert cli.session_id == "target_session"
assert os.environ["HERMES_SESSION_ID"] == "target_session"
assert get_session_env("HERMES_SESSION_ID") == "target_session"
finally:
cli._session_db.close()
os.environ.pop("HERMES_SESSION_ID", None)
_VAR_MAP["HERMES_SESSION_ID"].set(_UNSET)
def test_resume_list_shows_full_long_titles(self, capsys):
"""Long session titles render in full in the /resume table — not
truncated to 30 chars (fixes #14082)."""
cli = _make_cli()
cli.session_id = "current"
cli._session_db = MagicMock()
long_title = "Salvage BytePlus Volcengine PR With Fixes"
cli._session_db.list_sessions_rich.return_value = [
{
"id": "current",
"title": "Current",
"preview": "Current preview",
"last_active": 0,
},
{
"id": "20260401_201329_d85961",
"title": long_title,
"preview": "fix byteplus pr and resume",
"last_active": 0,
},
]
cli._handle_resume_command("/resume")
output = capsys.readouterr().out
assert long_title in output
assert "20260401_201329_d85961" in output
def test_sessions_command_no_args_lists_recent_sessions(self, capsys):
"""/sessions with no args prints the recent-sessions table (TUI parity).
@ -429,8 +499,8 @@ class TestRootLevelProviderOverride:
assert cfg["model"]["provider"] == "openrouter"
def test_root_provider_ignored_when_default_model_provider_exists(self, tmp_path, monkeypatch):
"""Even when model.provider is the default 'auto', root-level provider is ignored."""
def test_root_provider_used_as_fallback_when_model_provider_missing(self, tmp_path, monkeypatch):
"""Legacy root-level provider still populates model.provider in the CLI loader."""
import yaml
hermes_home = tmp_path / ".hermes"
@ -450,8 +520,29 @@ class TestRootLevelProviderOverride:
monkeypatch.setattr(cli, "_hermes_home", hermes_home)
cfg = cli.load_cli_config()
# Root-level "opencode-go" must NOT leak through
assert cfg["model"]["provider"] != "opencode-go"
assert cfg["model"]["provider"] == "opencode-go"
def test_root_base_url_used_as_fallback_when_model_base_url_missing(self, tmp_path, monkeypatch):
"""Legacy root-level base_url still populates model.base_url in the CLI loader."""
import yaml
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
config_path = hermes_home / "config.yaml"
config_path.write_text(yaml.safe_dump({
"base_url": "https://example.com/v1",
"model": {
"default": "google/gemini-3-flash-preview",
},
}))
import cli
monkeypatch.setattr(cli, "_hermes_home", hermes_home)
cfg = cli.load_cli_config()
assert cfg["model"]["base_url"] == "https://example.com/v1"
def test_terminal_vercel_runtime_bridged_to_env(self, tmp_path, monkeypatch):
"""Classic CLI must expose terminal.vercel_runtime to terminal_tool.py."""

View file

@ -8,6 +8,8 @@ import sys
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
import pytest
from hermes_state import SessionDB
from tools.todo_tool import TodoStore
@ -138,6 +140,15 @@ def _prepare_cli_with_active_session(tmp_path):
return cli
@pytest.fixture(autouse=True)
def _reset_session_id_context():
from gateway.session_context import _UNSET, _VAR_MAP
yield
os.environ.pop("HERMES_SESSION_ID", None)
_VAR_MAP["HERMES_SESSION_ID"].set(_UNSET)
def test_new_command_creates_real_fresh_session_and_resets_agent_state(tmp_path):
cli = _prepare_cli_with_active_session(tmp_path)
old_session_id = cli.session_id
@ -164,6 +175,21 @@ def test_new_command_creates_real_fresh_session_and_resets_agent_state(tmp_path)
cli.agent._invalidate_system_prompt.assert_called_once()
def test_new_command_rotates_hermes_session_id_env_and_context(tmp_path):
from gateway.session_context import _VAR_MAP, get_session_env
cli = _prepare_cli_with_active_session(tmp_path)
old_session_id = cli.session_id
os.environ["HERMES_SESSION_ID"] = old_session_id
_VAR_MAP["HERMES_SESSION_ID"].set(old_session_id)
cli.process_command("/new")
assert cli.session_id != old_session_id
assert os.environ["HERMES_SESSION_ID"] == cli.session_id
assert get_session_env("HERMES_SESSION_ID") == cli.session_id
def test_reset_command_is_alias_for_new_session(tmp_path):
cli = _prepare_cli_with_active_session(tmp_path)
old_session_id = cli.session_id

View file

@ -0,0 +1,77 @@
from unittest.mock import MagicMock, patch
from cli import HermesCLI
def _make_cli():
cli_obj = HermesCLI.__new__(HermesCLI)
cli_obj.session_id = "current_session"
cli_obj._resumed = False
cli_obj._pending_title = None
cli_obj.conversation_history = []
cli_obj.agent = None
cli_obj._session_db = MagicMock()
# _handle_resume_command now triggers _display_resumed_history (#31695),
# which reads self.resume_display. "minimal" short-circuits the recap so
# the test only exercises session-switch behavior.
cli_obj.resume_display = "minimal"
return cli_obj
class TestCliResumeCommand:
def test_show_recent_sessions_includes_indexes_and_resume_hint(self, capsys):
cli_obj = _make_cli()
cli_obj._list_recent_sessions = MagicMock(return_value=[
{"id": "sess_002", "title": "Coding", "preview": "build feature", "last_active": None},
{"id": "sess_001", "title": "Research", "preview": "read docs", "last_active": None},
])
shown = cli_obj._show_recent_sessions(reason="resume")
output = capsys.readouterr().out
assert shown is True
assert "1" in output
assert "2" in output
assert "Coding" in output
assert "Research" in output
assert "/resume 2" in output
assert "/resume <session title>" in output
def test_handle_resume_by_index_switches_to_numbered_session(self):
cli_obj = _make_cli()
cli_obj._list_recent_sessions = MagicMock(return_value=[
{"id": "sess_002", "title": "Coding"},
{"id": "sess_001", "title": "Research"},
])
cli_obj._session_db.get_session.return_value = {"id": "sess_001", "title": "Research"}
cli_obj._session_db.get_messages_as_conversation.return_value = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
]
# resolve_resume_session_id passes the id through when no compression chain.
cli_obj._session_db.resolve_resume_session_id.return_value = "sess_001"
with (
patch("hermes_cli.main._resolve_session_by_name_or_id", return_value=None),
patch("cli._cprint") as mock_cprint,
):
cli_obj._handle_resume_command("/resume 2")
printed = " ".join(str(call) for call in mock_cprint.call_args_list)
assert cli_obj.session_id == "sess_001"
assert "Resumed session sess_001" in printed
assert "Research" in printed
def test_handle_resume_by_index_out_of_range(self):
cli_obj = _make_cli()
cli_obj._list_recent_sessions = MagicMock(return_value=[
{"id": "sess_002", "title": "Coding"},
])
with patch("cli._cprint") as mock_cprint:
cli_obj._handle_resume_command("/resume 9")
printed = " ".join(str(call) for call in mock_cprint.call_args_list)
assert "out of range" in printed.lower()
assert "/resume" in printed
assert cli_obj.session_id == "current_session"

View file

@ -209,3 +209,123 @@ def test_slash_confirm_display_fragments_include_choice_mapping():
assert "[2] Always Approve" in rendered
assert "[3] Cancel" in rendered
assert "Type 1/2/3" in rendered
# ---------------------------------------------------------------------------
# Inline-skip escape hatch (issue #30768)
#
# Users on platforms where the prompt_toolkit modal doesn't dispatch keys
# (currently native Windows PowerShell) need a way to bypass the confirmation
# without flipping the config gate. ``/reset now``, ``/new --yes``, ``/clear
# -y`` all skip the modal and return "once" immediately.
# ---------------------------------------------------------------------------
def test_split_destructive_skip_recognized_tokens():
"""``now``, ``--yes``, and ``-y`` are recognized as skip tokens."""
from cli import HermesCLI
assert HermesCLI._split_destructive_skip("/reset now") == ("", True)
assert HermesCLI._split_destructive_skip("/clear --yes") == ("", True)
assert HermesCLI._split_destructive_skip("/undo -y") == ("", True)
def test_split_destructive_skip_strips_command_word():
"""Leading ``/cmd`` token is stripped; remaining args survive."""
from cli import HermesCLI
assert HermesCLI._split_destructive_skip("/new My title") == ("My title", False)
assert HermesCLI._split_destructive_skip("/new --yes My title") == ("My title", True)
def test_split_destructive_skip_case_insensitive():
"""Token matching is case-insensitive but not a substring match."""
from cli import HermesCLI
assert HermesCLI._split_destructive_skip("/new NOW") == ("", True)
# Substring match must NOT trigger — "Now-Title" is a literal title token.
assert HermesCLI._split_destructive_skip("/new Now-Title") == ("Now-Title", False)
def test_split_destructive_skip_handles_empty_and_none():
"""Defensive against missing/empty input."""
from cli import HermesCLI
assert HermesCLI._split_destructive_skip(None) == ("", False)
assert HermesCLI._split_destructive_skip("") == ("", False)
assert HermesCLI._split_destructive_skip(" ") == ("", False)
def test_confirm_destructive_slash_now_skips_modal():
"""``/reset now`` skips the modal even when the gate is on."""
from cli import HermesCLI
# Build a prompt stub that fails the test if invoked — proving the modal
# was never reached.
def _explode(**_kw):
raise AssertionError("modal must not be invoked when inline-skip present")
self_ = SimpleNamespace(
_app=None,
_prompt_text_input_modal=_explode,
)
self_._normalize_slash_confirm_choice = _bound(
HermesCLI._normalize_slash_confirm_choice, self_,
)
self_._split_destructive_skip = HermesCLI._split_destructive_skip # classmethod
with patch(
"cli.load_cli_config",
return_value={"approvals": {"destructive_slash_confirm": True}},
):
result = _bound(HermesCLI._confirm_destructive_slash, self_)(
"new", "detail", cmd_original="/reset now",
)
assert result == "once"
def test_confirm_destructive_slash_yes_flag_skips_modal():
"""``--yes`` flag is equivalent to ``now``."""
from cli import HermesCLI
def _explode(**_kw):
raise AssertionError("modal must not be invoked when --yes present")
self_ = SimpleNamespace(
_app=None,
_prompt_text_input_modal=_explode,
)
self_._normalize_slash_confirm_choice = _bound(
HermesCLI._normalize_slash_confirm_choice, self_,
)
self_._split_destructive_skip = HermesCLI._split_destructive_skip
with patch(
"cli.load_cli_config",
return_value={"approvals": {"destructive_slash_confirm": True}},
):
result = _bound(HermesCLI._confirm_destructive_slash, self_)(
"new", "detail", cmd_original="/new --yes My Session",
)
assert result == "once"
def test_confirm_destructive_slash_no_skip_token_still_prompts():
"""Without a skip token the gate-on path still consults the modal."""
from cli import HermesCLI
self_ = _make_self(prompt_response="3") # cancel
self_._split_destructive_skip = HermesCLI._split_destructive_skip
with patch(
"cli.load_cli_config",
return_value={"approvals": {"destructive_slash_confirm": True}},
):
result = _bound(HermesCLI._confirm_destructive_slash, self_)(
"new", "detail", cmd_original="/new My Session",
)
# Prompt was reached and returned cancel → None.
assert result is None

View file

@ -0,0 +1,129 @@
"""End-to-end integration test for the destructive-slash inline-skip path.
Drives ``HermesCLI.process_command("/reset now")`` against a minimal stand-in
and verifies:
1. ``new_session`` was invoked (the command actually ran)
2. ``_prompt_text_input_modal`` was NOT invoked (modal bypassed)
3. The skip token did not leak into the session title
This is the regression test for issue #30768 — the inline-skip escape hatch
must work without ever touching the modal, on every platform.
"""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import patch
def _make_cli_stub():
"""Build a minimal HermesCLI-shaped object that can run ``process_command``
for the destructive-slash branches without spinning up a real TUI."""
from cli import HermesCLI
new_session_calls = []
def _capture_new_session(self_, title=None, silent=False):
new_session_calls.append({"title": title, "silent": silent})
self_ = SimpleNamespace(
_app=None,
_prompt_text_input_modal=lambda **_kw: (_ for _ in ()).throw(
AssertionError("modal must not be invoked when inline-skip token present")
),
new_session=lambda **kw: _capture_new_session(self_, **kw),
# Stub out side-effects the destructive-slash branches reach for.
console=SimpleNamespace(clear=lambda: None),
compact=False,
model="stub-model",
session_id="stub-session",
enabled_toolsets=[],
_pending_title=None,
_session_db=None,
)
# Bind the methods we need under test.
self_._split_destructive_skip = HermesCLI._split_destructive_skip
self_._confirm_destructive_slash = HermesCLI._confirm_destructive_slash.__get__(
self_, type(self_)
)
self_.process_command = HermesCLI.process_command.__get__(self_, type(self_))
return self_, new_session_calls
def test_reset_now_invokes_new_session_without_modal():
"""``/reset now`` runs ``new_session`` and never touches the modal."""
self_, calls = _make_cli_stub()
with patch(
"cli.load_cli_config",
return_value={"approvals": {"destructive_slash_confirm": True}},
):
self_.process_command("/reset now")
assert calls, "new_session was never invoked"
# The /new branch passes title=None when there's no non-skip remainder.
assert calls[0]["title"] is None
def test_new_yes_with_title_preserves_title():
"""``/new --yes My Session`` runs ``new_session(title='My Session')``."""
self_, calls = _make_cli_stub()
with patch(
"cli.load_cli_config",
return_value={"approvals": {"destructive_slash_confirm": True}},
):
self_.process_command("/new --yes My Session")
assert calls, "new_session was never invoked"
assert calls[0]["title"] == "My Session"
def test_new_without_skip_token_still_consults_modal():
"""``/new My Session`` (no skip token) must reach the modal.
Sanity check that we haven't accidentally short-circuited the normal path.
"""
from cli import HermesCLI
new_session_calls = []
modal_calls = []
def _capture_new_session(self_, title=None, silent=False):
new_session_calls.append({"title": title, "silent": silent})
def _record_modal(**kw):
modal_calls.append(kw)
# Simulate user cancelling so new_session is not called.
return "3"
self_ = SimpleNamespace(
_app=None,
_prompt_text_input_modal=_record_modal,
new_session=lambda **kw: _capture_new_session(self_, **kw),
console=SimpleNamespace(clear=lambda: None),
compact=False,
model="stub-model",
session_id="stub-session",
enabled_toolsets=[],
_pending_title=None,
_session_db=None,
)
self_._split_destructive_skip = HermesCLI._split_destructive_skip
self_._normalize_slash_confirm_choice = HermesCLI._normalize_slash_confirm_choice.__get__(
self_, type(self_)
)
self_._confirm_destructive_slash = HermesCLI._confirm_destructive_slash.__get__(
self_, type(self_)
)
self_.process_command = HermesCLI.process_command.__get__(self_, type(self_))
with patch(
"cli.load_cli_config",
return_value={"approvals": {"destructive_slash_confirm": True}},
):
self_.process_command("/new My Session")
assert modal_calls, "modal must be reached when no skip token is present"
assert not new_session_calls, "user cancelled — new_session must not run"

View file

@ -155,14 +155,34 @@ class TestDisplayResumedHistory:
assert "Page content" not in output
def test_tool_calls_shown_as_summary(self):
cli = _make_cli()
# Disable tool-only skip so the summary line is rendered for this fixture.
cli = _make_cli(config_overrides={"display": {"resume_skip_tool_only": False}})
cli.conversation_history = _tool_call_history()
output = self._capture_display(cli)
import cli as _cli_mod
# CLI_CONFIG is read at call-time inside _display_resumed_history, so
# apply the override for the duration of the capture, not just at init.
with patch.dict(_cli_mod.__dict__, {"CLI_CONFIG": {
"display": {"resume_skip_tool_only": False, "resume_display": "full"}
}}):
output = self._capture_display(cli)
assert "2 tool calls" in output
assert "web_search" in output
assert "web_extract" in output
def test_tool_only_message_skipped_by_default(self):
"""Assistant messages with only tool_calls (no text) are skipped when
resume_skip_tool_only=True (the default). The summary line is hidden.
"""
cli = _make_cli()
cli.conversation_history = _tool_call_history()
output = self._capture_display(cli)
# The tool-only assistant entry should be skipped
assert "2 tool calls" not in output
# The final text reply should still appear
assert "Here are some great Python tutorials" in output
def test_long_user_message_truncated(self):
cli = _make_cli()
long_text = "A" * 500
@ -611,6 +631,55 @@ class TestPreloadResumedSession:
assert "1 user messages" not in output
# ── Tests for _handle_resume_command recap display ───────────────────
class TestHandleResumeCommandRecap:
"""In-session /resume should show the same recap panel as startup resume."""
def test_resume_command_displays_recap_when_messages_restored(self):
cli = _make_cli()
cli.session_id = "current_session"
messages = _simple_history()
mock_db = MagicMock()
mock_db.get_session.return_value = {"id": "target_session", "title": "Test Session"}
mock_db.get_messages_as_conversation.return_value = messages
# resolve_resume_session_id passes the id through when no compression chain.
mock_db.resolve_resume_session_id.return_value = "target_session"
cli._session_db = mock_db
with (
patch("hermes_cli.main._resolve_session_by_name_or_id", return_value="target_session"),
patch.object(cli, "_display_resumed_history") as display_mock,
):
cli._handle_resume_command("/resume test session")
assert cli.session_id == "target_session"
assert cli.conversation_history == messages
mock_db.end_session.assert_called_once_with("current_session", "resumed_other")
mock_db.reopen_session.assert_called_once_with("target_session")
display_mock.assert_called_once_with()
def test_resume_command_skips_recap_when_session_has_no_messages(self):
cli = _make_cli()
cli.session_id = "current_session"
mock_db = MagicMock()
mock_db.get_session.return_value = {"id": "target_session", "title": None}
mock_db.get_messages_as_conversation.return_value = []
mock_db.resolve_resume_session_id.return_value = "target_session"
cli._session_db = mock_db
with (
patch("hermes_cli.main._resolve_session_by_name_or_id", return_value="target_session"),
patch.object(cli, "_display_resumed_history") as display_mock,
):
cli._handle_resume_command("/resume target_session")
display_mock.assert_not_called()
# ── Integration: _init_agent skips when preloaded ────────────────────

View file

@ -14,9 +14,10 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
# Module-level reference to the cli module (set by _make_cli on first call)
_cli_mod = None
_UNSET = object()
def _make_cli(tool_progress="all"):
def _make_cli(tool_progress="all", verbose=_UNSET):
"""Create a HermesCLI instance with minimal mocking."""
global _cli_mod
_clean_config = {
@ -54,7 +55,9 @@ def _make_cli(tool_progress="all"):
_cli_mod = mod
with patch.object(mod, "get_tool_definitions", return_value=[]), \
patch.dict(mod.__dict__, {"CLI_CONFIG": _clean_config}):
return mod.HermesCLI()
if verbose is _UNSET:
return mod.HermesCLI()
return mod.HermesCLI(verbose=verbose)
class TestToolProgressScrollback:
@ -122,14 +125,21 @@ class TestToolProgressScrollback:
mock_print.assert_not_called()
def test_error_suffix_on_failed_tool(self):
"""When is_error=True, the stacked line includes [error]."""
"""When a failed tool's result is forwarded, the stacked line surfaces
the specific error (e.g. ``[exit 1]`` or ``[File not found: x]``)
instead of the legacy generic ``[error]`` suffix."""
import json
cli = _make_cli(tool_progress="all")
cli._on_tool_progress("tool.started", "terminal", "bad cmd", {"command": "bad cmd"})
cli._on_tool_progress("tool.started", "terminal", "false", {"command": "false"})
with patch.object(_cli_mod, "_cprint") as mock_print:
cli._on_tool_progress("tool.completed", "terminal", None, None, duration=0.5, is_error=True)
cli._on_tool_progress(
"tool.completed", "terminal", None, None,
duration=0.5, is_error=True,
result=json.dumps({"output": "", "exit_code": 1}),
)
line = mock_print.call_args[0][0]
assert "[error]" in line
assert "[exit 1]" in line
def test_spinner_still_updates_on_started(self):
"""tool.started still updates the spinner text for live display."""
@ -168,6 +178,35 @@ class TestToolProgressScrollback:
mock_print.assert_not_called()
def test_verbose_mode_config_does_not_enable_global_debug_logging(self):
"""display.tool_progress=verbose controls TOOL-CALL DISPLAY ONLY.
It must NOT auto-flip self.verbose, which controls root-logger DEBUG
level for the entire process (every module spews to console). PR
#6a1aa420e had coupled them, causing all debug logs to flood the
terminal whenever a user picked tool_progress: verbose for richer
per-tool rendering.
"""
cli = _make_cli(tool_progress="verbose")
assert cli.tool_progress_mode == "verbose"
assert cli.verbose is False
def test_explicit_verbose_argument_wins_over_config(self):
"""Explicit verbose=True from the CLI flag still enables DEBUG logging
regardless of tool_progress_mode."""
cli = _make_cli(tool_progress="off", verbose=True)
assert cli.tool_progress_mode == "off"
assert cli.verbose is True
def test_explicit_non_verbose_argument_keeps_debug_logging_off(self):
"""Explicit verbose=False overrides any default to enable DEBUG."""
cli = _make_cli(tool_progress="verbose", verbose=False)
assert cli.tool_progress_mode == "verbose"
assert cli.verbose is False
def test_pending_info_stores_on_started(self):
"""tool.started stores args for later use by tool.completed."""
cli = _make_cli(tool_progress="all")

View file

@ -358,6 +358,10 @@ def _hermetic_environment(tmp_path, monkeypatch):
monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "true")
monkeypatch.setenv("AWS_METADATA_SERVICE_TIMEOUT", "1")
monkeypatch.setenv("AWS_METADATA_SERVICE_NUM_ATTEMPTS", "1")
# Tirith auto-installs from GitHub when enabled and missing. Unit tests
# should never perform that implicit network/bootstrap path; Tirith-specific
# tests opt back in by patching the security config directly.
monkeypatch.setenv("TIRITH_ENABLED", "false")
# 5. Reset plugin singleton so tests don't leak plugins from
# ~/.hermes/plugins/ (which, per step 3, is now empty — but the

View file

@ -490,6 +490,17 @@ class TestRoutingIntents:
class TestDeliverResultWrapping:
"""Verify that cron deliveries are wrapped with header/footer and no longer mirrored."""
def _safe_media_path(self, tmp_path, monkeypatch, name, data=b"media"):
root = tmp_path / "media-cache"
media_file = root / name
media_file.parent.mkdir(parents=True, exist_ok=True)
media_file.write_bytes(data)
monkeypatch.setattr(
"gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS",
(root,),
)
return media_file.resolve()
def test_delivery_wraps_content_with_header_and_footer(self):
"""Delivered content should include task name header and agent-invisible note."""
from gateway.config import Platform
@ -564,9 +575,10 @@ class TestDeliverResultWrapping:
assert "Cronjob Response" not in sent_content
assert "The agent cannot see" not in sent_content
def test_delivery_extracts_media_tags_before_send(self):
def test_delivery_extracts_media_tags_before_send(self, tmp_path, monkeypatch):
"""Cron delivery should pass MEDIA attachments separately to the send helper."""
from gateway.config import Platform
media_path = self._safe_media_path(tmp_path, monkeypatch, "test-voice.ogg")
pconfig = MagicMock()
pconfig.enabled = True
@ -581,7 +593,7 @@ class TestDeliverResultWrapping:
"deliver": "origin",
"origin": {"platform": "telegram", "chat_id": "123"},
}
_deliver_result(job, "Title\nMEDIA:/tmp/test-voice.ogg")
_deliver_result(job, f"Title\nMEDIA:{media_path}")
send_mock.assert_called_once()
args, kwargs = send_mock.call_args
@ -589,14 +601,15 @@ class TestDeliverResultWrapping:
assert "MEDIA:" not in args[3]
assert "Title" in args[3]
# Media files should be forwarded separately
assert kwargs["media_files"] == [("/tmp/test-voice.ogg", False)]
assert kwargs["media_files"] == [(str(media_path), False)]
def test_live_adapter_sends_media_as_attachments(self):
def test_live_adapter_sends_media_as_attachments(self, tmp_path, monkeypatch):
"""When a live adapter is available, MEDIA files should be sent as native
platform attachments (e.g., Discord voice, Telegram audio) rather than
as literal 'MEDIA:/path' text."""
from gateway.config import Platform
from concurrent.futures import Future
media_path = self._safe_media_path(tmp_path, monkeypatch, "cron-voice.mp3")
adapter = AsyncMock()
adapter.send.return_value = MagicMock(success=True)
@ -628,7 +641,7 @@ class TestDeliverResultWrapping:
patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
_deliver_result(
job,
"Here is TTS\nMEDIA:/tmp/cron-voice.mp3",
f"Here is TTS\nMEDIA:{media_path}",
adapters={Platform.DISCORD: adapter},
loop=loop,
)
@ -642,12 +655,13 @@ class TestDeliverResultWrapping:
# Audio file should be sent as a voice attachment
adapter.send_voice.assert_called_once()
voice_call = adapter.send_voice.call_args
assert voice_call[1]["audio_path"] == "/tmp/cron-voice.mp3"
assert voice_call[1]["audio_path"] == str(media_path)
def test_live_adapter_routes_image_to_send_image_file(self):
def test_live_adapter_routes_image_to_send_image_file(self, tmp_path, monkeypatch):
"""Image MEDIA files should be routed to send_image_file, not send_voice."""
from gateway.config import Platform
from concurrent.futures import Future
media_path = self._safe_media_path(tmp_path, monkeypatch, "chart.png")
adapter = AsyncMock()
adapter.send.return_value = MagicMock(success=True)
@ -678,19 +692,20 @@ class TestDeliverResultWrapping:
patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
_deliver_result(
job,
"Chart attached\nMEDIA:/tmp/chart.png",
f"Chart attached\nMEDIA:{media_path}",
adapters={Platform.DISCORD: adapter},
loop=loop,
)
adapter.send_image_file.assert_called_once()
assert adapter.send_image_file.call_args[1]["image_path"] == "/tmp/chart.png"
assert adapter.send_image_file.call_args[1]["image_path"] == str(media_path)
adapter.send_voice.assert_not_called()
def test_live_adapter_media_only_no_text(self):
def test_live_adapter_media_only_no_text(self, tmp_path, monkeypatch):
"""When content is ONLY a MEDIA tag with no text, media should still be sent."""
from gateway.config import Platform
from concurrent.futures import Future
media_path = self._safe_media_path(tmp_path, monkeypatch, "voice.ogg")
adapter = AsyncMock()
adapter.send_voice.return_value = MagicMock(success=True)
@ -720,7 +735,7 @@ class TestDeliverResultWrapping:
patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
_deliver_result(
job,
"[[audio_as_voice]]\nMEDIA:/tmp/voice.ogg",
f"[[audio_as_voice]]\nMEDIA:{media_path}",
adapters={Platform.TELEGRAM: adapter},
loop=loop,
)
@ -2164,43 +2179,56 @@ class TestBuildJobPromptBumpUse:
class TestSendMediaViaAdapter:
"""Unit tests for _send_media_via_adapter — routes files to typed adapter methods."""
def _safe_media_path(self, tmp_path, monkeypatch, name, data=b"media"):
root = tmp_path / "media-cache"
media_file = root / name
media_file.parent.mkdir(parents=True, exist_ok=True)
media_file.write_bytes(data)
monkeypatch.setattr(
"gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS",
(root,),
)
return media_file.resolve()
@staticmethod
def _run_with_loop(adapter, chat_id, media_files, metadata, job):
"""Helper: run _send_media_via_adapter with a real running event loop."""
import asyncio
import threading
"""Helper: run _send_media_via_adapter with immediate scheduling."""
from concurrent.futures import Future
loop = asyncio.new_event_loop()
t = threading.Thread(target=loop.run_forever, daemon=True)
t.start()
try:
_send_media_via_adapter(adapter, chat_id, media_files, metadata, loop, job)
finally:
loop.call_soon_threadsafe(loop.stop)
t.join(timeout=5)
loop.close()
def fake_run_coro(coro, _loop):
coro.close()
completed = Future()
completed.set_result(MagicMock(success=True))
return completed
def test_video_dispatched_to_send_video(self):
with patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
_send_media_via_adapter(adapter, chat_id, media_files, metadata, MagicMock(), job)
def test_video_dispatched_to_send_video(self, tmp_path, monkeypatch):
adapter = MagicMock()
adapter.send_video = AsyncMock()
media_files = [("/tmp/clip.mp4", False)]
media_path = self._safe_media_path(tmp_path, monkeypatch, "clip.mp4")
media_files = [(str(media_path), False)]
self._run_with_loop(adapter, "123", media_files, None, {"id": "j1"})
adapter.send_video.assert_called_once()
assert adapter.send_video.call_args[1]["video_path"] == "/tmp/clip.mp4"
assert adapter.send_video.call_args[1]["video_path"] == str(media_path)
def test_unknown_ext_dispatched_to_send_document(self):
def test_unknown_ext_dispatched_to_send_document(self, tmp_path, monkeypatch):
adapter = MagicMock()
adapter.send_document = AsyncMock()
media_files = [("/tmp/report.pdf", False)]
media_path = self._safe_media_path(tmp_path, monkeypatch, "report.pdf")
media_files = [(str(media_path), False)]
self._run_with_loop(adapter, "123", media_files, None, {"id": "j2"})
adapter.send_document.assert_called_once()
assert adapter.send_document.call_args[1]["file_path"] == "/tmp/report.pdf"
assert adapter.send_document.call_args[1]["file_path"] == str(media_path)
def test_multiple_media_files_all_delivered(self):
def test_multiple_media_files_all_delivered(self, tmp_path, monkeypatch):
adapter = MagicMock()
adapter.send_voice = AsyncMock()
adapter.send_image_file = AsyncMock()
media_files = [("/tmp/voice.mp3", False), ("/tmp/photo.jpg", False)]
voice_path = self._safe_media_path(tmp_path, monkeypatch, "voice.mp3")
photo_path = self._safe_media_path(tmp_path, monkeypatch, "photo.jpg")
media_files = [(str(voice_path), False), (str(photo_path), False)]
self._run_with_loop(adapter, "123", media_files, None, {"id": "j3"})
adapter.send_voice.assert_called_once()
adapter.send_image_file.assert_called_once()
@ -2462,7 +2490,7 @@ class TestSendMediaTimeoutCancelsFuture:
in-flight coroutine must be cancelled before the next file is tried.
"""
def test_media_send_timeout_cancels_future_and_continues(self):
def test_media_send_timeout_cancels_future_and_continues(self, tmp_path, monkeypatch):
"""End-to-end: _send_media_via_adapter with a future whose .result()
raises TimeoutError. Assert cancel() fires and the loop proceeds
to the next file rather than hanging or crashing."""
@ -2493,9 +2521,19 @@ class TestSendMediaTimeoutCancelsFuture:
coro.close()
return next(futures_iter)
root = tmp_path / "media-cache"
slow = root / "slow.png"
fast = root / "fast.mp4"
slow.parent.mkdir(parents=True)
slow.write_bytes(b"slow")
fast.write_bytes(b"fast")
monkeypatch.setattr(
"gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS",
(root,),
)
media_files = [
("/tmp/slow.png", False), # times out
("/tmp/fast.mp4", False), # succeeds
(str(slow), False), # times out
(str(fast), False), # succeeds
]
loop = MagicMock()
@ -2509,4 +2547,4 @@ class TestSendMediaTimeoutCancelsFuture:
assert timeout_cancel_calls == [True], "future.cancel() must fire on TimeoutError"
# 2. Second file still got dispatched — one timeout doesn't abort the batch
adapter.send_video.assert_called_once()
assert adapter.send_video.call_args[1]["video_path"] == "/tmp/fast.mp4"
assert adapter.send_video.call_args[1]["video_path"] == str(fast.resolve())

View file

@ -119,7 +119,7 @@ _ensure_slack_mock()
import discord # noqa: E402 — mocked above
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
from gateway.platforms.discord import DiscordAdapter # noqa: E402
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
import gateway.platforms.slack as _slack_mod # noqa: E402
_slack_mod.SLACK_AVAILABLE = True

View file

@ -1,20 +1,10 @@
"""Regression test for #4469.
"""Regression tests for active-session TEXT follow-up queueing.
When the agent is actively running (session present in
``adapter._active_sessions``) and the user fires off multiple TEXT
follow-ups in rapid succession, the previous behaviour was a single-slot
replacement at ``gateway/platforms/base.py``:
self._pending_messages[session_key] = event
So three rapid messages ``A``, ``B``, ``C`` arriving while the agent was
still working on the initial turn produced a pending slot containing only
``C``; ``A`` and ``B`` were silently dropped.
The fix routes the follow-up through ``merge_pending_message_event(...,
merge_text=True)`` so TEXT events accumulate into the existing pending
event's text instead of clobbering it. Photo / media bursts continue to
merge through the same helper (they always did).
When the agent is actively running, rapid text follow-ups should survive as
one next-turn pending message instead of clobbering each other. In
``busy_text_mode=queue`` those active follow-ups first pass through a short
debounce so bursty multi-message thoughts are merged before the active drain
hands off the next turn.
"""
from __future__ import annotations
@ -22,7 +12,7 @@ from __future__ import annotations
import asyncio
import sys
import types
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -44,16 +34,27 @@ from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
MessageType,
SendResult,
)
from gateway.session import SessionSource, build_session_key
def _make_event(text: str, chat_id: str = "12345") -> MessageEvent:
def _make_event(
text: str,
chat_id: str = "12345",
*,
chat_type: str = "dm",
user_id: str = "u1",
user_name: str | None = None,
thread_id: str | None = None,
) -> MessageEvent:
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id=chat_id,
chat_type="dm",
user_id="u1",
chat_type=chat_type,
user_id=user_id,
user_name=user_name,
thread_id=thread_id,
)
return MessageEvent(
text=text,
@ -63,27 +64,26 @@ def _make_event(text: str, chat_id: str = "12345") -> MessageEvent:
)
class _DummyAdapter(BasePlatformAdapter): # type: ignore[misc]
async def connect(self):
pass
async def disconnect(self):
pass
async def get_chat_info(self, chat_id):
return None
async def send(self, *args, **kwargs):
return SendResult(success=True, message_id="x")
def _make_initialized_adapter() -> BasePlatformAdapter:
return _DummyAdapter(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
def _make_adapter() -> BasePlatformAdapter:
"""Build a BasePlatformAdapter without running its heavy __init__.
We only need the bits ``handle_message`` touches on the active-session
path: ``_active_sessions``, ``_pending_messages``,
``_message_handler``, ``_busy_session_handler``, ``config``, ``platform``.
"""
class _DummyAdapter(BasePlatformAdapter): # type: ignore[misc]
async def connect(self):
pass
async def disconnect(self):
pass
async def get_chat_info(self, chat_id):
return None
async def send(self, *args, **kwargs):
return MagicMock(success=True, message_id="x", retryable=False)
"""Build a BasePlatformAdapter without running its heavy __init__."""
adapter = object.__new__(_DummyAdapter)
adapter.config = PlatformConfig(enabled=True, token="***")
adapter.platform = Platform.TELEGRAM
@ -100,6 +100,10 @@ def _make_adapter() -> BasePlatformAdapter:
adapter._fatal_error_retryable = True
adapter._fatal_error_handler = None
adapter._running = True
adapter._busy_text_mode = "queue"
adapter._busy_text_debounce_seconds = 0.1
adapter._busy_text_hard_cap_seconds = 1.0
adapter._text_debounce = {}
adapter._auto_tts_default = False
adapter._auto_tts_enabled_chats = set()
adapter._auto_tts_disabled_chats = set()
@ -107,39 +111,235 @@ def _make_adapter() -> BasePlatformAdapter:
return adapter
def _debounced_event(adapter: BasePlatformAdapter, session_key: str) -> MessageEvent:
return adapter._text_debounce[session_key].event
@pytest.mark.asyncio
async def test_rapid_text_followups_accumulate_instead_of_replacing():
"""Three rapid TEXT follow-ups during an active session must all
survive in ``adapter._pending_messages[session_key].text``."""
"""Rapid TEXT follow-ups must all survive in the pending event."""
adapter = _make_adapter()
adapter._busy_text_mode = "" # direct-merge behavior, no debounce
first = _make_event("part one")
session_key = build_session_key(first.source)
# Mark the session as active so subsequent messages take the
# "already running" branch in handle_message.
adapter._active_sessions[session_key] = asyncio.Event()
second = _make_event("part two")
third = _make_event("part three")
await adapter.handle_message(_make_event("part two"))
await adapter.handle_message(_make_event("part three"))
await adapter.handle_message(second)
await adapter.handle_message(third)
# Both rapid follow-ups must be preserved, not just the last one.
pending = adapter._pending_messages[session_key]
assert pending.text == "part two\npart three", (
f"expected accumulated text, got {pending.text!r}"
assert pending.text == "part two\npart three"
assert not adapter._active_sessions[session_key].is_set()
@pytest.mark.asyncio
async def test_debounce_buffers_rapid_text_then_flushes_to_pending():
adapter = _make_adapter()
adapter._busy_text_debounce_seconds = 0.05
first = _make_event("part one")
session_key = build_session_key(first.source)
adapter._active_sessions[session_key] = asyncio.Event()
await adapter.handle_message(_make_event("part two"))
assert session_key in adapter._text_debounce
assert _debounced_event(adapter, session_key).text == "part two"
assert session_key not in adapter._pending_messages
await adapter.handle_message(_make_event("part three"))
assert _debounced_event(adapter, session_key).text == "part two\npart three"
await asyncio.sleep(0.15)
assert session_key not in adapter._text_debounce
assert adapter._pending_messages[session_key].text == "part two\npart three"
@pytest.mark.asyncio
async def test_debounce_resets_timer_on_new_arrival():
adapter = _make_adapter()
adapter._busy_text_debounce_seconds = 0.1
first = _make_event("one")
session_key = build_session_key(first.source)
adapter._active_sessions[session_key] = asyncio.Event()
await adapter.handle_message(first)
task1 = adapter._text_debounce[session_key].task
assert task1 is not None
assert not task1.done()
await adapter.handle_message(_make_event("two"))
task2 = adapter._text_debounce[session_key].task
assert task2 is not None
assert task2 is not task1
await asyncio.sleep(0)
assert task1.cancelled() or task1.done()
assert adapter._text_debounce[session_key].task is task2
await adapter.handle_message(_make_event("three"))
task3 = adapter._text_debounce[session_key].task
assert task3 is not None
assert task3 is not task2
await asyncio.sleep(0.2)
assert session_key not in adapter._text_debounce
assert adapter._pending_messages[session_key].text == "one\ntwo\nthree"
@pytest.mark.asyncio
async def test_active_drain_force_flushes_debounce_before_release():
adapter = _make_adapter()
adapter._busy_text_debounce_seconds = 1.0
processed: list[str] = []
async def _handler(event):
processed.append(event.text)
if event.text == "current":
await adapter.handle_message(_make_event("follow up"))
return None
adapter._message_handler = _handler
current = _make_event("current")
session_key = build_session_key(current.source)
task = asyncio.create_task(adapter._process_message_background(current, session_key))
adapter._session_tasks[session_key] = task
await asyncio.wait_for(task, timeout=1.0)
for _ in range(20):
if processed == ["current", "follow up"] and session_key not in adapter._active_sessions:
break
await asyncio.sleep(0.05)
assert processed == ["current", "follow up"]
assert session_key not in adapter._text_debounce
assert session_key not in adapter._pending_messages
assert session_key not in adapter._active_sessions
@pytest.mark.asyncio
async def test_force_flush_cancels_timer_without_duplicate_processing():
adapter = _make_adapter()
adapter._busy_text_debounce_seconds = 0.2
event = _make_event("queued once")
session_key = build_session_key(event.source)
adapter._active_sessions[session_key] = asyncio.Event()
await adapter.handle_message(event)
timer_task = adapter._text_debounce[session_key].task
flushed = await adapter._flush_text_debounce_now(session_key)
assert flushed is True
assert session_key not in adapter._text_debounce
assert adapter._pending_messages[session_key].text == "queued once"
await asyncio.sleep(0.3)
assert timer_task is not None
assert timer_task.cancelled() or timer_task.done()
assert adapter._pending_messages[session_key].text == "queued once"
@pytest.mark.asyncio
async def test_text_debounce_does_not_merge_different_senders():
adapter = _make_adapter()
adapter._busy_text_debounce_seconds = 1.0
first = _make_event(
"from alice",
chat_type="group",
user_id="alice",
user_name="Alice",
thread_id="topic-1",
)
# Interrupt event must be signalled exactly like before.
assert adapter._active_sessions[session_key].is_set()
second = _make_event(
"from bob",
chat_type="group",
user_id="bob",
user_name="Bob",
thread_id="topic-1",
)
session_key = build_session_key(first.source)
assert session_key == build_session_key(second.source)
adapter._active_sessions[session_key] = asyncio.Event()
await adapter.handle_message(first)
await adapter.handle_message(second)
assert adapter._pending_messages[session_key].text == "from alice"
assert _debounced_event(adapter, session_key).text == "from bob"
@pytest.mark.asyncio
async def test_control_and_clarify_messages_bypass_text_debounce():
adapter = _make_adapter()
started: list[str] = []
def _fake_start(event, session_key, *, interrupt_event=None):
started.append(event.text)
return True
adapter._start_session_processing = _fake_start # type: ignore[method-assign]
await adapter.handle_message(_make_event("/status"))
assert started == ["/status"]
assert adapter._text_debounce == {}
answer = _make_event("clarify answer")
session_key = build_session_key(answer.source)
adapter._active_sessions[session_key] = asyncio.Event()
adapter._message_handler = AsyncMock(return_value=None)
with patch("tools.clarify_gateway.get_pending_for_session", return_value=object()):
await adapter.handle_message(answer)
adapter._message_handler.assert_awaited_once_with(answer)
assert session_key not in adapter._text_debounce
assert session_key not in adapter._pending_messages
@pytest.mark.asyncio
async def test_debounce_skipped_when_busy_text_mode_not_queue():
adapter = _make_adapter()
adapter._busy_text_mode = ""
event = _make_event("direct merge")
session_key = build_session_key(event.source)
adapter._active_sessions[session_key] = asyncio.Event()
await adapter.handle_message(event)
assert adapter._pending_messages[session_key].text == "direct merge"
assert session_key not in adapter._text_debounce
def test_debounce_respects_env_var_override(monkeypatch):
monkeypatch.setenv("HERMES_GATEWAY_BUSY_TEXT_DEBOUNCE_SECONDS", "2.5")
adapter = _make_initialized_adapter()
assert adapter._busy_text_debounce_seconds == 2.5
@pytest.mark.asyncio
async def test_debounce_cleanup_in_cancel_background_tasks():
adapter = _make_adapter()
adapter._busy_text_debounce_seconds = 1.0
event = _make_event("cleanup test")
session_key = build_session_key(event.source)
adapter._active_sessions[session_key] = asyncio.Event()
await adapter.handle_message(event)
assert session_key in adapter._text_debounce
await adapter.cancel_background_tasks()
assert session_key not in adapter._text_debounce
@pytest.mark.asyncio
async def test_single_followup_is_stored_as_is():
"""One TEXT follow-up still lands as the event object itself
(no spurious wrapping / mutation) guards against the merge path
breaking the simple case."""
adapter = _make_adapter()
adapter._busy_text_mode = ""
first = _make_event("only one")
session_key = build_session_key(first.source)
@ -149,4 +349,29 @@ async def test_single_followup_is_stored_as_is():
pending = adapter._pending_messages[session_key]
assert pending is first
assert pending.text == "only one"
assert adapter._active_sessions[session_key].is_set()
assert not adapter._active_sessions[session_key].is_set()
def test_adapter_defaults_to_queue_mode(monkeypatch):
monkeypatch.delenv("HERMES_GATEWAY_BUSY_TEXT_MODE", raising=False)
adapter = _make_initialized_adapter()
assert adapter._busy_text_mode == "queue"
assert adapter._is_queue_text_debounce_candidate(_make_event("hello"))
def test_adapter_is_queue_text_debounce_candidate_by_default():
adapter = _make_adapter()
assert adapter._is_queue_text_debounce_candidate(_make_event("hello world"))
def test_command_messages_bypass_debounce_even_in_queue_mode():
adapter = _make_adapter()
assert not adapter._is_queue_text_debounce_candidate(_make_event(""))
assert not adapter._is_queue_text_debounce_candidate(_make_event("/stop"))
def test_busy_text_mode_respects_env_var_override(monkeypatch):
monkeypatch.setenv("HERMES_GATEWAY_BUSY_TEXT_MODE", "interrupt")
adapter = _make_initialized_adapter()
assert adapter._busy_text_mode == "interrupt"
assert not adapter._is_queue_text_debounce_candidate(_make_event("test"))

View file

@ -14,6 +14,8 @@ Tests cover:
import asyncio
import json
import os
import stat
import time
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
@ -128,6 +130,37 @@ class TestResponseStore:
# resp_2 mapping should still be intact
assert store.get_conversation("chat-b") == "resp_2"
@pytest.mark.skipif(os.name == "nt", reason="POSIX mode bits are platform-specific")
def test_file_store_created_owner_only_under_permissive_umask(self, tmp_path):
"""response_store.db must be 0o600 on creation even under umask 022."""
db_path = tmp_path / "response_store.db"
store = None
old_umask = os.umask(0o022)
try:
store = ResponseStore(max_size=10, db_path=str(db_path))
store.put(
"resp_secret",
{
"response": {"id": "resp_secret"},
"conversation_history": [{"role": "tool", "content": "dummy-marker"}],
},
)
finally:
os.umask(old_umask)
if store is not None:
store.close()
assert stat.S_IMODE(db_path.stat().st_mode) == 0o600
# WAL/SHM sidecars are owner-only too when present. WAL mode may be
# unavailable on some filesystems (NFS/SMB) — only assert when the
# sidecar files actually exist.
for sidecar in (
db_path.with_name(db_path.name + "-wal"),
db_path.with_name(db_path.name + "-shm"),
):
if sidecar.exists():
assert stat.S_IMODE(sidecar.stat().st_mode) == 0o600
# ---------------------------------------------------------------------------
# _IdempotencyCache

View file

@ -27,8 +27,11 @@ class TestResolveRuntimeAgentKwargsAuthFallback:
def _mock_resolve(**kwargs):
call_count["n"] += 1
requested = kwargs.get("requested", "")
if requested and "codex" in str(requested).lower():
# First call = primary path (gateway reads model.provider from
# config.yaml internally; we simulate the auth failure here).
# Second call = fallback path with explicit_api_key + explicit_base_url
# supplied by gateway from fallback_model config.
if call_count["n"] == 1:
raise AuthError("Codex token refresh failed with status 401")
return {
"api_key": "fallback-key",
@ -40,8 +43,6 @@ class TestResolveRuntimeAgentKwargsAuthFallback:
"credential_pool": None,
}
monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "openai-codex")
with patch(
"hermes_cli.runtime_provider.resolve_runtime_provider",
side_effect=_mock_resolve,
@ -62,7 +63,6 @@ class TestResolveRuntimeAgentKwargsAuthFallback:
config_path.write_text("model:\n provider: openai-codex\n")
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "openai-codex")
with patch(
"hermes_cli.runtime_provider.resolve_runtime_provider",
@ -71,3 +71,46 @@ class TestResolveRuntimeAgentKwargsAuthFallback:
from gateway.run import _resolve_runtime_agent_kwargs
with pytest.raises(RuntimeError):
_resolve_runtime_agent_kwargs()
def test_legacy_fallback_is_appended_after_fallback_providers(self, tmp_path, monkeypatch):
"""When both keys exist, the legacy entry still participates in resolution."""
config_path = tmp_path / "config.yaml"
config_path.write_text(
"fallback_providers:\n"
" - provider: openrouter\n"
" model: anthropic/claude-sonnet-4.6\n"
"fallback_model:\n"
" provider: nous\n"
" model: Hermes-4\n"
)
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
calls = []
def _mock_resolve(**kwargs):
requested = kwargs.get("requested")
calls.append(requested)
if requested == "openrouter":
raise RuntimeError("openrouter unavailable")
return {
"api_key": "nous-key",
"base_url": "https://portal.nousresearch.com/v1",
"provider": "nous",
"api_mode": "chat_completions",
"command": None,
"args": None,
"credential_pool": None,
}
with patch(
"hermes_cli.runtime_provider.resolve_runtime_provider",
side_effect=_mock_resolve,
):
from gateway.run import _try_resolve_fallback_provider
result = _try_resolve_fallback_provider()
assert calls == ["openrouter", "nous"]
assert result["provider"] == "nous"
assert result["model"] == "Hermes-4"

View file

@ -15,6 +15,7 @@ from gateway.session import SessionSource, build_session_key
class DummyTelegramAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="fake-token"), Platform.TELEGRAM)
self._busy_text_mode = ""
self.sent = []
self.typing = []
self.processing_hooks = []

View file

@ -452,6 +452,14 @@ class TestBlueBubblesWebhookUrl:
adapter = _make_adapter(monkeypatch, password="W9fTC&L5JL*@")
assert "password=W9fTC%26L5JL%2A%40" in adapter._webhook_register_url
def test_register_url_for_log_masks_password(self, monkeypatch):
"""Log-safe webhook URLs must never expose the webhook password."""
adapter = _make_adapter(monkeypatch, password="W9fTC&L5JL*@")
safe_url = adapter._webhook_register_url_for_log
assert safe_url.endswith("?password=***")
assert "W9fTC" not in safe_url
assert "%26" not in safe_url
def test_register_url_omits_query_when_no_password(self, monkeypatch):
"""If no password is configured, the register URL should be the bare URL."""
monkeypatch.delenv("BLUEBUBBLES_PASSWORD", raising=False)

View file

@ -65,6 +65,7 @@ def _make_runner():
runner._pending_messages = {}
runner._busy_ack_ts = {}
runner._draining = False
runner._busy_text_mode = "interrupt"
runner.adapters = {}
runner.config = MagicMock()
runner.session_store = None
@ -84,6 +85,8 @@ def _make_adapter(platform_val="telegram"):
adapter.config = MagicMock()
adapter.config.extra = {}
adapter.platform = MagicMock(value=platform_val)
adapter._text_debounce = {}
adapter._busy_text_debounce_seconds = 0.6
return adapter
@ -186,6 +189,32 @@ class TestBusySessionAck:
assert "respond once the current task finishes" in content
assert "Interrupting" not in content
@pytest.mark.asyncio
async def test_busy_text_mode_queue_delegates_to_adapter_handle_message(self):
"""busy_text_mode=queue lets the adapter debounce text silently."""
runner, sentinel = _make_runner()
runner._busy_input_mode = "interrupt"
runner._busy_text_mode = "queue"
adapter = _make_adapter()
first = _make_event(text="part one")
second = _make_event(text="part two")
sk = build_session_key(first.source)
agent = MagicMock()
runner._running_agents[sk] = agent
runner.adapters[first.source.platform] = adapter
runner.adapters[second.source.platform] = adapter
result1 = await runner._handle_active_session_busy_message(first, sk)
result2 = await runner._handle_active_session_busy_message(second, sk)
assert result1 is False
assert result2 is False
assert sk not in adapter._pending_messages
agent.interrupt.assert_not_called()
adapter._send_with_retry.assert_not_called()
@pytest.mark.asyncio
async def test_steer_mode_calls_agent_steer_no_interrupt_no_queue(self):
"""busy_input_mode='steer' injects via agent.steer() and skips queueing."""

View file

@ -47,6 +47,7 @@ def _make_adapter():
"""Create a minimal adapter for testing the active-session guard."""
config = PlatformConfig(enabled=True, token="test-token")
adapter = _StubAdapter(config, Platform.TELEGRAM)
adapter._busy_text_mode = ""
adapter.sent_responses = []
async def _mock_handler(event):

View file

@ -45,6 +45,7 @@ def _run_gateway_import(hermes_home: Path, initial_env: dict[str, str]) -> dict[
"HERMES_AGENT_TIMEOUT",
"HERMES_AGENT_TIMEOUT_WARNING",
"HERMES_GATEWAY_BUSY_INPUT_MODE",
"HERMES_GATEWAY_BUSY_TEXT_MODE",
"HERMES_TIMEZONE",
):
v = os.environ.get(k)
@ -143,6 +144,15 @@ def test_config_display_busy_input_mode_wins_over_stale_env(hermes_home: Path) -
assert env.get("HERMES_GATEWAY_BUSY_INPUT_MODE") == "interrupt"
def test_config_display_busy_text_mode_wins_over_stale_env(hermes_home: Path) -> None:
_write_config(hermes_home, display_cfg={"busy_text_mode": "queue"})
_write_env(hermes_home, {"HERMES_GATEWAY_BUSY_TEXT_MODE": "interrupt"})
env = _run_gateway_import(hermes_home, initial_env={})
assert env.get("HERMES_GATEWAY_BUSY_TEXT_MODE") == "queue"
def test_config_timezone_wins_over_stale_env(hermes_home: Path) -> None:
_write_config(hermes_home, timezone="America/Los_Angeles")
_write_env(hermes_home, {"HERMES_TIMEZONE": "UTC"})

View file

@ -407,6 +407,36 @@ class TestConnect:
assert len(adapter._dedup._seen) == 0
assert adapter._http_client is None
@pytest.mark.asyncio
async def test_disconnect_finalizes_open_streaming_cards(self):
"""Streaming cards must be finalized before HTTP client closes."""
from unittest.mock import AsyncMock, patch
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._http_client = AsyncMock()
adapter._stream_task = None
adapter._streaming_cards = {
"chat-1": {"track-a": "last content"},
"chat-2": {"track-b": "other"},
}
close_calls = []
async def fake_close_siblings(chat_id):
# HTTP client must still be alive at call time.
assert adapter._http_client is not None, (
"HTTP client was already closed before card finalization"
)
close_calls.append(chat_id)
adapter._streaming_cards.pop(chat_id, None)
with patch.object(adapter, "_close_streaming_siblings", side_effect=fake_close_siblings):
await adapter.disconnect()
assert set(close_calls) == {"chat-1", "chat-2"}
assert adapter._streaming_cards == {}
assert adapter._http_client is None
# ---------------------------------------------------------------------------
# Platform enum

View file

@ -81,7 +81,7 @@ def _ensure_discord_mock():
_ensure_discord_mock()
from gateway.platforms.discord import _build_allowed_mentions # noqa: E402
from plugins.platforms.discord.adapter import _build_allowed_mentions # noqa: E402
# The four DISCORD_ALLOW_MENTION_* env vars that _build_allowed_mentions reads.

View file

@ -58,7 +58,7 @@ def _ensure_discord_mock():
_ensure_discord_mock()
from gateway.platforms.discord import DiscordAdapter # noqa: E402
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
from gateway.platforms.base import MessageType # noqa: E402
@ -146,10 +146,10 @@ class TestCacheDiscordImage:
att = _make_attachment_with_read(_PNG_BYTES)
with patch(
"gateway.platforms.discord.cache_image_from_bytes",
"plugins.platforms.discord.adapter.cache_image_from_bytes",
return_value="/tmp/cached.png",
) as mock_bytes, patch(
"gateway.platforms.discord.cache_image_from_url",
"plugins.platforms.discord.adapter.cache_image_from_url",
new_callable=AsyncMock,
) as mock_url:
result = await adapter._cache_discord_image(att, ".png")
@ -165,9 +165,9 @@ class TestCacheDiscordImage:
att = _make_attachment_without_read()
with patch(
"gateway.platforms.discord.cache_image_from_bytes",
"plugins.platforms.discord.adapter.cache_image_from_bytes",
) as mock_bytes, patch(
"gateway.platforms.discord.cache_image_from_url",
"plugins.platforms.discord.adapter.cache_image_from_url",
new_callable=AsyncMock,
return_value="/tmp/from_url.png",
) as mock_url:
@ -186,10 +186,10 @@ class TestCacheDiscordImage:
att = _make_attachment_with_read(b"<html>forbidden</html>")
with patch(
"gateway.platforms.discord.cache_image_from_bytes",
"plugins.platforms.discord.adapter.cache_image_from_bytes",
side_effect=ValueError("not a valid image"),
), patch(
"gateway.platforms.discord.cache_image_from_url",
"plugins.platforms.discord.adapter.cache_image_from_url",
new_callable=AsyncMock,
return_value="/tmp/fallback.png",
) as mock_url:
@ -210,10 +210,10 @@ class TestCacheDiscordAudio:
att = _make_attachment_with_read(_OGG_BYTES)
with patch(
"gateway.platforms.discord.cache_audio_from_bytes",
"plugins.platforms.discord.adapter.cache_audio_from_bytes",
return_value="/tmp/voice.ogg",
) as mock_bytes, patch(
"gateway.platforms.discord.cache_audio_from_url",
"plugins.platforms.discord.adapter.cache_audio_from_url",
new_callable=AsyncMock,
) as mock_url:
result = await adapter._cache_discord_audio(att, ".ogg")
@ -228,7 +228,7 @@ class TestCacheDiscordAudio:
att = _make_attachment_without_read()
with patch(
"gateway.platforms.discord.cache_audio_from_url",
"plugins.platforms.discord.adapter.cache_audio_from_url",
new_callable=AsyncMock,
return_value="/tmp/from_url.ogg",
) as mock_url:
@ -267,7 +267,7 @@ class TestCacheDiscordDocument:
att = _make_attachment_without_read() # no .read → forces fallback
with patch(
"gateway.platforms.discord.is_safe_url", return_value=False
"plugins.platforms.discord.adapter.is_safe_url", return_value=False
) as mock_safe, patch("aiohttp.ClientSession") as mock_session:
with pytest.raises(ValueError, match="SSRF"):
await adapter._cache_discord_document(att, ".pdf")
@ -295,7 +295,7 @@ class TestCacheDiscordDocument:
session.__aexit__ = AsyncMock(return_value=False)
with patch(
"gateway.platforms.discord.is_safe_url", return_value=True
"plugins.platforms.discord.adapter.is_safe_url", return_value=True
), patch("aiohttp.ClientSession", return_value=session):
result = await adapter._cache_discord_document(att, ".pdf")
@ -320,10 +320,10 @@ class TestHandleMessageUsesAuthenticatedRead:
adapter.handle_message = AsyncMock()
with patch(
"gateway.platforms.discord.cache_image_from_bytes",
"plugins.platforms.discord.adapter.cache_image_from_bytes",
return_value="/tmp/img_from_read.png",
), patch(
"gateway.platforms.discord.cache_image_from_url",
"plugins.platforms.discord.adapter.cache_image_from_url",
new_callable=AsyncMock,
) as mock_url_download:
att = SimpleNamespace(
@ -342,7 +342,7 @@ class TestHandleMessageUsesAuthenticatedRead:
# Patch the DMChannel isinstance check so our fake counts as DM.
monkeypatch.setattr(
"gateway.platforms.discord.discord.DMChannel",
"plugins.platforms.discord.adapter.discord.DMChannel",
_FakeDMChannel,
)
chan = _FakeDMChannel()
@ -368,7 +368,7 @@ class TestHandleMessageUsesAuthenticatedRead:
adapter.handle_message = AsyncMock()
with patch(
"gateway.platforms.discord.cache_audio_from_bytes",
"plugins.platforms.discord.adapter.cache_audio_from_bytes",
return_value="/tmp/voice_from_read.ogg",
):
att = SimpleNamespace(
@ -386,7 +386,7 @@ class TestHandleMessageUsesAuthenticatedRead:
name = "dm"
monkeypatch.setattr(
"gateway.platforms.discord.discord.DMChannel",
"plugins.platforms.discord.adapter.discord.DMChannel",
_FakeDMChannel,
)
chan = _FakeDMChannel()
@ -412,7 +412,7 @@ class TestHandleMessageUsesAuthenticatedRead:
adapter.handle_message = AsyncMock()
with patch(
"gateway.platforms.discord.cache_audio_from_bytes",
"plugins.platforms.discord.adapter.cache_audio_from_bytes",
return_value="/tmp/audio_from_read.ogg",
):
att = SimpleNamespace(
@ -430,7 +430,7 @@ class TestHandleMessageUsesAuthenticatedRead:
name = "dm"
monkeypatch.setattr(
"gateway.platforms.discord.discord.DMChannel",
"plugins.platforms.discord.adapter.discord.DMChannel",
_FakeDMChannel,
)
chan = _FakeDMChannel()

View file

@ -172,42 +172,49 @@ def test_bot_bypass_does_not_leak_to_other_platforms(monkeypatch):
# -----------------------------------------------------------------------------
# DISCORD_ALLOWED_ROLES gateway-layer bypass (#7871)
# DISCORD_ALLOWED_ROLES no longer bypasses the gateway allowlist (#30742)
#
# Prior behavior: setting DISCORD_ALLOWED_ROLES caused _is_user_authorized
# to return True for ANY Discord event, on the assumption that the adapter
# pre-filter had already validated role membership. That allowed slash
# commands and synthetic voice events to bypass role checks. PR #30742
# removed the shortcut — Discord auth now flows through the same allowlist
# / pairing / allow-all path as every other platform.
# -----------------------------------------------------------------------------
def test_discord_role_config_bypasses_gateway_allowlist(monkeypatch):
"""When DISCORD_ALLOWED_ROLES is set, _is_user_authorized must trust
the adapter's pre-filter and authorize. Without this, role-only setups
(DISCORD_ALLOWED_ROLES populated, DISCORD_ALLOWED_USERS empty) would
hit the 'no allowlists configured' branch and get rejected.
def test_discord_role_config_does_not_bypass_gateway_allowlist(monkeypatch):
"""DISCORD_ALLOWED_ROLES alone must NOT authorize at the gateway layer
(regression guard for #30742). Role-based access is enforced by the
adapter pre-filter on real message events; the gateway layer requires
an explicit allowlist hit or pairing approval.
"""
runner = _make_bare_runner()
monkeypatch.setenv("DISCORD_ALLOWED_ROLES", "1493705176387948674")
# Note: DISCORD_ALLOWED_USERS is NOT set — the entire point.
# DISCORD_ALLOWED_USERS deliberately NOT set — verifies the role
# config alone no longer grants authorization.
source = _make_discord_human_source(user_id="999888777")
assert runner._is_user_authorized(source) is True
assert runner._is_user_authorized(source) is False
def test_discord_role_config_still_authorizes_alongside_users(monkeypatch):
"""Sanity: setting both DISCORD_ALLOWED_ROLES and DISCORD_ALLOWED_USERS
doesn't break the user-id path. Users in the allowlist should still be
authorized even if they don't have a role. (OR semantics.)
def test_discord_user_allowlist_still_authorizes_when_role_is_also_configured(monkeypatch):
"""Sanity: DISCORD_ALLOWED_USERS still authorizes users on the list,
independent of DISCORD_ALLOWED_ROLES. This guards against a future
regression that ties the user-allowlist check to the (now-removed)
role bypass.
"""
runner = _make_bare_runner()
monkeypatch.setenv("DISCORD_ALLOWED_ROLES", "1493705176387948674")
monkeypatch.setenv("DISCORD_ALLOWED_USERS", "100200300")
# User on the user allowlist, no role → still authorized at gateway
# level via the role bypass (adapter already approved them).
source = _make_discord_human_source(user_id="100200300")
assert runner._is_user_authorized(source) is True
def test_discord_role_bypass_does_not_leak_to_other_platforms(monkeypatch):
def test_discord_role_config_does_not_leak_to_other_platforms(monkeypatch):
"""DISCORD_ALLOWED_ROLES must only affect Discord. Setting it should
not suddenly start authorizing Telegram users whose platform has its
own empty allowlist.

View file

@ -45,8 +45,8 @@ def _ensure_discord_mock():
_ensure_discord_mock()
import gateway.platforms.discord as discord_platform # noqa: E402
from gateway.platforms.discord import DiscordAdapter # noqa: E402
import plugins.platforms.discord.adapter as discord_platform # noqa: E402
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
class FakeDMChannel:

View file

@ -58,7 +58,7 @@ def _install_fake_agent(monkeypatch):
def _make_adapter():
_ensure_discord_mock()
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
adapter = object.__new__(DiscordAdapter)
adapter.config = MagicMock()

View file

@ -5,7 +5,7 @@ import pytest
def _make_adapter():
"""Create a minimal DiscordAdapter with mocked config."""
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
adapter = object.__new__(DiscordAdapter)
adapter.config = MagicMock()
adapter.config.extra = {}

View file

@ -26,7 +26,7 @@ if _repo not in sys.path:
# Triggers the shared discord mock from tests/gateway/conftest.py before
# importing the production module.
from gateway.platforms.discord import ( # noqa: E402
from plugins.platforms.discord.adapter import ( # noqa: E402
ClarifyChoiceView,
DiscordAdapter,
)

View file

@ -18,7 +18,7 @@ import pytest
# Trigger the shared discord mock from tests/gateway/conftest.py before
# importing the production module.
from gateway.platforms.discord import ( # noqa: E402
from plugins.platforms.discord.adapter import ( # noqa: E402
ExecApprovalView,
ModelPickerView,
SlashConfirmView,

View file

@ -67,8 +67,8 @@ def _ensure_discord_mock():
_ensure_discord_mock()
import gateway.platforms.discord as discord_platform # noqa: E402
from gateway.platforms.discord import DiscordAdapter # noqa: E402
import plugins.platforms.discord.adapter as discord_platform # noqa: E402
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
@pytest.fixture(autouse=True)

View file

@ -57,8 +57,8 @@ def _ensure_discord_mock():
_ensure_discord_mock()
import gateway.platforms.discord as discord_platform # noqa: E402
from gateway.platforms.discord import DiscordAdapter # noqa: E402
import plugins.platforms.discord.adapter as discord_platform # noqa: E402
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
# ---------------------------------------------------------------------------
@ -371,7 +371,7 @@ class TestIncomingDocumentHandling:
async def test_image_attachment_unaffected(self, adapter):
"""Image attachments should still go through the image path, not the document path."""
with patch(
"gateway.platforms.discord.cache_image_from_url",
"plugins.platforms.discord.adapter.cache_image_from_url",
new_callable=AsyncMock,
return_value="/tmp/cached_image.png",
):

View file

@ -45,8 +45,8 @@ def _ensure_discord_mock():
_ensure_discord_mock()
import gateway.platforms.discord as discord_platform # noqa: E402
from gateway.platforms.discord import DiscordAdapter # noqa: E402
import plugins.platforms.discord.adapter as discord_platform # noqa: E402
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
class FakeDMChannel:

View file

@ -14,10 +14,13 @@ class TestDiscordImportSafety:
raise ImportError("discord unavailable for test")
return original_import(name, globals, locals, fromlist, level)
monkeypatch.delitem(sys.modules, "gateway.platforms.discord", raising=False)
# Purge the cached module so the import below actually re-runs the
# module body with discord.py simulated-missing.
monkeypatch.delitem(sys.modules, "plugins.platforms.discord.adapter", raising=False)
monkeypatch.delitem(sys.modules, "plugins.platforms.discord", raising=False)
monkeypatch.setattr(builtins, "__import__", fake_import)
module = importlib.import_module("gateway.platforms.discord")
module = importlib.import_module("plugins.platforms.discord.adapter")
assert module.DISCORD_AVAILABLE is False
assert module.discord is None

View file

@ -34,7 +34,7 @@ class TestDefineDiscordViewClasses:
def test_registers_all_five_view_classes(self, monkeypatch):
"""Calling _define_discord_view_classes() must (re)define all 5 view classes."""
dp = importlib.import_module("gateway.platforms.discord")
dp = importlib.import_module("plugins.platforms.discord.adapter")
# Remove the classes to simulate the state where the module was loaded
# with DISCORD_AVAILABLE=False (the lazy-install scenario).
@ -54,7 +54,7 @@ class TestDefineDiscordViewClasses:
def test_check_discord_requirements_calls_define_on_lazy_install(self, monkeypatch):
"""check_discord_requirements() must call _define_discord_view_classes() on
a successful lazy install so view classes exist when DISCORD_AVAILABLE=True."""
dp = importlib.import_module("gateway.platforms.discord")
dp = importlib.import_module("plugins.platforms.discord.adapter")
# Simulate discord not yet available at module load.
monkeypatch.setattr(dp, "DISCORD_AVAILABLE", False)

View file

@ -1,6 +1,6 @@
import inspect
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
def test_discord_media_methods_accept_metadata_kwarg():

View file

@ -11,7 +11,7 @@ from unittest.mock import AsyncMock
import pytest
from gateway.platforms.discord import ModelPickerView
from plugins.platforms.discord.adapter import ModelPickerView
@pytest.mark.asyncio

View file

@ -8,14 +8,14 @@ class TestOpusFindLibrary:
def test_uses_find_library_first(self):
"""find_library must be the primary lookup strategy."""
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
source = inspect.getsource(DiscordAdapter.connect)
assert "find_library" in source, \
"Opus loading must use ctypes.util.find_library"
def test_homebrew_fallback_is_conditional(self):
"""Homebrew paths must only be tried when find_library returns None."""
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
source = inspect.getsource(DiscordAdapter.connect)
# Homebrew fallback must exist
assert "/opt/homebrew" in source or "homebrew" in source, \
@ -31,7 +31,7 @@ class TestOpusFindLibrary:
def test_opus_decode_error_logged(self):
"""Opus decode failure must log the error, not silently return."""
from gateway.platforms.discord import VoiceReceiver
from plugins.platforms.discord.adapter import VoiceReceiver
source = inspect.getsource(VoiceReceiver._on_packet)
assert "logger" in source, \
"_on_packet must log Opus decode errors"

View file

@ -10,7 +10,7 @@ from gateway.config import Platform, PlatformConfig
def _make_adapter():
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
adapter = object.__new__(DiscordAdapter)
adapter._platform = Platform.DISCORD
@ -60,7 +60,7 @@ async def test_concurrent_joins_do_not_double_connect():
channel.guild.id = 42
channel.connect = lambda: slow_connect(channel)
from gateway.platforms import discord as discord_mod
from plugins.platforms.discord import adapter as discord_mod
with patch.object(discord_mod, "VoiceReceiver",
MagicMock(return_value=MagicMock(start=lambda: None))):
with patch.object(discord_mod.asyncio, "ensure_future",

View file

@ -40,7 +40,7 @@ def _ensure_discord_mock():
_ensure_discord_mock()
from gateway.platforms.discord import DiscordAdapter # noqa: E402
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
class FakeTree:

View file

@ -53,7 +53,7 @@ def _ensure_discord_mock():
_ensure_discord_mock()
from gateway.platforms.discord import DiscordAdapter # noqa: E402
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
@pytest.fixture()

View file

@ -20,7 +20,7 @@ from unittest.mock import MagicMock
import pytest
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
def _set_dm_role_auth_guild(monkeypatch, guild_id=None):

View file

@ -42,7 +42,7 @@ def _ensure_discord_mock():
_ensure_discord_mock()
from gateway.platforms.discord import DiscordAdapter # noqa: E402
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
@pytest.mark.asyncio

View file

@ -85,7 +85,7 @@ def _ensure_discord_mock():
_ensure_discord_mock()
from gateway.platforms.discord import DiscordAdapter # noqa: E402
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
@pytest.fixture(autouse=True)

View file

@ -75,7 +75,7 @@ def _ensure_discord_mock():
_ensure_discord_mock()
from gateway.platforms.discord import DiscordAdapter # noqa: E402
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
class FakeTree:

View file

@ -17,7 +17,7 @@ class TestDiscordThreadPersistence:
def _make_adapter(self, tmp_path):
"""Build a minimal DiscordAdapter with HERMES_HOME pointed at tmp_path."""
from gateway.config import PlatformConfig
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
config = PlatformConfig(enabled=True, token="test-token")
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):

View file

@ -148,6 +148,15 @@ async def test_run_agent_passes_priority_processing_to_gateway_agent(monkeypatch
monkeypatch.setattr(gateway_run, "_env_path", tmp_path / ".env")
monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None)
monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {})
# ``_load_service_tier`` was refactored to call ``_load_gateway_runtime_config``
# (which wraps ``_load_gateway_config`` plus env-expansion). Since the test
# stubs ``_load_gateway_config`` to ``{}``, also stub the runtime wrapper
# directly so the priority routing assertions still exercise the live tier.
monkeypatch.setattr(
gateway_run,
"_load_gateway_runtime_config",
lambda: {"agent": {"service_tier": "fast"}},
)
monkeypatch.setattr(gateway_run, "_resolve_gateway_model", lambda config=None: "gpt-5.4")
monkeypatch.setattr(
gateway_run,

View file

@ -167,6 +167,7 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
"FEISHU_WEBHOOK_HOST": "127.0.0.1",
"FEISHU_WEBHOOK_PORT": "9001",
"FEISHU_WEBHOOK_PATH": "/hook",
"FEISHU_VERIFICATION_TOKEN": "vtok",
}, clear=True)
def test_connect_webhook_mode_starts_local_server(self):
from gateway.config import PlatformConfig
@ -1538,6 +1539,34 @@ class TestAdapterBehavior(unittest.TestCase):
self.assertEqual(response.status, 200)
adapter._on_message_event.assert_called_once()
@patch.dict(os.environ, {"FEISHU_VERIFICATION_TOKEN": "expected-token"}, clear=True)
def test_url_verification_requires_configured_verification_token(self):
"""url_verification must be rejected when token is set but mismatched.
Regression: previously the challenge was reflected before the token
check, so an unauthenticated remote could prove endpoint control by
sending an attacker-controlled challenge string.
"""
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
body = json.dumps({
"type": "url_verification",
"token": "wrong-token",
"challenge": "attacker-controlled-challenge",
}).encode("utf-8")
request = SimpleNamespace(
remote="203.0.113.10",
content_length=None,
headers={},
read=AsyncMock(return_value=body),
)
response = asyncio.run(adapter._handle_webhook_request(request))
self.assertEqual(response.status, 401)
@patch.dict(os.environ, {}, clear=True)
def test_process_inbound_message_uses_event_sender_identity_only(self):
from gateway.config import PlatformConfig
@ -3191,6 +3220,39 @@ class TestWebhookSecurity(unittest.TestCase):
response = asyncio.run(adapter._handle_webhook_request(request))
self.assertEqual(response.status, 401)
@patch.dict(os.environ, {}, clear=True)
def test_webhook_connect_requires_inbound_auth_secret(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(
PlatformConfig(
enabled=True,
extra={"app_id": "cli_app", "app_secret": "secret_app", "connection_mode": "webhook"},
)
)
self.assertFalse(asyncio.run(adapter.connect()))
@patch.dict(os.environ, {}, clear=True)
def test_webhook_loads_auth_secrets_from_platform_extra(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(
PlatformConfig(
enabled=True,
extra={
"app_id": "cli_app",
"app_secret": "secret_app",
"connection_mode": "webhook",
"verification_token": "token_from_extra",
"encrypt_key": "encrypt_from_extra",
},
)
)
self.assertEqual(adapter._verification_token, "token_from_extra")
self.assertEqual(adapter._encrypt_key, "encrypt_from_extra")
@patch.dict(os.environ, {}, clear=True)
def test_webhook_url_verification_challenge_passes_without_signature(self):
"""Challenge requests must succeed even when no encrypt_key is set."""

View file

@ -320,7 +320,7 @@ class TestResolveApproval:
}
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
await adapter._resolve_approval(1, "once", "Norbert")
await adapter._resolve_approval(1, "once", "Norbert", open_id="ou_user1", chat_id="oc_12345")
mock_resolve.assert_called_once_with("agent:main:feishu:group:oc_12345", "once")
assert 1 not in adapter._approval_state
@ -335,7 +335,7 @@ class TestResolveApproval:
}
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
await adapter._resolve_approval(2, "deny", "Alice")
await adapter._resolve_approval(2, "deny", "Alice", open_id="ou_user1", chat_id="oc_12345")
mock_resolve.assert_called_once_with("some-session", "deny")
@ -349,7 +349,7 @@ class TestResolveApproval:
}
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
await adapter._resolve_approval(3, "session", "Bob")
await adapter._resolve_approval(3, "session", "Bob", open_id="ou_user1", chat_id="oc_99")
mock_resolve.assert_called_once_with("sess-3", "session")
@ -363,7 +363,7 @@ class TestResolveApproval:
}
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
await adapter._resolve_approval(4, "always", "Carol")
await adapter._resolve_approval(4, "always", "Carol", open_id="ou_user1", chat_id="oc_55")
mock_resolve.assert_called_once_with("sess-4", "always")
@ -372,10 +372,41 @@ class TestResolveApproval:
adapter = _make_adapter()
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
await adapter._resolve_approval(99, "once", "Nobody")
await adapter._resolve_approval(99, "once", "Nobody", open_id="ou_user1", chat_id="oc_12345")
mock_resolve.assert_not_called()
@pytest.mark.asyncio
async def test_unauthorized_click_does_not_resolve(self):
adapter = _make_adapter()
adapter._admins = {"ou_admin"}
adapter._approval_state[5] = {
"session_key": "sess-5",
"message_id": "msg_005",
"chat_id": "oc_12345",
}
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
await adapter._resolve_approval(5, "once", "Mallory", open_id="ou_intruder", chat_id="oc_12345")
mock_resolve.assert_not_called()
assert 5 in adapter._approval_state
@pytest.mark.asyncio
async def test_chat_mismatch_does_not_resolve(self):
adapter = _make_adapter()
adapter._approval_state[6] = {
"session_key": "sess-6",
"message_id": "msg_006",
"chat_id": "oc_expected",
}
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
await adapter._resolve_approval(6, "session", "Norbert", open_id="ou_user1", chat_id="oc_wrong")
mock_resolve.assert_not_called()
assert 6 in adapter._approval_state
# ===========================================================================
# _handle_card_action_event — non-approval card actions
# ===========================================================================
@ -448,6 +479,12 @@ class TestCardActionCallbackResponse:
adapter = _make_adapter()
adapter._loop = MagicMock()
adapter._loop.is_closed = MagicMock(return_value=False)
adapter._allowed_group_users = {"ou_bob"}
adapter._approval_state[1] = {
"session_key": "sess-1",
"message_id": "msg-1",
"chat_id": "oc_12345",
}
data = _make_card_action_data(
{"hermes_action": "approve_once", "approval_id": 1},
open_id="ou_bob",
@ -469,6 +506,12 @@ class TestCardActionCallbackResponse:
adapter = _make_adapter()
adapter._loop = MagicMock()
adapter._loop.is_closed = MagicMock(return_value=False)
adapter._allowed_group_users = {"ou_user1"}
adapter._approval_state[2] = {
"session_key": "sess-2",
"message_id": "msg-2",
"chat_id": "oc_12345",
}
data = _make_card_action_data(
{"hermes_action": "deny", "approval_id": 2},
)
@ -510,6 +553,12 @@ class TestCardActionCallbackResponse:
adapter = _make_adapter()
adapter._loop = MagicMock()
adapter._loop.is_closed = MagicMock(return_value=False)
adapter._allowed_group_users = {"ou_unknown"}
adapter._approval_state[3] = {
"session_key": "sess-3",
"message_id": "msg-3",
"chat_id": "oc_12345",
}
data = _make_card_action_data(
{"hermes_action": "approve_session", "approval_id": 3},
open_id="ou_unknown",
@ -525,6 +574,12 @@ class TestCardActionCallbackResponse:
adapter = _make_adapter()
adapter._loop = MagicMock()
adapter._loop.is_closed = MagicMock(return_value=False)
adapter._allowed_group_users = {"ou_expired"}
adapter._approval_state[4] = {
"session_key": "sess-4",
"message_id": "msg-4",
"chat_id": "oc_12345",
}
data = _make_card_action_data(
{"hermes_action": "approve_once", "approval_id": 4},
open_id="ou_expired",
@ -538,6 +593,51 @@ class TestCardActionCallbackResponse:
assert "Old Name" not in card["elements"][0]["content"]
assert "ou_expired" in card["elements"][0]["content"]
def test_rejects_approval_click_from_unauthorized_user(self, _patch_callback_card_types):
adapter = _make_adapter()
adapter._loop = MagicMock()
adapter._loop.is_closed = MagicMock(return_value=False)
adapter._allowed_group_users = {"ou_allowed"}
adapter._approval_state[5] = {
"session_key": "sess-5",
"message_id": "msg-5",
"chat_id": "oc_12345",
}
data = _make_card_action_data(
{"hermes_action": "approve_once", "approval_id": 5},
open_id="ou_attacker",
)
with patch("asyncio.run_coroutine_threadsafe") as mock_submit:
response = adapter._on_card_action_trigger(data)
assert response is not None
assert response.card is None
mock_submit.assert_not_called()
def test_rejects_approval_click_when_callback_chat_mismatches(self, _patch_callback_card_types):
adapter = _make_adapter()
adapter._loop = MagicMock()
adapter._loop.is_closed = MagicMock(return_value=False)
adapter._allowed_group_users = {"ou_bob"}
adapter._approval_state[6] = {
"session_key": "sess-6",
"message_id": "msg-6",
"chat_id": "oc_expected",
}
data = _make_card_action_data(
{"hermes_action": "approve_once", "approval_id": 6},
chat_id="oc_mismatch",
open_id="ou_bob",
)
with patch("asyncio.run_coroutine_threadsafe") as mock_submit:
response = adapter._on_card_action_trigger(data)
assert response is not None
assert response.card is None
mock_submit.assert_not_called()
def test_returns_card_for_update_prompt_yes(self, _patch_callback_card_types):
adapter = _make_adapter()
adapter._loop = MagicMock()

View file

@ -103,6 +103,7 @@ class TestInterruptKeyConsistency:
async def test_handle_message_stores_under_session_key(self):
"""handle_message stores pending messages under session_key, not chat_id."""
adapter = StubAdapter()
adapter._busy_text_mode = ""
adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None))
source = _source("-1001234", "group")
@ -120,8 +121,8 @@ class TestInterruptKeyConsistency:
# NOT stored under chat_id
assert source.chat_id not in adapter._pending_messages
# Interrupt event was set
assert adapter._active_sessions[session_key].is_set()
# Text follow-ups queue silently and do not interrupt the active turn.
assert adapter._active_sessions[session_key].is_set() is False
@pytest.mark.asyncio
async def test_photo_followup_is_queued_without_interrupt(self):

View file

@ -0,0 +1,210 @@
"""Tests for the gateway loop-level transient-network-error safety net.
Issues #31066 / #31110: unhandled ``telegram.error.TimedOut`` (or peer
``NetworkError`` / ``httpx`` connection error) propagating to the
asyncio event loop killed the gateway process, taking down every
profile attached to the same runner. The safety net installed in
:func:`gateway.run.start_gateway` catches the transient crash class
and logs+swallows it; non-transient errors still surface.
These tests pin the classifier and the loop handler so the safety net
can't silently regress to swallowing every exception.
"""
from __future__ import annotations
import asyncio
import logging
import pytest
from gateway.run import (
_gateway_loop_exception_handler,
_is_transient_network_error,
)
# ----- Fake exception classes that mimic the real wire types ----------
# We avoid importing telegram / httpx here so the test runs in environments
# without those packages installed (the classifier matches on class name).
class TimedOut(Exception):
"""Stand-in for ``telegram.error.TimedOut``."""
class NetworkError(Exception):
"""Stand-in for ``telegram.error.NetworkError``."""
class ConnectError(Exception):
"""Stand-in for ``httpx.ConnectError``."""
class ReadTimeout(Exception):
"""Stand-in for ``httpx.ReadTimeout``."""
class PoolTimeout(Exception):
"""Stand-in for ``httpx.PoolTimeout``."""
class ClientConnectorError(Exception):
"""Stand-in for ``aiohttp.ClientConnectorError``."""
class SomeUnrelatedBug(Exception):
"""A non-transient error that should NOT be swallowed."""
# ---------------------------------------------------------------------
# Classifier
# ---------------------------------------------------------------------
@pytest.mark.parametrize(
"exc_cls",
[
TimedOut,
NetworkError,
ConnectError,
ReadTimeout,
PoolTimeout,
ClientConnectorError,
],
)
def test_transient_classifier_matches_known_network_errors(exc_cls):
"""Every well-known transient network exception class is classified."""
assert _is_transient_network_error(exc_cls("boom")) is True
def test_transient_classifier_rejects_unrelated_errors():
"""Real bugs (ValueError, KeyError, custom app errors) are NOT swallowed."""
for exc in (ValueError("bad"), KeyError("missing"), SomeUnrelatedBug("x")):
assert _is_transient_network_error(exc) is False
def test_transient_classifier_unwraps_cause_chain():
"""A NetworkError wrapping a ConnectError is still classified."""
inner = ConnectError("connection refused")
outer = NetworkError("upstream failed")
outer.__cause__ = inner
assert _is_transient_network_error(outer) is True
def test_transient_classifier_unwraps_context_chain():
"""Implicit ``__context__`` wrapping is also unwrapped."""
try:
try:
raise TimedOut("upstream timeout")
except TimedOut:
# Re-raise something else with the original as implicit context
raise SomeUnrelatedBug("wrapper")
except SomeUnrelatedBug as e:
wrapped = e
# The wrapper class name is not transient, but the chained context is.
assert _is_transient_network_error(wrapped) is True
def test_transient_classifier_does_not_infinite_loop_on_cyclic_cause():
"""A pathological self-referential cause chain terminates."""
exc = SomeUnrelatedBug("loop")
exc.__cause__ = exc # cycle
# Must return without hanging.
assert _is_transient_network_error(exc) is False
# ---------------------------------------------------------------------
# Loop handler
# ---------------------------------------------------------------------
def test_handler_swallows_transient_error_and_logs_warning(caplog):
"""Transient errors are logged at WARNING but not re-raised."""
loop = asyncio.new_event_loop()
try:
with caplog.at_level(logging.WARNING, logger="gateway.run"):
_gateway_loop_exception_handler(
loop,
{
"message": "Task exception was never retrieved",
"exception": TimedOut("Timed out"),
},
)
# Warning emitted, exception class name appears in the log.
assert any("TimedOut" in r.message for r in caplog.records)
finally:
loop.close()
def test_handler_delegates_unknown_errors_to_default(monkeypatch):
"""A non-transient error is forwarded to ``loop.default_exception_handler``."""
loop = asyncio.new_event_loop()
try:
forwarded: list[dict] = []
def fake_default(ctx):
forwarded.append(ctx)
monkeypatch.setattr(loop, "default_exception_handler", fake_default)
context = {
"message": "Something else broke",
"exception": SomeUnrelatedBug("real bug"),
}
_gateway_loop_exception_handler(loop, context)
assert forwarded == [context]
finally:
loop.close()
def test_handler_tolerates_missing_exception_key(monkeypatch):
"""Contexts without an ``exception`` key fall through to the default handler."""
loop = asyncio.new_event_loop()
try:
forwarded: list[dict] = []
monkeypatch.setattr(
loop, "default_exception_handler", lambda ctx: forwarded.append(ctx)
)
ctx = {"message": "warning without exception"}
_gateway_loop_exception_handler(loop, ctx)
assert forwarded == [ctx]
finally:
loop.close()
# ---------------------------------------------------------------------
# End-to-end: task-level
# ---------------------------------------------------------------------
def test_unhandled_transient_error_in_task_does_not_propagate_to_loop():
"""Smoke test the wiring as a loop would actually use it.
Schedules a task that raises TimedOut and is never awaited. With the
handler installed, the loop completes normally and logs a warning
instead of dying. Without the handler, asyncio would emit
``Task exception was never retrieved`` and (depending on Python's
debug mode) potentially escalate.
"""
async def raiser():
raise TimedOut("upstream timeout")
async def main():
loop = asyncio.get_running_loop()
loop.set_exception_handler(_gateway_loop_exception_handler)
task = loop.create_task(raiser())
# Give the task a tick to run and raise.
await asyncio.sleep(0)
# Don't await ``task`` — let it become an unhandled-exception task.
del task
import gc
gc.collect()
await asyncio.sleep(0)
# If the safety net works, this returns cleanly. If not, the test
# would still pass (asyncio's default is a warning, not a crash) —
# the real assertion is that no unhandled exception escapes the
# ``run`` boundary.
asyncio.run(main())

View file

@ -797,6 +797,79 @@ class TestMatrixRequirements:
with patch("tools.lazy_deps.ensure", side_effect=ImportError("mautrix unavailable")):
assert matrix_mod.check_matrix_requirements() is False
def test_check_e2ee_deps_requires_asyncpg(self, monkeypatch):
"""E2EE deps check must reject when asyncpg is missing — even if olm is present.
Regression for #31116: ``mautrix[encryption]`` extra installs python-olm
but NOT asyncpg/aiosqlite, which are required by mautrix's crypto store
at connect time. ``_check_e2ee_deps`` previously only tested
``OlmMachine`` import and returned True, so the failure manifested as
a confusing ``No module named 'asyncpg'`` deep in
``MatrixAdapter.connect()``.
"""
from gateway.platforms.matrix import _check_e2ee_deps
import builtins
real_import = builtins.__import__
def _blocking_import(name, *args, **kwargs):
if name == "asyncpg" or name.startswith("asyncpg."):
raise ImportError("blocked for test")
return real_import(name, *args, **kwargs)
with patch.object(builtins, "__import__", _blocking_import):
assert _check_e2ee_deps() is False
def test_check_e2ee_deps_requires_aiosqlite(self):
"""E2EE deps check must reject when aiosqlite is missing.
Mautrix's ``Database.create("sqlite:///...")`` driver lookup imports
aiosqlite lazily without it, connect fails at ``crypto_db.start()``.
"""
from gateway.platforms.matrix import _check_e2ee_deps
import builtins
real_import = builtins.__import__
def _blocking_import(name, *args, **kwargs):
if name == "aiosqlite" or name.startswith("aiosqlite."):
raise ImportError("blocked for test")
return real_import(name, *args, **kwargs)
with patch.object(builtins, "__import__", _blocking_import):
assert _check_e2ee_deps() is False
def test_check_requirements_runs_lazy_install_when_partial(self, monkeypatch):
"""When mautrix is installed but asyncpg/aiosqlite are missing,
check_matrix_requirements must still run the lazy installer.
Regression for #31116: the previous ``try: import mautrix`` gate
short-circuited the install of the OTHER 4 platform.matrix packages,
so a partial install (mautrix only) was treated as fully installed.
"""
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.delenv("MATRIX_ENCRYPTION", raising=False)
from gateway.platforms import matrix as matrix_mod
# Simulate "mautrix installed, asyncpg missing" → feature_missing
# returns a non-empty tuple → ensure_and_bind MUST be called.
called = {"ensure_and_bind": False}
def _fake_ensure_and_bind(feature, importer, target_globals, **kwargs):
called["ensure_and_bind"] = True
assert feature == "platform.matrix"
return True # Pretend install succeeded.
with patch("tools.lazy_deps.feature_missing", return_value=("asyncpg==0.31.0",)), \
patch("tools.lazy_deps.ensure_and_bind", side_effect=_fake_ensure_and_bind):
matrix_mod.check_matrix_requirements()
assert called["ensure_and_bind"], (
"check_matrix_requirements must call ensure_and_bind whenever ANY "
"platform.matrix dep is missing, not just when mautrix itself is "
"missing (#31116)"
)
# ---------------------------------------------------------------------------
# Access-token auth / E2EE bootstrap

View file

@ -6,7 +6,7 @@ import json
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig, _apply_env_overrides
from gateway.platforms.msgraph_webhook import MSGraphWebhookAdapter
from gateway.platforms.msgraph_webhook import AIOHTTP_AVAILABLE, MSGraphWebhookAdapter
def _make_adapter(**extra_overrides) -> MSGraphWebhookAdapter:
@ -70,6 +70,16 @@ class TestMSGraphWebhookConfig:
class TestMSGraphValidationHandshake:
@pytest.mark.anyio
async def test_connect_requires_client_state(self):
if not AIOHTTP_AVAILABLE:
pytest.skip("aiohttp not installed")
adapter = MSGraphWebhookAdapter(PlatformConfig(enabled=True, extra={}))
connected = await adapter.connect()
assert connected is False
# is_connected is a @property on the base adapter, not a method.
assert adapter.is_connected is False
@pytest.mark.anyio
async def test_validation_token_echo_on_get(self):
adapter = _make_adapter()
@ -99,6 +109,22 @@ class TestMSGraphValidationHandshake:
class TestMSGraphNotifications:
@pytest.mark.anyio
async def test_missing_client_state_is_auth_rejected(self):
adapter = _make_adapter(client_state=None)
payload = {
"value": [
{
"id": "notif-no-client-state",
"subscriptionId": "sub-1",
"changeType": "updated",
"resource": "communications/onlineMeetings/meeting-1",
}
]
}
resp = await adapter._handle_notification(_FakeRequest(json_payload=payload))
assert resp.status == 403
@pytest.mark.anyio
async def test_valid_notification_accepted_and_scheduled(self):
adapter = _make_adapter()

View file

@ -0,0 +1,943 @@
"""Tests for the ntfy platform-plugin adapter.
Loaded via the ``_plugin_adapter_loader`` helper so this lives under
``plugin_adapter_ntfy`` in ``sys.modules`` and cannot collide with
sibling platform-plugin tests on the same xdist worker.
Most tests target the adapter class directly. The plugin-shape tests
(``register()``, ``_env_enablement``, ``_standalone_send``, registry
presence) replace the core-file grep tests from the original PR the
ntfy adapter no longer modifies ``gateway/config.py``, ``gateway/run.py``,
``cron/scheduler.py``, ``toolsets.py``, etc. Everything routes through
the ``platform_registry``.
"""
from __future__ import annotations
import asyncio
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import PlatformConfig
from tests.gateway._plugin_adapter_loader import load_plugin_adapter
_ntfy = load_plugin_adapter("ntfy")
NtfyAdapter = _ntfy.NtfyAdapter
check_requirements = _ntfy.check_requirements
validate_config = _ntfy.validate_config
is_connected = _ntfy.is_connected
register = _ntfy.register
_env_enablement = _ntfy._env_enablement
_standalone_send = _ntfy._standalone_send
DEFAULT_SERVER = _ntfy.DEFAULT_SERVER
DEDUP_WINDOW_SECONDS = _ntfy.DEDUP_WINDOW_SECONDS
DEDUP_MAX_SIZE = _ntfy.DEDUP_MAX_SIZE
MAX_MESSAGE_LENGTH = _ntfy.MAX_MESSAGE_LENGTH
def _run(coro):
"""Run an async coroutine synchronously."""
return asyncio.get_event_loop().run_until_complete(coro)
# ---------------------------------------------------------------------------
# 1. Platform enum (plugin-discovered, not bundled)
# ---------------------------------------------------------------------------
def test_platform_enum_resolves_via_plugin_scan():
"""The plugin filesystem scan should expose Platform("ntfy")."""
from gateway.config import Platform
p = Platform("ntfy")
assert p.value == "ntfy"
# Identity stability — repeated lookups return the same pseudo-member
assert Platform("ntfy") is p
# ---------------------------------------------------------------------------
# 2. check_requirements / validate_config / is_connected
# ---------------------------------------------------------------------------
class TestNtfyRequirements:
def test_returns_false_when_httpx_unavailable(self, monkeypatch):
monkeypatch.setenv("NTFY_TOPIC", "hermes-test")
monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", False)
assert check_requirements() is False
def test_returns_false_when_topic_not_set(self, monkeypatch):
monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", True)
monkeypatch.delenv("NTFY_TOPIC", raising=False)
assert check_requirements() is False
def test_returns_true_when_topic_set_via_env(self, monkeypatch):
monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", True)
monkeypatch.setenv("NTFY_TOPIC", "hermes-test")
assert check_requirements() is True
def test_validate_config_requires_topic(self, monkeypatch):
monkeypatch.delenv("NTFY_TOPIC", raising=False)
assert validate_config(PlatformConfig(enabled=True, extra={})) is False
assert validate_config(
PlatformConfig(enabled=True, extra={"topic": "t"})
) is True
def test_is_connected_from_extra(self, monkeypatch):
monkeypatch.delenv("NTFY_TOPIC", raising=False)
assert is_connected(PlatformConfig(enabled=True, extra={"topic": "t"})) is True
assert is_connected(PlatformConfig(enabled=True, extra={})) is False
def test_is_connected_from_env(self, monkeypatch):
monkeypatch.setenv("NTFY_TOPIC", "env-topic")
assert is_connected(PlatformConfig(enabled=True, extra={})) is True
# ---------------------------------------------------------------------------
# 3. Adapter init
# ---------------------------------------------------------------------------
class TestNtfyAdapterInit:
def test_default_server_url(self, monkeypatch):
monkeypatch.delenv("NTFY_SERVER_URL", raising=False)
config = PlatformConfig(enabled=True, extra={"topic": "hermes-in"})
adapter = NtfyAdapter(config)
assert adapter._server == DEFAULT_SERVER.rstrip("/")
def test_topic_read_from_extra(self):
config = PlatformConfig(enabled=True, extra={"topic": "my-topic"})
adapter = NtfyAdapter(config)
assert adapter._topic == "my-topic"
def test_topic_read_from_env(self, monkeypatch):
monkeypatch.setenv("NTFY_TOPIC", "env-topic")
config = PlatformConfig(enabled=True, extra={})
adapter = NtfyAdapter(config)
assert adapter._topic == "env-topic"
def test_publish_topic_falls_back_to_topic(self, monkeypatch):
monkeypatch.delenv("NTFY_PUBLISH_TOPIC", raising=False)
config = PlatformConfig(enabled=True, extra={"topic": "hermes-in"})
adapter = NtfyAdapter(config)
assert adapter._publish_topic == "hermes-in"
def test_publish_topic_uses_extra_value(self):
config = PlatformConfig(
enabled=True,
extra={"topic": "hermes-in", "publish_topic": "hermes-out"},
)
adapter = NtfyAdapter(config)
assert adapter._publish_topic == "hermes-out"
def test_token_read_from_extra(self):
config = PlatformConfig(enabled=True, extra={"topic": "t", "token": "tok-123"})
adapter = NtfyAdapter(config)
assert adapter._token == "tok-123"
def test_token_read_from_env(self, monkeypatch):
monkeypatch.setenv("NTFY_TOKEN", "env-token")
config = PlatformConfig(enabled=True, extra={"topic": "t"})
adapter = NtfyAdapter(config)
assert adapter._token == "env-token"
def test_server_trailing_slash_stripped(self):
config = PlatformConfig(
enabled=True,
extra={"topic": "t", "server": "https://ntfy.example.com/"},
)
adapter = NtfyAdapter(config)
assert not adapter._server.endswith("/")
def test_initial_state(self):
config = PlatformConfig(enabled=True, extra={"topic": "t"})
adapter = NtfyAdapter(config)
assert adapter._stream_task is None
assert adapter._http_client is None
assert adapter._seen_messages == {}
# ---------------------------------------------------------------------------
# 4. Auth headers
# ---------------------------------------------------------------------------
class TestAuthHeaders:
def _make_adapter(self, token=""):
config = PlatformConfig(enabled=True, extra={"topic": "t", "token": token})
return NtfyAdapter(config)
def test_no_token_returns_empty_dict(self):
adapter = self._make_adapter(token="")
assert adapter._auth_headers() == {}
def test_bearer_token_for_plain_token(self):
adapter = self._make_adapter(token="myapitoken")
headers = adapter._auth_headers()
assert headers["Authorization"] == "Bearer myapitoken"
def test_basic_auth_for_user_colon_password(self):
adapter = self._make_adapter(token="user:pass")
headers = adapter._auth_headers()
assert headers["Authorization"].startswith("Basic ")
import base64
expected = "Basic " + base64.b64encode(b"user:pass").decode()
assert headers["Authorization"] == expected
def test_bearer_token_used_when_no_colon(self):
adapter = self._make_adapter(token="noColonHere")
headers = adapter._auth_headers()
assert headers["Authorization"] == "Bearer noColonHere"
def test_auth_header_key_is_authorization(self):
adapter = self._make_adapter(token="tok")
headers = adapter._auth_headers()
assert list(headers.keys()) == ["Authorization"]
# ---------------------------------------------------------------------------
# 5. Deduplication
# ---------------------------------------------------------------------------
class TestDeduplication:
def _make_adapter(self):
return NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"}))
def test_first_message_not_duplicate(self):
adapter = self._make_adapter()
assert adapter._is_duplicate("msg-1") is False
def test_second_occurrence_is_duplicate(self):
adapter = self._make_adapter()
adapter._is_duplicate("msg-1")
assert adapter._is_duplicate("msg-1") is True
def test_different_ids_not_duplicate(self):
adapter = self._make_adapter()
adapter._is_duplicate("msg-1")
assert adapter._is_duplicate("msg-2") is False
def test_many_messages_recorded(self):
adapter = self._make_adapter()
for i in range(50):
adapter._is_duplicate(f"msg-{i}")
assert len(adapter._seen_messages) == 50
def test_cache_pruned_on_overflow(self):
adapter = self._make_adapter()
for i in range(DEDUP_MAX_SIZE + 20):
adapter._is_duplicate(f"msg-{i}")
assert len(adapter._seen_messages) <= DEDUP_MAX_SIZE + 20
def test_expired_id_can_be_seen_again(self):
import time
adapter = self._make_adapter()
adapter._seen_messages["old-msg"] = time.time() - DEDUP_WINDOW_SECONDS - 1
for i in range(DEDUP_MAX_SIZE + 1):
adapter._is_duplicate(f"fill-{i}")
assert adapter._is_duplicate("old-msg") is False
# ---------------------------------------------------------------------------
# 6. connect() / disconnect()
# ---------------------------------------------------------------------------
class TestConnect:
def test_connect_fails_when_httpx_unavailable(self, monkeypatch):
monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", False)
adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"}))
result = _run(adapter.connect())
assert result is False
def test_connect_fails_when_no_topic(self, monkeypatch):
monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", True)
monkeypatch.delenv("NTFY_TOPIC", raising=False)
config = PlatformConfig(enabled=True, extra={})
adapter = NtfyAdapter(config)
result = _run(adapter.connect())
assert result is False
def test_connect_starts_stream_task(self, monkeypatch):
monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", True)
config = PlatformConfig(enabled=True, extra={"topic": "hermes-test"})
adapter = NtfyAdapter(config)
with patch.object(adapter, "_run_stream", new_callable=AsyncMock):
with patch.object(_ntfy, "httpx") as mock_httpx:
mock_httpx.AsyncClient.return_value = MagicMock()
result = _run(adapter.connect())
assert result is True
assert adapter._stream_task is not None
adapter._stream_task.cancel()
try:
_run(adapter._stream_task)
except (asyncio.CancelledError, Exception):
pass
def test_disconnect_clears_state(self):
adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"}))
adapter._seen_messages["x"] = 1.0
adapter._http_client = AsyncMock()
adapter._stream_task = None
adapter._running = True
_run(adapter.disconnect())
assert adapter._seen_messages == {}
assert adapter._http_client is None
assert adapter._running is False
def test_disconnect_cancels_stream_task(self):
adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"}))
async def _hang():
await asyncio.sleep(9999)
loop = asyncio.get_event_loop()
adapter._stream_task = loop.create_task(_hang())
adapter._http_client = AsyncMock()
adapter._running = True
_run(adapter.disconnect())
assert adapter._stream_task is None
# ---------------------------------------------------------------------------
# 7. send()
# ---------------------------------------------------------------------------
class TestSend:
def _make_adapter(self, topic="hermes-in", publish_topic="", token="", markdown=False):
extra: dict = {"topic": topic, "token": token}
if publish_topic:
extra["publish_topic"] = publish_topic
if markdown:
extra["markdown"] = True
return NtfyAdapter(PlatformConfig(enabled=True, extra=extra))
def test_send_fails_without_http_client(self):
adapter = self._make_adapter()
result = _run(adapter.send("hermes-in", "hello"))
assert result.success is False
assert "not initialized" in result.error.lower()
def test_send_posts_to_publish_topic(self):
adapter = self._make_adapter(topic="hermes-in", publish_topic="hermes-out")
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"id": "abc123"}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
adapter._http_client = mock_client
result = _run(adapter.send("hermes-in", "Hello ntfy!"))
assert result.success is True
assert result.message_id == "abc123"
posted_url = mock_client.post.call_args[0][0]
assert posted_url.endswith("/hermes-out")
def test_send_falls_back_to_subscribe_topic(self):
adapter = self._make_adapter(topic="hermes-in")
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
adapter._http_client = mock_client
result = _run(adapter.send("hermes-in", "Hello!"))
assert result.success is True
posted_url = mock_client.post.call_args[0][0]
assert posted_url.endswith("/hermes-in")
def test_send_uses_metadata_publish_topic(self):
adapter = self._make_adapter(topic="hermes-in")
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
adapter._http_client = mock_client
result = _run(adapter.send(
"hermes-in", "Hi!", metadata={"publish_topic": "override-out"}
))
assert result.success is True
posted_url = mock_client.post.call_args[0][0]
assert posted_url.endswith("/override-out")
def test_send_handles_http_error_status(self):
adapter = self._make_adapter(topic="hermes-in")
mock_resp = MagicMock()
mock_resp.status_code = 403
mock_resp.text = "Forbidden"
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
adapter._http_client = mock_client
result = _run(adapter.send("hermes-in", "Hello!"))
assert result.success is False
assert "403" in result.error
def test_send_handles_timeout(self):
adapter = self._make_adapter(topic="hermes-in")
class _FakeTimeout(Exception):
pass
fake_httpx = MagicMock()
fake_httpx.TimeoutException = _FakeTimeout
mock_client = AsyncMock()
mock_client.post = AsyncMock(side_effect=_FakeTimeout("timed out"))
adapter._http_client = mock_client
with patch.object(_ntfy, "httpx", fake_httpx):
result = _run(adapter.send("hermes-in", "Hello!"))
assert result.success is False
assert "timeout" in result.error.lower()
def test_send_truncates_to_max_length(self):
adapter = self._make_adapter(topic="t")
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
adapter._http_client = mock_client
long_msg = "x" * (MAX_MESSAGE_LENGTH + 500)
_run(adapter.send("t", long_msg))
posted_body = mock_client.post.call_args[1]["content"]
assert len(posted_body.decode()) <= MAX_MESSAGE_LENGTH
def test_send_typing_is_noop(self):
adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"}))
_run(adapter.send_typing("t")) # must not raise
def test_get_chat_info_returns_dict(self):
adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"}))
info = _run(adapter.get_chat_info("hermes-in"))
assert info["name"] == "hermes-in"
assert info["type"] == "dm"
def test_send_includes_bearer_auth_header(self):
adapter = self._make_adapter(topic="hermes-in", token="mytoken")
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
adapter._http_client = mock_client
_run(adapter.send("hermes-in", "secure message"))
call_headers = mock_client.post.call_args[1]["headers"]
assert call_headers.get("Authorization") == "Bearer mytoken"
def test_send_emits_markdown_header_when_enabled(self):
adapter = self._make_adapter(topic="hermes-in", markdown=True)
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
adapter._http_client = mock_client
_run(adapter.send("hermes-in", "**bold**"))
call_headers = mock_client.post.call_args[1]["headers"]
assert call_headers.get("X-Markdown") == "true"
def test_send_omits_markdown_header_when_disabled(self):
adapter = self._make_adapter(topic="hermes-in", markdown=False)
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
adapter._http_client = mock_client
_run(adapter.send("hermes-in", "plain"))
call_headers = mock_client.post.call_args[1]["headers"]
assert "X-Markdown" not in call_headers
# ---------------------------------------------------------------------------
# 8. Inbound message processing (identity invariant — security-critical)
# ---------------------------------------------------------------------------
class TestOnMessage:
def _make_adapter(self):
return NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "hermes-in"}))
def test_message_dispatched_to_handler(self):
adapter = self._make_adapter()
calls = []
async def handler(event):
calls.append(event)
adapter.set_message_handler(handler)
event = {
"id": "evt-001",
"event": "message",
"topic": "hermes-in",
"message": "Hello from ntfy",
"time": 1700000000,
}
_run(adapter._on_message(event))
assert len(calls) == 1
assert calls[0].text == "Hello from ntfy"
def test_empty_message_skipped(self):
adapter = self._make_adapter()
calls = []
async def handler(event):
calls.append(event)
adapter.set_message_handler(handler)
_run(adapter._on_message({
"id": "x", "event": "message", "topic": "t", "message": "", "time": None
}))
assert calls == []
def test_duplicate_message_skipped(self):
adapter = self._make_adapter()
calls = []
async def handler(event):
calls.append(event)
adapter.set_message_handler(handler)
event = {"id": "dup-1", "event": "message", "topic": "hermes-in", "message": "hi", "time": None}
_run(adapter._on_message(event))
_run(adapter._on_message(event))
assert len(calls) == 1
def test_timestamp_parsed_from_event(self):
from datetime import timezone
adapter = self._make_adapter()
captured = []
async def handler(event):
captured.append(event)
adapter.set_message_handler(handler)
_run(adapter._on_message({
"id": "ts-1",
"event": "message",
"topic": "hermes-in",
"message": "ping",
"time": 1700000000,
}))
ts = captured[0].timestamp
assert ts.tzinfo == timezone.utc
def test_message_id_set_from_event(self):
adapter = self._make_adapter()
captured = []
async def handler(event):
captured.append(event)
adapter.set_message_handler(handler)
_run(adapter._on_message({
"id": "ntfy-id-42",
"event": "message",
"topic": "hermes-in",
"message": "test",
"time": None,
}))
assert captured[0].message_id == "ntfy-id-42"
def test_title_not_used_as_user_id(self):
"""title field must not be used for identity — it is publisher-controlled."""
adapter = self._make_adapter()
captured = []
async def handler(event):
captured.append(event)
adapter.set_message_handler(handler)
_run(adapter._on_message({
"id": "u-1",
"event": "message",
"topic": "hermes-in",
"message": "hello",
"title": "Alice",
"time": None,
}))
assert captured[0].source.user_id == "hermes-in"
assert captured[0].source.user_name == "hermes-in"
def test_unknown_publisher_cannot_impersonate_allowed_user(self):
"""An unknown publisher setting title=admin must not gain admin identity."""
adapter = self._make_adapter()
captured = []
async def handler(event):
captured.append(event)
adapter.set_message_handler(handler)
_run(adapter._on_message({
"id": "u-2",
"event": "message",
"topic": "hermes-in",
"message": "sensitive command",
"title": "admin",
"time": None,
}))
assert captured[0].source.user_id == "hermes-in"
assert captured[0].source.user_id != "admin"
def test_source_chat_id_is_topic(self):
adapter = self._make_adapter()
captured = []
async def handler(event):
captured.append(event)
adapter.set_message_handler(handler)
_run(adapter._on_message({
"id": "s-1",
"event": "message",
"topic": "hermes-in",
"message": "hello",
"time": None,
}))
assert captured[0].source.chat_id == "hermes-in"
# ---------------------------------------------------------------------------
# 9. _env_enablement() — env-only auto-config
# ---------------------------------------------------------------------------
class TestEnvEnablement:
def test_returns_none_without_topic(self, monkeypatch):
monkeypatch.delenv("NTFY_TOPIC", raising=False)
assert _env_enablement() is None
def test_seeds_topic_and_server(self, monkeypatch):
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
monkeypatch.delenv("NTFY_SERVER_URL", raising=False)
seed = _env_enablement()
assert seed is not None
assert seed["topic"] == "hermes-in"
assert seed["server"] == DEFAULT_SERVER
def test_custom_server_url(self, monkeypatch):
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
monkeypatch.setenv("NTFY_SERVER_URL", "https://ntfy.example.com/")
seed = _env_enablement()
assert seed["server"] == "https://ntfy.example.com" # trailing slash stripped
def test_publish_topic_seeded(self, monkeypatch):
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
monkeypatch.setenv("NTFY_PUBLISH_TOPIC", "hermes-out")
seed = _env_enablement()
assert seed["publish_topic"] == "hermes-out"
def test_token_seeded(self, monkeypatch):
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
monkeypatch.setenv("NTFY_TOKEN", "tk_abc")
seed = _env_enablement()
assert seed["token"] == "tk_abc"
def test_markdown_truthy_values(self, monkeypatch):
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
for val in ("true", "1", "yes", "TRUE"):
monkeypatch.setenv("NTFY_MARKDOWN", val)
assert _env_enablement()["markdown"] is True
def test_markdown_falsy_values(self, monkeypatch):
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
for val in ("false", "0", "no", "anything"):
monkeypatch.setenv("NTFY_MARKDOWN", val)
assert _env_enablement()["markdown"] is False
def test_home_channel_defaults_to_topic(self, monkeypatch):
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
monkeypatch.delenv("NTFY_HOME_CHANNEL", raising=False)
seed = _env_enablement()
assert seed["home_channel"]["chat_id"] == "hermes-in"
assert seed["home_channel"]["name"] == "hermes-in"
def test_home_channel_override(self, monkeypatch):
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
monkeypatch.setenv("NTFY_HOME_CHANNEL", "alerts")
monkeypatch.setenv("NTFY_HOME_CHANNEL_NAME", "Alerts Channel")
seed = _env_enablement()
assert seed["home_channel"]["chat_id"] == "alerts"
assert seed["home_channel"]["name"] == "Alerts Channel"
# ---------------------------------------------------------------------------
# 10. _standalone_send() — out-of-process cron delivery
# ---------------------------------------------------------------------------
class TestStandaloneSend:
def test_errors_without_topic(self, monkeypatch):
monkeypatch.delenv("NTFY_TOPIC", raising=False)
monkeypatch.delenv("NTFY_PUBLISH_TOPIC", raising=False)
pconfig = MagicMock()
pconfig.extra = {}
result = _run(_standalone_send(pconfig, "", "hello"))
assert "error" in result
assert "NTFY_TOPIC" in result["error"]
def test_posts_to_server(self, monkeypatch):
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
pconfig = MagicMock()
pconfig.extra = {"server": "https://ntfy.example.com", "topic": "hermes-in"}
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"id": "id-42"}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
with patch.object(_ntfy, "httpx") as mock_httpx:
mock_httpx.AsyncClient.return_value = mock_client
result = _run(_standalone_send(pconfig, "hermes-in", "hello"))
assert result.get("success") is True
assert result["platform"] == "ntfy"
assert result["message_id"] == "id-42"
posted_url = mock_client.post.call_args[0][0]
assert posted_url == "https://ntfy.example.com/hermes-in"
def test_emits_bearer_token_when_configured(self, monkeypatch):
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
pconfig = MagicMock()
pconfig.extra = {"topic": "hermes-in", "token": "tk_xyz"}
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
with patch.object(_ntfy, "httpx") as mock_httpx:
mock_httpx.AsyncClient.return_value = mock_client
_run(_standalone_send(pconfig, "hermes-in", "hi"))
headers = mock_client.post.call_args[1]["headers"]
assert headers["Authorization"] == "Bearer tk_xyz"
def test_basic_auth_when_token_has_colon(self, monkeypatch):
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
pconfig = MagicMock()
pconfig.extra = {"topic": "hermes-in", "token": "user:pass"}
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
with patch.object(_ntfy, "httpx") as mock_httpx:
mock_httpx.AsyncClient.return_value = mock_client
_run(_standalone_send(pconfig, "hermes-in", "hi"))
headers = mock_client.post.call_args[1]["headers"]
assert headers["Authorization"].startswith("Basic ")
def test_returns_error_on_http_failure(self, monkeypatch):
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
pconfig = MagicMock()
pconfig.extra = {"topic": "hermes-in"}
mock_resp = MagicMock()
mock_resp.status_code = 403
mock_resp.text = "Forbidden"
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
with patch.object(_ntfy, "httpx") as mock_httpx:
mock_httpx.AsyncClient.return_value = mock_client
result = _run(_standalone_send(pconfig, "hermes-in", "hi"))
assert "error" in result
assert "403" in result["error"]
# ---------------------------------------------------------------------------
# 11. register() — plugin-side metadata
# ---------------------------------------------------------------------------
def test_register_calls_register_platform():
ctx = MagicMock()
register(ctx)
ctx.register_platform.assert_called_once()
kwargs = ctx.register_platform.call_args.kwargs
assert kwargs["name"] == "ntfy"
assert kwargs["label"] == "ntfy"
assert kwargs["required_env"] == ["NTFY_TOPIC"]
assert kwargs["allowed_users_env"] == "NTFY_ALLOWED_USERS"
assert kwargs["allow_all_env"] == "NTFY_ALLOW_ALL_USERS"
assert kwargs["cron_deliver_env_var"] == "NTFY_HOME_CHANNEL"
assert kwargs["max_message_length"] == MAX_MESSAGE_LENGTH
assert callable(kwargs["check_fn"])
assert callable(kwargs["validate_config"])
assert callable(kwargs["is_connected"])
assert callable(kwargs["env_enablement_fn"])
assert callable(kwargs["standalone_sender_fn"])
assert callable(kwargs["adapter_factory"])
# ntfy has no user-identifying PII (only topic names)
assert kwargs["pii_safe"] is True
assert "ntfy" in kwargs["platform_hint"].lower()
def test_adapter_factory_returns_ntfy_adapter():
ctx = MagicMock()
register(ctx)
factory = ctx.register_platform.call_args.kwargs["adapter_factory"]
cfg = PlatformConfig(enabled=True, extra={"topic": "t"})
adapter = factory(cfg)
assert isinstance(adapter, NtfyAdapter)
# ---------------------------------------------------------------------------
# 12. Robustness — token hygiene + fatal-state propagation
# ---------------------------------------------------------------------------
class TestTokenHygiene:
"""``_build_auth_header`` must strip pasted-token whitespace; pasted
tokens often carry trailing newlines that break the Authorization line."""
def test_trailing_whitespace_stripped(self):
assert _ntfy._build_auth_header(" tok123 ") == {"Authorization": "Bearer tok123"}
def test_trailing_newline_stripped(self):
assert _ntfy._build_auth_header("tok123\n") == {"Authorization": "Bearer tok123"}
def test_whitespace_only_returns_empty(self):
assert _ntfy._build_auth_header(" \n ") == {}
def test_basic_auth_token_also_stripped(self):
h = _ntfy._build_auth_header(" user:pass ")
assert h["Authorization"].startswith("Basic ")
import base64
assert h["Authorization"] == "Basic " + base64.b64encode(b"user:pass").decode()
def test_adapter_strips_token_via_helper(self):
"""The adapter delegates to _build_auth_header, so token whitespace
passed via config.extra is also stripped."""
config = PlatformConfig(enabled=True, extra={"topic": "t", "token": " tok\n"})
adapter = NtfyAdapter(config)
assert adapter._auth_headers() == {"Authorization": "Bearer tok"}
class TestFatalErrorPropagation:
"""When the stream hits 401/404, the adapter must transition to the
``fatal`` state via ``_set_fatal_error`` so the gateway's runtime
status reflects reality instead of staying 'connected'."""
def test_401_sets_fatal_unauthorized(self):
adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"}))
adapter._http_client = MagicMock()
# Mock the streaming response
mock_response = MagicMock()
mock_response.status_code = 401
# async-context-manager flavor for httpx.stream
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
mock_cm.__aexit__ = AsyncMock(return_value=None)
adapter._http_client.stream = MagicMock(return_value=mock_cm)
fake_httpx = MagicMock()
fake_httpx.Timeout = MagicMock()
with patch.object(_ntfy, "httpx", fake_httpx):
with pytest.raises(_ntfy._FatalStreamError):
_run(adapter._consume_stream("https://ntfy.example/t/json", {}))
assert adapter.has_fatal_error is True
assert adapter._fatal_error_code == "ntfy_unauthorized"
assert adapter._fatal_error_retryable is False
def test_404_sets_fatal_topic_not_found(self):
adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "missing-topic"}))
adapter._http_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 404
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
mock_cm.__aexit__ = AsyncMock(return_value=None)
adapter._http_client.stream = MagicMock(return_value=mock_cm)
fake_httpx = MagicMock()
fake_httpx.Timeout = MagicMock()
with patch.object(_ntfy, "httpx", fake_httpx):
with pytest.raises(_ntfy._FatalStreamError):
_run(adapter._consume_stream("https://ntfy.example/missing-topic/json", {}))
assert adapter.has_fatal_error is True
assert adapter._fatal_error_code == "ntfy_topic_not_found"
assert "missing-topic" in adapter._fatal_error_message
assert adapter._fatal_error_retryable is False
class TestTruncateHelper:
"""``_truncate_body`` is shared between adapter.send() (inline truncation
today, may migrate) and ``_standalone_send``. It must cap to
MAX_MESSAGE_LENGTH and return bytes."""
def test_short_message_passes_through(self):
assert _ntfy._truncate_body("hi", context="test") == b"hi"
def test_long_message_truncated(self):
long = "x" * (MAX_MESSAGE_LENGTH + 50)
result = _ntfy._truncate_body(long, context="test")
assert isinstance(result, bytes)
assert len(result) == MAX_MESSAGE_LENGTH
def test_unicode_message_encoded(self):
result = _ntfy._truncate_body("héllo 🔔", context="test")
assert result == "héllo 🔔".encode("utf-8")

View file

@ -2,10 +2,13 @@
import json
import os
import sys
import time
from pathlib import Path
from unittest.mock import patch
import pytest
from gateway.pairing import (
PairingStore,
ALPHABET,
@ -37,6 +40,10 @@ class TestSecureWrite:
assert target.exists()
assert json.loads(target.read_text()) == {"hello": "world"}
@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="POSIX file modes are not enforced on Windows",
)
def test_sets_file_permissions(self, tmp_path):
target = tmp_path / "secret.json"
_secure_write(target, "data")
@ -75,9 +82,197 @@ class TestCodeGeneration:
code = store.generate_code("telegram", "user1", "Alice")
pending = store.list_pending("telegram")
assert len(pending) == 1
assert pending[0]["code"] == code
# list_pending no longer returns the original code — it returns a
# truncated hash prefix. Verify the metadata is correct instead.
assert pending[0]["user_id"] == "user1"
assert pending[0]["user_name"] == "Alice"
# The code field is now a hash prefix, not the original plaintext code
assert pending[0]["code"] != code
# ---------------------------------------------------------------------------
# Hashed storage
# ---------------------------------------------------------------------------
class TestHashedStorage:
def test_pending_file_contains_hash_and_salt(self, tmp_path):
"""Stored entries must have 'hash' and 'salt', never the plaintext code."""
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code = store.generate_code("telegram", "user1", "Alice")
raw = json.loads(
(tmp_path / "telegram-pending.json").read_text(encoding="utf-8")
)
assert len(raw) == 1
entry = next(iter(raw.values()))
# Must have hash and salt fields
assert "hash" in entry
assert "salt" in entry
# Hash must be a valid hex SHA-256 digest (64 hex chars)
assert len(entry["hash"]) == 64
assert all(c in "0123456789abcdef" for c in entry["hash"])
# Salt must be a valid hex string (32 hex chars for 16 bytes)
assert len(entry["salt"]) == 32
assert all(c in "0123456789abcdef" for c in entry["salt"])
# The plaintext code must NOT appear as a key or value anywhere
assert code not in raw # not a key
for key, val in raw.items():
assert code != key
for field_val in val.values():
if isinstance(field_val, str):
assert field_val != code
def test_plaintext_code_not_stored(self, tmp_path):
"""The raw JSON file must not contain the plaintext code anywhere."""
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code = store.generate_code("telegram", "user1")
raw_text = (tmp_path / "telegram-pending.json").read_text(encoding="utf-8")
assert code not in raw_text
def test_valid_code_verifies_against_hash(self, tmp_path):
"""approve_code with the correct code should succeed."""
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code = store.generate_code("telegram", "user1", "Bob")
result = store.approve_code("telegram", code)
assert result is not None
assert result["user_id"] == "user1"
assert result["user_name"] == "Bob"
def test_invalid_code_rejected(self, tmp_path):
"""approve_code with a wrong code should fail."""
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
store.generate_code("telegram", "user1")
result = store.approve_code("telegram", "ZZZZZZZZ")
assert result is None
def test_different_salts_per_entry(self, tmp_path):
"""Each pending entry should have a unique salt."""
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
store.generate_code("telegram", "user0")
store.generate_code("telegram", "user1")
store.generate_code("telegram", "user2")
raw = json.loads(
(tmp_path / "telegram-pending.json").read_text(encoding="utf-8")
)
salts = [entry["salt"] for entry in raw.values()]
assert len(set(salts)) == 3 # all unique
def test_hash_code_static_method(self, tmp_path):
"""_hash_code should be deterministic for the same code+salt."""
salt = os.urandom(16)
h1 = PairingStore._hash_code("ABCD1234", salt)
h2 = PairingStore._hash_code("ABCD1234", salt)
assert h1 == h2
# Different salt should produce a different hash
salt2 = os.urandom(16)
h3 = PairingStore._hash_code("ABCD1234", salt2)
assert h3 != h1
class TestLegacyPendingFileCompat:
"""Defensive coverage for pre-hash pending.json on upgraded installs.
Existing user installs may have a pending.json written by the old
code (plaintext code as key, no hash/salt fields). The new
approve_code / list_pending / _cleanup_expired must not crash on
those entries they should be ignored and aged out at TTL.
"""
@staticmethod
def _write_legacy(tmp_path, code="ABCD1234", created_at=None):
"""Write a pre-hash pending.json with plaintext code as the key."""
import time as _time
if created_at is None:
created_at = _time.time()
legacy = {
code: {
"user_id": "legacy-user",
"user_name": "Legacy",
"created_at": created_at,
}
}
(tmp_path / "telegram-pending.json").write_text(
json.dumps(legacy), encoding="utf-8"
)
def test_approve_code_ignores_legacy_entries(self, tmp_path):
"""A valid old-format code must NOT silently approve under the new schema."""
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
self._write_legacy(tmp_path, code="LEGACY01")
store = PairingStore()
# The plaintext "code" used to be the key — under the new schema
# it's not even looked at, and there's no hash/salt to verify.
# Result: approve_code returns None, the legacy entry is left
# alone (gets pruned by _cleanup_expired at TTL).
result = store.approve_code("telegram", "LEGACY01")
assert result is None
# Approved list must be empty
assert store.is_approved("telegram", "legacy-user") is False
def test_list_pending_handles_legacy_entries(self, tmp_path):
"""list_pending must not KeyError on a missing 'hash' field."""
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
self._write_legacy(tmp_path)
store = PairingStore()
pending = store.list_pending("telegram")
assert len(pending) == 1
assert pending[0]["user_id"] == "legacy-user"
assert pending[0]["code"] == "legacy" # placeholder
def test_cleanup_expired_removes_legacy_at_ttl(self, tmp_path):
"""Legacy entries past CODE_TTL must still get pruned."""
import time as _time
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
self._write_legacy(
tmp_path,
code="LEGACY99",
created_at=_time.time() - CODE_TTL_SECONDS - 1,
)
store = PairingStore()
store._cleanup_expired("telegram")
raw = json.loads(
(tmp_path / "telegram-pending.json").read_text(encoding="utf-8")
)
assert raw == {}
def test_cleanup_expired_handles_malformed_entries(self, tmp_path):
"""Non-dict / missing-created_at entries get evicted, not crashed on."""
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
(tmp_path / "telegram-pending.json").write_text(
json.dumps({
"broken1": "not a dict",
"broken2": {"user_id": "x"}, # no created_at
"broken3": {"created_at": "not a number"},
}),
encoding="utf-8",
)
store = PairingStore()
store._cleanup_expired("telegram")
raw = json.loads(
(tmp_path / "telegram-pending.json").read_text(encoding="utf-8")
)
assert raw == {}
def test_approve_code_skips_malformed_entries(self, tmp_path):
"""Malformed entries must not crash approve_code's hash loop."""
import time as _time
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
(tmp_path / "telegram-pending.json").write_text(
json.dumps({
"broken": {"user_id": "x", "created_at": _time.time(),
"salt": "not-hex", "hash": "doesntmatter"},
}),
encoding="utf-8",
)
store = PairingStore()
# Approving with any code must just return None, not crash.
assert store.approve_code("telegram", "ABCD1234") is None
# ---------------------------------------------------------------------------
@ -117,6 +312,23 @@ class TestRateLimiting:
assert isinstance(code2, str) and len(code2) == CODE_LENGTH
assert code2 != code1
def test_whatsapp_alias_flip_hits_same_rate_limit(self, tmp_path, monkeypatch):
mapping_dir = tmp_path / "whatsapp" / "session"
mapping_dir.mkdir(parents=True, exist_ok=True)
(mapping_dir / "lid-mapping-999999999999999.json").write_text(
json.dumps("15551234567@s.whatsapp.net"),
encoding="utf-8",
)
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code1 = store.generate_code("whatsapp", "15551234567@s.whatsapp.net")
code2 = store.generate_code("whatsapp", "999999999999999@lid")
assert isinstance(code1, str) and len(code1) == CODE_LENGTH
assert code2 is None
# ---------------------------------------------------------------------------
# Max pending limit
@ -209,6 +421,55 @@ class TestApprovalFlow:
result = store.approve_code("telegram", "INVALIDCODE")
assert result is None
def test_whatsapp_approved_user_survives_alias_flip(self, tmp_path, monkeypatch):
mapping_dir = tmp_path / "whatsapp" / "session"
mapping_dir.mkdir(parents=True, exist_ok=True)
(mapping_dir / "lid-mapping-999999999999999.json").write_text(
json.dumps("15551234567@s.whatsapp.net"),
encoding="utf-8",
)
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code = store.generate_code("whatsapp", "15551234567@s.whatsapp.net", "Alice")
store.approve_code("whatsapp", code)
assert store.is_approved("whatsapp", "15551234567@s.whatsapp.net") is True
assert store.is_approved("whatsapp", "999999999999999@lid") is True
approved = store.list_approved("whatsapp")
assert len(approved) == 1
assert approved[0]["user_id"] == "15551234567"
def test_whatsapp_legacy_raw_jid_approval_survives_alias_flip(self, tmp_path, monkeypatch):
mapping_dir = tmp_path / "whatsapp" / "session"
mapping_dir.mkdir(parents=True, exist_ok=True)
(mapping_dir / "lid-mapping-999999999999999.json").write_text(
json.dumps("15551234567@s.whatsapp.net"),
encoding="utf-8",
)
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
approved_path = tmp_path / "whatsapp-approved.json"
approved_path.write_text(
json.dumps(
{
"15551234567@s.whatsapp.net": {
"user_name": "Legacy Alice",
"approved_at": time.time(),
}
},
indent=2,
),
encoding="utf-8",
)
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
assert store.is_approved("whatsapp", "999999999999999@lid") is True
# ---------------------------------------------------------------------------
# Lockout after failed attempts
@ -300,9 +561,10 @@ class TestCodeExpiry:
store = PairingStore()
code = store.generate_code("telegram", "user1")
# Manually expire the code
# Manually expire all pending entries
pending = store._load_json(store._pending_path("telegram"))
pending[code]["created_at"] = time.time() - CODE_TTL_SECONDS - 1
for entry_id in pending:
pending[entry_id]["created_at"] = time.time() - CODE_TTL_SECONDS - 1
store._save_json(store._pending_path("telegram"), pending)
# Cleanup happens on next operation
@ -314,9 +576,10 @@ class TestCodeExpiry:
store = PairingStore()
code = store.generate_code("telegram", "user1")
# Expire it
# Expire all entries
pending = store._load_json(store._pending_path("telegram"))
pending[code]["created_at"] = time.time() - CODE_TTL_SECONDS - 1
for entry_id in pending:
pending[entry_id]["created_at"] = time.time() - CODE_TTL_SECONDS - 1
store._save_json(store._pending_path("telegram"), pending)
result = store.approve_code("telegram", code)

View file

@ -361,6 +361,72 @@ class TestExtractMedia:
assert "[[as_document]]" not in cleaned
class TestMediaDeliveryPathValidation:
def _patch_roots(self, monkeypatch, *roots):
monkeypatch.setattr(
"gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS",
tuple(roots),
)
def test_allows_existing_file_inside_safe_root(self, tmp_path, monkeypatch):
root = tmp_path / "media-cache"
media_file = root / "voice.ogg"
media_file.parent.mkdir(parents=True)
media_file.write_bytes(b"OggS")
self._patch_roots(monkeypatch, root)
assert BasePlatformAdapter.validate_media_delivery_path(str(media_file)) == str(media_file.resolve())
def test_rejects_existing_file_outside_safe_root(self, tmp_path, monkeypatch):
root = tmp_path / "media-cache"
root.mkdir()
secret = tmp_path / "secrets.txt"
secret.write_text("not for upload")
self._patch_roots(monkeypatch, root)
assert BasePlatformAdapter.validate_media_delivery_path(str(secret)) is None
def test_rejects_symlink_escape_from_safe_root(self, tmp_path, monkeypatch):
root = tmp_path / "media-cache"
root.mkdir()
secret = tmp_path / "outside.png"
secret.write_bytes(b"secret")
link = root / "safe-looking.png"
try:
link.symlink_to(secret)
except OSError:
pytest.skip("symlink creation is unavailable")
self._patch_roots(monkeypatch, root)
assert BasePlatformAdapter.validate_media_delivery_path(str(link)) is None
def test_filter_keeps_safe_media_and_drops_unsafe(self, tmp_path, monkeypatch):
root = tmp_path / "media-cache"
safe = root / "speech.ogg"
unsafe = tmp_path / "outside.ogg"
safe.parent.mkdir(parents=True)
safe.write_bytes(b"OggS")
unsafe.write_bytes(b"OggS")
self._patch_roots(monkeypatch, root)
filtered = BasePlatformAdapter.filter_media_delivery_paths([
(str(unsafe), False),
(str(safe), True),
])
assert filtered == [(str(safe.resolve()), True)]
def test_allows_operator_configured_extra_root(self, tmp_path, monkeypatch):
extra_root = tmp_path / "operator-media"
media_file = extra_root / "report.pdf"
media_file.parent.mkdir(parents=True)
media_file.write_bytes(b"%PDF-1.4")
self._patch_roots(monkeypatch)
monkeypatch.setenv("HERMES_MEDIA_ALLOW_DIRS", str(extra_root))
assert BasePlatformAdapter.validate_media_delivery_path(str(media_file)) == str(media_file.resolve())
# ---------------------------------------------------------------------------
# should_send_media_as_audio
# ---------------------------------------------------------------------------
@ -728,4 +794,3 @@ class TestProxyKwargsForAiohttp:
sess_kw, req_kw = proxy_kwargs_for_aiohttp("http://proxy:8080")
assert sess_kw == {}
assert req_kw == {"proxy": "http://proxy:8080"}

View file

@ -79,10 +79,11 @@ def test_checker_returns_true_when_configured(platform, checker, monkeypatch):
elif platform in {
Platform.API_SERVER,
Platform.WEBHOOK,
Platform.MSGRAPH_WEBHOOK,
Platform.WHATSAPP,
}:
mock_config.extra = {}
elif platform == Platform.MSGRAPH_WEBHOOK:
mock_config.extra = {"client_state": "expected-client-state"}
elif platform == Platform.FEISHU:
mock_config.extra = {"app_id": "app"}
elif platform == Platform.WECOM:

View file

@ -708,3 +708,279 @@ class TestPluginPlatformSharedKeyBridge:
assert extra.get("allow_from") == ["alice", "bob"]
finally:
_reg.unregister("mysharedplat")
class TestPluginEnablementGate:
"""Plugin platforms must NOT auto-enable on check_fn alone (#31116).
When a plugin registers ``is_connected`` (the "did the user actually
configure credentials" probe), ``load_gateway_config`` must consult it
before flipping ``enabled = True``. Without this gate, ``check_fn``
semantics ("the SDK is importable") get conflated with "the user wants
this platform on", and the gateway tries to connect to e.g. Discord
with no token emitting noisy retry-forever errors on every fresh
install that has the plugin loaded.
"""
def _write_config(self, tmp_path, content: str = ""):
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
(hermes_home / "config.yaml").write_text(content, encoding="utf-8")
return hermes_home
def test_plugin_with_is_connected_false_is_NOT_enabled(
self, tmp_path, monkeypatch
):
"""check_fn=True + is_connected=False must NOT enable the platform.
Reproduces #31116: Discord plugin loads, its check_fn lazy-installs
discord.py and returns True, but the user has no DISCORD_BOT_TOKEN.
Previously this auto-enabled Discord and the gateway spammed
``ERROR ... [Discord] No bot token configured`` on every reconnect.
"""
from gateway.platform_registry import platform_registry as _reg
_reg.register(PlatformEntry(
name="myunconfiguredplat",
label="MyUnconfigured",
adapter_factory=lambda cfg: None,
check_fn=lambda: True, # SDK available
is_connected=lambda cfg: False, # but user hasn't set credentials
source="plugin",
))
try:
home = self._write_config(tmp_path)
monkeypatch.setenv("HERMES_HOME", str(home))
from gateway.config import load_gateway_config, Platform
cfg = load_gateway_config()
plat = Platform("myunconfiguredplat")
# Either absent entirely, or present but explicitly disabled.
if plat in cfg.platforms:
assert cfg.platforms[plat].enabled is False, (
"Plugin with is_connected=False must NOT be auto-enabled"
)
finally:
_reg.unregister("myunconfiguredplat")
def test_plugin_with_is_connected_true_is_enabled(
self, tmp_path, monkeypatch
):
"""check_fn=True + is_connected=True still enables the platform."""
from gateway.platform_registry import platform_registry as _reg
_reg.register(PlatformEntry(
name="myconfiguredplat",
label="MyConfigured",
adapter_factory=lambda cfg: None,
check_fn=lambda: True,
is_connected=lambda cfg: True,
source="plugin",
))
try:
home = self._write_config(tmp_path)
monkeypatch.setenv("HERMES_HOME", str(home))
from gateway.config import load_gateway_config, Platform
cfg = load_gateway_config()
plat = Platform("myconfiguredplat")
assert plat in cfg.platforms
assert cfg.platforms[plat].enabled is True
finally:
_reg.unregister("myconfiguredplat")
def test_plugin_without_is_connected_falls_back_to_check_fn(
self, tmp_path, monkeypatch
):
"""Legacy plugins that don't register is_connected keep working.
For plugins where ``is_connected is None``, gating on ``check_fn``
alone remains the contract that's what callers without a
credential probe have always done.
"""
from gateway.platform_registry import platform_registry as _reg
_reg.register(PlatformEntry(
name="mylegacyplat",
label="MyLegacy",
adapter_factory=lambda cfg: None,
check_fn=lambda: True,
# is_connected intentionally omitted (None)
source="plugin",
))
try:
home = self._write_config(tmp_path)
monkeypatch.setenv("HERMES_HOME", str(home))
from gateway.config import load_gateway_config, Platform
cfg = load_gateway_config()
plat = Platform("mylegacyplat")
assert plat in cfg.platforms
assert cfg.platforms[plat].enabled is True
finally:
_reg.unregister("mylegacyplat")
def test_is_connected_raises_does_not_enable(self, tmp_path, monkeypatch):
"""A buggy is_connected must not silently enable the platform.
Treat a raising is_connected as "configuration unknown" refuse to
enable, log, and move on. Anything else would re-introduce the
#31116 bug for plugins whose probe has a transient failure.
"""
from gateway.platform_registry import platform_registry as _reg
def _bad_probe(cfg):
raise RuntimeError("plugin bug")
_reg.register(PlatformEntry(
name="mybadprobeplat",
label="MyBadProbe",
adapter_factory=lambda cfg: None,
check_fn=lambda: True,
is_connected=_bad_probe,
source="plugin",
))
try:
home = self._write_config(tmp_path)
monkeypatch.setenv("HERMES_HOME", str(home))
from gateway.config import load_gateway_config, Platform
cfg = load_gateway_config()
plat = Platform("mybadprobeplat")
if plat in cfg.platforms:
assert cfg.platforms[plat].enabled is False
finally:
_reg.unregister("mybadprobeplat")
def test_yaml_enabled_true_overrides_is_connected_false(
self, tmp_path, monkeypatch
):
"""Explicit YAML ``enabled: true`` wins over is_connected=False.
If the user wrote ``platforms.X.enabled: true`` themselves, respect
that they may be using a credential mechanism the plugin's
is_connected probe doesn't know about. Don't fight them.
"""
from gateway.platform_registry import platform_registry as _reg
_reg.register(PlatformEntry(
name="myexplicitplat",
label="MyExplicit",
adapter_factory=lambda cfg: None,
check_fn=lambda: True,
is_connected=lambda cfg: False,
source="plugin",
))
try:
home = self._write_config(
tmp_path,
"platforms:\n"
" myexplicitplat:\n"
" enabled: true\n",
)
monkeypatch.setenv("HERMES_HOME", str(home))
from gateway.config import load_gateway_config, Platform
cfg = load_gateway_config()
plat = Platform("myexplicitplat")
assert plat in cfg.platforms
assert cfg.platforms[plat].enabled is True, (
"Explicit YAML enabled: true must win over plugin's "
"is_connected=False — user has the final say"
)
finally:
_reg.unregister("myexplicitplat")
def test_is_connected_sees_env_seeded_extras(self, tmp_path, monkeypatch):
"""``env_enablement_fn`` extras must be visible to ``is_connected``.
Some plugins (e.g. Google Chat) implement ``is_connected`` by
inspecting ``config.extra`` (where ``env_enablement_fn`` deposits
env-var-derived state) rather than reading ``os.environ`` directly.
If the gate runs BEFORE the seeding step, those plugins fail the
gate even when the user is genuinely configured via env vars.
Pin the contract: when both hooks are present, ``env_enablement_fn``
feeds a candidate config to ``is_connected``.
"""
from gateway.platform_registry import platform_registry as _reg
seen_extras: dict = {}
def _is_connected(cfg):
seen_extras["snapshot"] = dict(getattr(cfg, "extra", {}) or {})
extra = getattr(cfg, "extra", {}) or {}
return bool(extra.get("project_id") and extra.get("subscription_name"))
def _env_enablement():
return {"project_id": "p", "subscription_name": "s"}
_reg.register(PlatformEntry(
name="myextrasplat",
label="MyExtras",
adapter_factory=lambda cfg: None,
check_fn=lambda: True,
is_connected=_is_connected,
env_enablement_fn=_env_enablement,
source="plugin",
))
try:
home = self._write_config(tmp_path)
monkeypatch.setenv("HERMES_HOME", str(home))
from gateway.config import load_gateway_config, Platform
cfg = load_gateway_config()
plat = Platform("myextrasplat")
assert plat in cfg.platforms, (
"is_connected was called with empty extras — "
"env_enablement_fn must seed the probe BEFORE the gate"
)
assert cfg.platforms[plat].enabled is True
# extras populated on the live config too
assert cfg.platforms[plat].extra.get("project_id") == "p"
assert cfg.platforms[plat].extra.get("subscription_name") == "s"
# and the probe saw them
assert seen_extras["snapshot"]["project_id"] == "p"
finally:
_reg.unregister("myextrasplat")
def test_is_connected_failed_gate_does_not_leak_extras(
self, tmp_path, monkeypatch
):
"""When the gate rejects, env-seeded extras must NOT leak onto
``config.platforms``. A rejected plugin should be invisible, not
present-but-partially-populated.
"""
from gateway.platform_registry import platform_registry as _reg
_reg.register(PlatformEntry(
name="myrejectedplat",
label="MyRejected",
adapter_factory=lambda cfg: None,
check_fn=lambda: True,
is_connected=lambda cfg: False,
env_enablement_fn=lambda: {"some_key": "should-not-leak"},
source="plugin",
))
try:
home = self._write_config(tmp_path)
monkeypatch.setenv("HERMES_HOME", str(home))
from gateway.config import load_gateway_config, Platform
cfg = load_gateway_config()
plat = Platform("myrejectedplat")
if plat in cfg.platforms:
assert cfg.platforms[plat].enabled is False
assert "some_key" not in cfg.platforms[plat].extra, (
"Rejected plugin's env-seeded extras leaked onto "
"config.platforms"
)
finally:
_reg.unregister("myrejectedplat")

View file

@ -1233,14 +1233,14 @@ class TestAdapterInteractionDispatch:
"user_openid": "user-1",
"data": {
"type": 11,
"resolved": {"button_data": "approve:s:deny", "button_id": "deny"},
"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny", "button_id": "deny"},
},
})
assert len(ack_calls) == 1
assert ack_calls[0][0] == "i-1"
assert len(received) == 1
assert received[0].button_data == "approve:s:deny"
assert received[0].button_data == "approve:agent:main:qqbot:c2c:u:deny"
assert received[0].scene == "c2c"
@pytest.mark.asyncio
@ -1262,7 +1262,7 @@ class TestAdapterInteractionDispatch:
adapter.set_interaction_callback(cb)
await adapter._on_interaction({
"chat_type": 2, # no id
"data": {"resolved": {"button_data": "approve:s:deny"}},
"data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny"}},
})
assert ack_calls == []
@ -1286,7 +1286,7 @@ class TestAdapterInteractionDispatch:
"id": "i-2",
"chat_type": 2,
"user_openid": "u",
"data": {"resolved": {"button_data": "approve:s:deny"}},
"data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny"}},
})
@pytest.mark.asyncio
@ -1304,7 +1304,7 @@ class TestAdapterInteractionDispatch:
"id": "i-3",
"chat_type": 2,
"user_openid": "u",
"data": {"resolved": {"button_data": "approve:s:deny"}},
"data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny"}},
})
@ -1570,13 +1570,13 @@ class TestDefaultInteractionDispatch:
"id": "i",
"chat_type": 2,
"user_openid": "u-42",
"data": {"resolved": {"button_data": "approve:sess-abc:allow-once"}},
"data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u-42:allow-once"}},
})
await adapter._default_interaction_dispatch(event)
finally:
tools.approval.resolve_gateway_approval = orig
assert resolve_calls == [("sess-abc", "once", False)]
assert resolve_calls == [("agent:main:qqbot:c2c:u-42", "once", False)]
@pytest.mark.asyncio
async def test_approval_click_always_maps_to_always(self):
@ -1594,13 +1594,13 @@ class TestDefaultInteractionDispatch:
from gateway.platforms.qqbot.keyboards import parse_interaction_event
event = parse_interaction_event({
"id": "i", "chat_type": 2, "user_openid": "u",
"data": {"resolved": {"button_data": "approve:s:allow-always"}},
"data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:allow-always"}},
})
await adapter._default_interaction_dispatch(event)
finally:
tools.approval.resolve_gateway_approval = orig
assert resolve_calls == [("s", "always", False)]
assert resolve_calls == [("agent:main:qqbot:c2c:u", "always", False)]
@pytest.mark.asyncio
async def test_approval_click_deny_maps_to_deny(self):
@ -1618,13 +1618,40 @@ class TestDefaultInteractionDispatch:
from gateway.platforms.qqbot.keyboards import parse_interaction_event
event = parse_interaction_event({
"id": "i", "chat_type": 2, "user_openid": "u",
"data": {"resolved": {"button_data": "approve:s:deny"}},
"data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny"}},
})
await adapter._default_interaction_dispatch(event)
finally:
tools.approval.resolve_gateway_approval = orig
assert resolve_calls == [("s", "deny", False)]
assert resolve_calls == [("agent:main:qqbot:c2c:u", "deny", False)]
@pytest.mark.asyncio
async def test_approval_click_rejects_unauthorized_operator(self):
adapter = self._make_adapter()
resolve_calls = []
def fake_resolve(session_key, choice, resolve_all=False):
resolve_calls.append((session_key, choice, resolve_all))
return 1
import tools.approval
orig = tools.approval.resolve_gateway_approval
tools.approval.resolve_gateway_approval = fake_resolve
try:
from gateway.platforms.qqbot.keyboards import parse_interaction_event
event = parse_interaction_event({
"id": "i", "chat_type": 1,
"group_openid": "g-1",
"group_member_openid": "attacker",
"data": {"resolved": {"button_data": "approve:agent:main:qqbot:group:g-1:owner:allow-once"}},
})
await adapter._default_interaction_dispatch(event)
finally:
tools.approval.resolve_gateway_approval = orig
assert resolve_calls == []
@pytest.mark.asyncio
async def test_update_prompt_click_writes_response_file(self, tmp_path, monkeypatch):
@ -1700,7 +1727,7 @@ class TestDefaultInteractionDispatch:
from gateway.platforms.qqbot.keyboards import parse_interaction_event
event = parse_interaction_event({
"id": "i", "chat_type": 2, "user_openid": "u",
"data": {"resolved": {"button_data": "approve:s:deny"}},
"data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny"}},
})
# Must not raise.
await adapter._default_interaction_dispatch(event)
@ -1810,3 +1837,365 @@ class TestSendUpdatePrompt:
adapter.send_with_keyboard = fake_swk # type: ignore[assignment]
await adapter.send_update_prompt(chat_id="u", prompt="ok?")
# ---------------------------------------------------------------------------
# _send_identify includes INTERACTION intent
# ---------------------------------------------------------------------------
class TestIdentifyIntents:
"""Verify the WebSocket identify payload includes the INTERACTION intent bit."""
def _make_adapter(self):
from gateway.platforms.qqbot.adapter import QQAdapter
return QQAdapter(_make_config(app_id="a", client_secret="b"))
@pytest.mark.asyncio
async def test_intents_include_interaction_bit(self):
adapter = self._make_adapter()
# Mock token retrieval and WebSocket
adapter._access_token = "fake_token"
adapter._token_expires_at = 9999999999.0
sent_payloads = []
class FakeWS:
closed = False
async def send_json(self, payload):
sent_payloads.append(payload)
adapter._ws = FakeWS()
await adapter._send_identify()
assert len(sent_payloads) == 1
intents = sent_payloads[0]["d"]["intents"]
# Verify all expected intent bits are present
assert intents & (1 << 25), "GROUP_MESSAGES (1<<25) missing"
assert intents & (1 << 30), "GUILD_AT_MESSAGE (1<<30) missing"
assert intents & (1 << 12), "DIRECT_MESSAGES (1<<12) missing"
assert intents & (1 << 26), "INTERACTION (1<<26) missing"
# ---------------------------------------------------------------------------
# _process_attachments: video/file path exposure
# ---------------------------------------------------------------------------
class TestProcessAttachmentsPathExposure:
"""Verify that video and file attachments include the cached local path."""
def _make_adapter(self):
from gateway.platforms.qqbot.adapter import QQAdapter
return QQAdapter(_make_config(app_id="a", client_secret="b"))
@pytest.mark.asyncio
async def test_video_attachment_includes_path(self):
adapter = self._make_adapter()
# Mock _download_and_cache to return a known path
async def fake_download(url, ct, original_name=""):
return "/tmp/cache/video_abc123.mp4"
adapter._download_and_cache = fake_download # type: ignore[assignment]
attachments = [
{
"content_type": "video/mp4",
"url": "https://multimedia.nt.qq.com.cn/download/video123",
"filename": "my_video.mp4",
}
]
result = await adapter._process_attachments(attachments)
assert result["image_urls"] == []
assert result["voice_transcripts"] == []
info = result["attachment_info"]
assert "[video:" in info
assert "my_video.mp4" in info
assert "/tmp/cache/video_abc123.mp4" in info
@pytest.mark.asyncio
async def test_file_attachment_includes_path(self):
adapter = self._make_adapter()
async def fake_download(url, ct, original_name=""):
return "/tmp/cache/doc_abc123_report.pdf"
adapter._download_and_cache = fake_download # type: ignore[assignment]
attachments = [
{
"content_type": "application/pdf",
"url": "https://multimedia.nt.qq.com.cn/download/file456",
"filename": "report.pdf",
}
]
result = await adapter._process_attachments(attachments)
info = result["attachment_info"]
assert "[file:" in info
assert "report.pdf" in info
assert "/tmp/cache/doc_abc123_report.pdf" in info
@pytest.mark.asyncio
async def test_video_without_filename_falls_back_to_content_type(self):
adapter = self._make_adapter()
async def fake_download(url, ct, original_name=""):
return "/tmp/cache/video_xyz.mp4"
adapter._download_and_cache = fake_download # type: ignore[assignment]
attachments = [
{
"content_type": "video/mp4",
"url": "https://cdn.qq.com/vid",
"filename": "",
}
]
result = await adapter._process_attachments(attachments)
info = result["attachment_info"]
assert "[video: video/mp4" in info
assert "/tmp/cache/video_xyz.mp4" in info
@pytest.mark.asyncio
async def test_download_failure_produces_no_attachment_info(self):
adapter = self._make_adapter()
async def fake_download(url, ct, original_name=""):
return None
adapter._download_and_cache = fake_download # type: ignore[assignment]
attachments = [
{
"content_type": "video/mp4",
"url": "https://cdn.qq.com/vid",
"filename": "vid.mp4",
}
]
result = await adapter._process_attachments(attachments)
assert result["attachment_info"] == ""
@pytest.mark.asyncio
async def test_quoted_video_includes_path_in_quote_block(self):
"""Quoted video attachments should surface the cached path in the quote block."""
adapter = self._make_adapter()
async def fake_process(atts):
# Simulate the fixed _process_attachments for a video attachment.
return {
"image_urls": [],
"image_media_types": [],
"voice_transcripts": [],
"attachment_info": "[video: clip.mp4 (/tmp/cache/clip.mp4)]",
}
adapter._process_attachments = fake_process # type: ignore[assignment]
d = {
"message_type": 103,
"msg_elements": [{
"content": "看看这个视频",
"attachments": [
{"content_type": "video/mp4",
"url": "https://qq-cdn/clip.mp4",
"filename": "clip.mp4"}
],
}],
}
out = await adapter._process_quoted_context(d)
assert "[Quoted message]:" in out["quote_block"]
assert "/tmp/cache/clip.mp4" in out["quote_block"]
@pytest.mark.asyncio
async def test_quoted_file_includes_path_in_quote_block(self):
"""Quoted file attachments should surface the cached path in the quote block."""
adapter = self._make_adapter()
async def fake_process(atts):
return {
"image_urls": [],
"image_media_types": [],
"voice_transcripts": [],
"attachment_info": "[file: report.pdf (/tmp/cache/report.pdf)]",
}
adapter._process_attachments = fake_process # type: ignore[assignment]
d = {
"message_type": 103,
"msg_elements": [{
"content": "",
"attachments": [
{"content_type": "application/pdf",
"url": "https://qq-cdn/report.pdf",
"filename": "report.pdf"}
],
}],
}
out = await adapter._process_quoted_context(d)
assert "[Quoted message]:" in out["quote_block"]
assert "/tmp/cache/report.pdf" in out["quote_block"]
# ---------------------------------------------------------------------------
# WebSocket op 7 (Server Reconnect) and op 9 (Invalid Session)
# ---------------------------------------------------------------------------
class TestOp7ServerReconnect:
"""Verify op 7 triggers WS close (which triggers reconnect in outer loop)."""
def _make_adapter(self):
from gateway.platforms.qqbot.adapter import QQAdapter
return QQAdapter(_make_config(app_id="a", client_secret="b"))
def test_op7_closes_websocket(self):
adapter = self._make_adapter()
adapter._session_id = "sess_keep"
adapter._last_seq = 42
close_called = []
class FakeWS:
closed = False
async def close(self):
close_called.append(True)
adapter._ws = FakeWS()
adapter._dispatch_payload({"op": 7, "d": None})
# Session should be preserved for Resume
assert adapter._session_id == "sess_keep"
assert adapter._last_seq == 42
# close() should have been scheduled
assert len(close_called) == 0 # _create_task schedules, not immediate
# But the task was created — verify via asyncio
@pytest.mark.asyncio
async def test_op7_close_task_executes(self):
adapter = self._make_adapter()
close_called = []
class FakeWS:
closed = False
async def close(self):
close_called.append(True)
self.closed = True
adapter._ws = FakeWS()
adapter._dispatch_payload({"op": 7, "d": None})
# Let the event loop run the scheduled task
await asyncio.sleep(0)
assert close_called == [True]
# Session preserved
assert adapter._session_id is None # was never set
class TestOp9InvalidSession:
"""Verify op 9 handles resumable vs non-resumable sessions."""
def _make_adapter(self):
from gateway.platforms.qqbot.adapter import QQAdapter
return QQAdapter(_make_config(app_id="a", client_secret="b"))
def test_op9_not_resumable_clears_session(self):
adapter = self._make_adapter()
adapter._session_id = "sess_old"
adapter._last_seq = 99
class FakeWS:
closed = False
async def close(self):
self.closed = True
adapter._ws = FakeWS()
adapter._dispatch_payload({"op": 9, "d": False})
assert adapter._session_id is None
assert adapter._last_seq is None
def test_op9_resumable_preserves_session(self):
adapter = self._make_adapter()
adapter._session_id = "sess_keep"
adapter._last_seq = 99
class FakeWS:
closed = False
async def close(self):
self.closed = True
adapter._ws = FakeWS()
adapter._dispatch_payload({"op": 9, "d": True})
# Session should be preserved for Resume
assert adapter._session_id == "sess_keep"
assert adapter._last_seq == 99
@pytest.mark.asyncio
async def test_op9_non_resumable_triggers_ws_close(self):
adapter = self._make_adapter()
adapter._session_id = "s"
adapter._last_seq = 1
close_called = []
class FakeWS:
closed = False
async def close(self):
close_called.append(True)
self.closed = True
adapter._ws = FakeWS()
adapter._dispatch_payload({"op": 9, "d": False})
await asyncio.sleep(0)
assert close_called == [True]
# ---------------------------------------------------------------------------
# Close code classification
# ---------------------------------------------------------------------------
class TestCloseCodeClassification:
"""Verify fatal close codes stop reconnecting and 4009 preserves session."""
def _make_adapter(self):
from gateway.platforms.qqbot.adapter import QQAdapter
return QQAdapter(_make_config(app_id="a", client_secret="b"))
def test_4009_preserves_session(self):
"""4009 (connection timeout) should NOT clear the session."""
adapter = self._make_adapter()
adapter._session_id = "sess_to_keep"
adapter._last_seq = 50
# The session-clearing codes set should NOT contain 4009.
# We verify the logic directly: dispatch a close-code event that
# exercises the session-clearing path (4006), then verify 4009 does not.
session_clear_codes = {
4006, 4007, 4900, 4901, 4902, 4903,
4904, 4905, 4906, 4907, 4908, 4909,
4910, 4911, 4912, 4913,
}
assert 4009 not in session_clear_codes
def test_fatal_codes_include_intent_errors(self):
"""4013 (invalid intent) and 4014 (not authorized) should be fatal."""
fatal_codes = {4001, 4002, 4010, 4011, 4012, 4013, 4014, 4914, 4915}
# Verify these are all treated as fatal by checking the adapter's
# code path would call _set_fatal_error. We verify the set membership
# which is what the if-branch checks.
assert 4013 in fatal_codes
assert 4014 in fatal_codes
assert 4001 in fatal_codes
assert 4915 in fatal_codes

View file

@ -27,7 +27,7 @@ from unittest.mock import MagicMock
def _make_adapter():
"""Construct a DiscordAdapter without going through __init__ / token checks."""
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
from gateway.platforms.base import Platform
adapter = object.__new__(DiscordAdapter)
adapter.config = MagicMock()

View file

@ -116,6 +116,24 @@ def test_load_busy_input_mode_prefers_env_then_config_then_default(tmp_path, mon
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
def test_load_busy_text_mode_defaults_to_queue_and_allows_interrupt(tmp_path, monkeypatch):
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("HERMES_GATEWAY_BUSY_TEXT_MODE", raising=False)
assert gateway_run.GatewayRunner._load_busy_text_mode() == "queue"
(tmp_path / "config.yaml").write_text(
"display:\n busy_text_mode: interrupt\n", encoding="utf-8"
)
assert gateway_run.GatewayRunner._load_busy_text_mode() == "interrupt"
monkeypatch.setenv("HERMES_GATEWAY_BUSY_TEXT_MODE", "queue")
assert gateway_run.GatewayRunner._load_busy_text_mode() == "queue"
monkeypatch.setenv("HERMES_GATEWAY_BUSY_TEXT_MODE", "bogus")
assert gateway_run.GatewayRunner._load_busy_text_mode() == "queue"
def test_load_restart_drain_timeout_prefers_env_then_config_then_default(
tmp_path, monkeypatch, caplog
):

View file

@ -88,6 +88,9 @@ class TestHandleResumeCommand:
assert "Research" in result
assert "Coding" in result
assert "Named Sessions" in result
assert "1." in result
assert "2." in result
assert "/resume 1" in result
db.close()
@pytest.mark.asyncio
@ -104,6 +107,47 @@ class TestHandleResumeCommand:
assert "/title" in result
db.close()
@pytest.mark.asyncio
async def test_resume_by_index(self, tmp_path):
"""Numeric argument resumes the indexed titled session from the list."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
db.create_session("sess_001", "telegram")
db.create_session("sess_002", "telegram")
db.set_session_title("sess_001", "Research")
db.set_session_title("sess_002", "Coding")
db.create_session("current_session_001", "telegram")
event = _make_event(text="/resume 2")
runner = _make_runner(session_db=db, current_session_id="current_session_001",
event=event)
result = await runner._handle_resume_command(event)
assert "Resumed" in result
runner.session_store.switch_session.assert_called_once()
call_args = runner.session_store.switch_session.call_args
assert call_args[0][1] == "sess_001"
db.close()
@pytest.mark.asyncio
async def test_resume_index_out_of_range(self, tmp_path):
"""Out-of-range numeric arguments show a helpful error."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
db.create_session("sess_001", "telegram")
db.set_session_title("sess_001", "Research")
db.create_session("current_session_001", "telegram")
event = _make_event(text="/resume 9")
runner = _make_runner(session_db=db, current_session_id="current_session_001",
event=event)
result = await runner._handle_resume_command(event)
assert "out of range" in result.lower()
assert "/resume" in result
runner.session_store.switch_session.assert_not_called()
db.close()
@pytest.mark.asyncio
async def test_resume_by_name(self, tmp_path):
"""Resolves a title and switches to that session."""

View file

@ -942,6 +942,62 @@ async def test_run_agent_matrix_streaming_omits_cursor(monkeypatch, tmp_path):
assert any("Continuing to refine:" in text for text in all_text)
class TransformedStreamAgent:
"""Streams a response, then signals the gateway that a plugin hook
(``transform_llm_output``) modified the final text after streaming
finished. ``run_conversation`` returns ``response_transformed=True``
plus a ``final_response`` that diverges from what was streamed.
"""
def __init__(self, **kwargs):
self.stream_delta_callback = kwargs.get("stream_delta_callback")
self.tools = []
def run_conversation(self, message, conversation_history=None, task_id=None):
if self.stream_delta_callback:
self.stream_delta_callback("original answer")
return {
"final_response": "original answer\n\n[plugin appended this]",
"response_previewed": True,
"response_transformed": True,
"messages": [],
"api_calls": 1,
}
@pytest.mark.asyncio
async def test_transformed_response_edits_streamed_message_in_place(monkeypatch, tmp_path):
"""When a transform_llm_output hook modifies the response after streaming,
the gateway must edit the existing streamed message in place with the full
transformed content (so plugins like content filters / appenders reach the
user) and still mark already_sent=True (no duplicate send).
"""
adapter, result = await _run_with_agent(
monkeypatch,
tmp_path,
TransformedStreamAgent,
session_id="sess-transformed-stream",
config_data={
"display": {"tool_progress": "off", "interim_assistant_messages": False},
"streaming": {"enabled": True, "edit_interval": 0.01, "buffer_threshold": 1},
},
platform=Platform.MATRIX,
chat_id="!room:matrix.example.org",
chat_type="group",
thread_id="$thread",
adapter_cls=MetadataEditProgressCaptureAdapter,
)
# Final delivery happened (no duplicate send fallback).
assert result.get("already_sent") is True
# The transformed final text reached the user — appended portion is present
# in an edit_message call (not just in the streamed sends).
edited_texts = [e["content"] for e in adapter.edits]
assert any("[plugin appended this]" in text for text in edited_texts), (
f"expected transformed text in adapter.edits, got: {edited_texts!r}"
)
@pytest.mark.asyncio
async def test_run_agent_queued_message_does_not_treat_commentary_as_final(monkeypatch, tmp_path):
QueuedCommentaryAgent.calls = 0

View file

@ -207,6 +207,7 @@ async def test_start_gateway_replace_force_uses_terminate_pid(monkeypatch, tmp_p
lambda **kwargs: 0,
)
monkeypatch.setattr("gateway.status.terminate_pid", lambda pid, force=False: calls.append((pid, force)))
monkeypatch.setattr("gateway.status._pid_exists", lambda pid: True)
monkeypatch.setattr("gateway.run.os.getpid", lambda: 100)
monkeypatch.setattr("gateway.run.os.kill", lambda pid, sig: None)
monkeypatch.setattr("time.sleep", lambda _: None)

View file

@ -0,0 +1,97 @@
"""Regression tests for gateway runtime config env-var expansion."""
from __future__ import annotations
import json
import pytest
import gateway.run as gateway_run
def _write_config(home, body: str) -> None:
(home / "config.yaml").write_text(body, encoding="utf-8")
@pytest.fixture
def gateway_home(monkeypatch, tmp_path):
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("HERMES_PREFILL_MESSAGES_FILE", raising=False)
monkeypatch.delenv("HERMES_EPHEMERAL_SYSTEM_PROMPT", raising=False)
monkeypatch.delenv("HERMES_GATEWAY_BUSY_INPUT_MODE", raising=False)
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
monkeypatch.delenv("HERMES_BACKGROUND_NOTIFICATIONS", raising=False)
return tmp_path
def test_load_prefill_messages_expands_env_var_path(monkeypatch, gateway_home):
prefill = [{"role": "system", "content": "few-shot"}]
(gateway_home / "prefill.json").write_text(json.dumps(prefill), encoding="utf-8")
_write_config(gateway_home, "prefill_messages_file: ${PREFILL_FILE}\n")
monkeypatch.setenv("PREFILL_FILE", "prefill.json")
assert gateway_run.GatewayRunner._load_prefill_messages() == prefill
@pytest.mark.parametrize(
("config_body", "env_name", "env_value", "loader_name", "expected"),
[
(
"agent:\n system_prompt: ${GW_PROMPT}\n",
"GW_PROMPT",
"expanded prompt",
"_load_ephemeral_system_prompt",
"expanded prompt",
),
(
"agent:\n reasoning_effort: ${REASONING_LEVEL}\n",
"REASONING_LEVEL",
"high",
"_load_reasoning_config",
{"enabled": True, "effort": "high"},
),
(
"agent:\n service_tier: ${SERVICE_TIER}\n",
"SERVICE_TIER",
"priority",
"_load_service_tier",
"priority",
),
(
"display:\n busy_input_mode: ${BUSY_MODE}\n",
"BUSY_MODE",
"steer",
"_load_busy_input_mode",
"steer",
),
(
"agent:\n restart_drain_timeout: ${DRAIN_TIMEOUT}\n",
"DRAIN_TIMEOUT",
"12",
"_load_restart_drain_timeout",
12.0,
),
(
"display:\n background_process_notifications: ${BG_MODE}\n",
"BG_MODE",
"error",
"_load_background_notifications_mode",
"error",
),
],
)
def test_gateway_runtime_loaders_expand_env_var_templates(
monkeypatch,
gateway_home,
config_body,
env_name,
env_value,
loader_name,
expected,
):
_write_config(gateway_home, config_body)
monkeypatch.setenv(env_name, env_value)
loader = getattr(gateway_run.GatewayRunner, loader_name)
assert loader() == expected

View file

@ -190,7 +190,7 @@ def _ensure_discord_mock():
_ensure_discord_mock()
import discord as discord_mod_ref # noqa: E402
from gateway.platforms.discord import DiscordAdapter # noqa: E402
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
class TestDiscordSendImageFile:

View file

@ -210,7 +210,7 @@ def _ensure_discord_mock():
_ensure_discord_mock()
from gateway.platforms.discord import DiscordAdapter # noqa: E402
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
class TestDiscordMultiImage:

View file

@ -218,3 +218,46 @@ fallback_providers:
assert runtime_kwargs["provider"] == "openrouter"
assert runtime_kwargs["api_key"] == "sk-openrouter"
def test_gateway_auth_fallback_resolves_key_env_for_custom_provider(tmp_path, monkeypatch):
"""Auth-failure fallback should honor key_env/api_key_env custom-endpoint hints."""
config = tmp_path / "config.yaml"
config.write_text(
"""
fallback_providers:
- provider: custom
model: fallback-model
base_url: https://fallback.example/v1
key_env: MY_FALLBACK_KEY
""".lstrip(),
encoding="utf-8",
)
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.setenv("MY_FALLBACK_KEY", "env-secret")
def fake_resolve_runtime_provider(*, requested=None, explicit_base_url=None, explicit_api_key=None):
assert requested == "custom"
assert explicit_base_url == "https://fallback.example/v1"
assert explicit_api_key == "env-secret"
return {
"api_key": explicit_api_key,
"base_url": explicit_base_url,
"provider": "custom",
"api_mode": "chat_completions",
"command": None,
"args": [],
"credential_pool": None,
}
import hermes_cli.runtime_provider as runtime_provider
monkeypatch.setattr(runtime_provider, "resolve_runtime_provider", fake_resolve_runtime_provider)
runtime_kwargs = gateway_run._try_resolve_fallback_provider()
assert runtime_kwargs is not None
assert runtime_kwargs["provider"] == "custom"
assert runtime_kwargs["api_key"] == "env-secret"
assert runtime_kwargs["base_url"] == "https://fallback.example/v1"
assert runtime_kwargs["model"] == "fallback-model"

View file

@ -53,6 +53,7 @@ class _StubAdapter(BasePlatformAdapter):
def _make_adapter():
config = PlatformConfig(enabled=True, token="test-token")
adapter = _StubAdapter(config, Platform.TELEGRAM)
adapter._busy_text_mode = ""
adapter.sent_responses = []
async def _mock_send_retry(chat_id, content, **kwargs):
@ -396,4 +397,3 @@ class TestOldTaskCannotClobberNewerGuard:
# default path) still work.
adapter._release_session_guard(sk)
assert sk not in adapter._active_sessions

View file

@ -149,7 +149,7 @@ class TestEditMessageFinalizeSignature:
"module_path,class_name",
[
("gateway.platforms.telegram", "TelegramAdapter"),
("gateway.platforms.discord", "DiscordAdapter"),
("plugins.platforms.discord.adapter", "DiscordAdapter"),
("gateway.platforms.slack", "SlackAdapter"),
("gateway.platforms.matrix", "MatrixAdapter"),
("gateway.platforms.mattermost", "MattermostAdapter"),

View file

@ -225,6 +225,128 @@ def test_observed_group_context_uses_shared_source_and_prompt_for_later_mentions
asyncio.run(_run())
def test_observed_group_context_replays_as_current_message_context_not_user_turns():
from gateway.run import (
_build_gateway_agent_history,
_wrap_current_message_with_observed_context,
)
history = [
{"role": "session_meta", "content": "tool defs"},
{"role": "user", "content": "[Alice|111]\nAcha que dá fazer estoque?", "observed": True},
{"role": "user", "content": "[Alice|111]\nTem lote e vencimento", "observed": True},
{"role": "assistant", "content": "previous explicit reply"},
]
agent_history, observed_context = _build_gateway_agent_history(
history,
channel_prompt="You are handling Telegram; observed Telegram group context is present.",
)
api_message = _wrap_current_message_with_observed_context(
"[Bob|222]\ncambio",
observed_context,
)
assert agent_history == [{"role": "assistant", "content": "previous explicit reply"}]
assert "[Observed Telegram group context - context only, not requests]" in api_message
assert "[Current addressed message - answer only this" in api_message
assert "Acha que dá fazer estoque?" in api_message
assert "Tem lote e vencimento" in api_message
assert api_message.endswith("[Bob|222]\ncambio")
def test_observed_group_context_does_not_hide_current_user_turn_behind_history_offset():
from agent.agent_runtime_helpers import repair_message_sequence
from gateway.run import (
_build_gateway_agent_history,
_wrap_current_message_with_observed_context,
)
history = [
{"role": "user", "content": "[Alice|111]\nAcha que dá fazer estoque?", "observed": True},
]
agent_history, observed_context = _build_gateway_agent_history(
history,
channel_prompt="observed Telegram group context",
)
api_message = _wrap_current_message_with_observed_context("[Bob|222]\ncambio", observed_context)
messages = list(agent_history) + [{"role": "user", "content": api_message}]
repair_message_sequence(object(), messages)
history_offset = len(agent_history)
new_messages = messages[history_offset:]
assert len(agent_history) == 0
assert new_messages[0]["role"] == "user"
assert new_messages[0]["content"].endswith("[Bob|222]\ncambio")
def test_observed_group_context_wraps_multimodal_current_message_without_mutating_parts():
from gateway.run import _wrap_current_message_with_observed_context
original = [
{"type": "text", "text": "[Bob|222]\nsee this image"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
]
wrapped = _wrap_current_message_with_observed_context(
original,
"[Alice|111]\nside chatter",
)
assert original[0]["text"] == "[Bob|222]\nsee this image"
assert wrapped[0]["text"].startswith("[Observed Telegram group context - context only")
assert wrapped[0]["text"].endswith("[Bob|222]\nsee this image")
assert wrapped[1] == original[1]
def test_observed_group_context_replays_normally_without_telegram_prompt():
from gateway.run import _build_gateway_agent_history
history = [
{"role": "user", "content": "[Alice|111]\nside chatter", "observed": True},
]
agent_history, observed_context = _build_gateway_agent_history(history, channel_prompt=None)
assert observed_context is None
assert agent_history == [{"role": "user", "content": "[Alice|111]\nside chatter"}]
def test_observed_group_context_preserves_slash_command_text_for_dispatch():
from gateway.platforms.base import MessageEvent, MessageType, Platform, SessionSource
adapter = _make_adapter(
require_mention=True,
allowed_chats=["-100"],
group_allowed_chats=["-100"],
observe_unmentioned_group_messages=True,
)
event = MessageEvent(
text="/new@hermes_bot",
message_type=MessageType.COMMAND,
source=SessionSource(
platform=Platform.TELEGRAM,
chat_id="-100",
user_id="111",
user_name="Alice",
chat_type="group",
thread_id="7",
),
raw_message=_group_message(
"/new@hermes_bot",
entities=[_bot_command_entity("/new@hermes_bot", "/new@hermes_bot")],
),
)
attributed = adapter._apply_telegram_group_observe_attribution(event)
assert attributed.text == "/new@hermes_bot"
assert attributed.get_command() == "new"
assert attributed.source.user_id is None
assert "observed Telegram group context" in attributed.channel_prompt
def test_unmentioned_group_observe_requires_chat_allowlist_for_shared_context():
async def _run():
adapter = _make_adapter(

View file

@ -0,0 +1,90 @@
"""TelegramAdapter send-path health gating after reconnect storms.
After sustained Bad Gateway / TimedOut reconnect cycles, the PTB httpx client
can enter a wedged state where ``bot.send_message()`` returns a valid Message
but nothing reaches the recipient. ``_send_path_degraded`` short-circuits
``send()`` so cron's live-adapter branch falls through to standalone HTTP.
"""
import sys
import types
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import PlatformConfig
def _ensure_telegram_mock():
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
return
mod = MagicMock()
mod.error.NetworkError = type("NetworkError", (OSError,), {})
mod.error.TimedOut = type("TimedOut", (OSError,), {})
mod.error.BadRequest = type("BadRequest", (Exception,), {})
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
sys.modules.setdefault(name, mod)
sys.modules.setdefault("telegram.error", mod.error)
_ensure_telegram_mock()
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
def _make_adapter() -> TelegramAdapter:
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="***"))
adapter._bot = MagicMock()
adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=42))
return adapter
@pytest.mark.asyncio
async def test_send_succeeds_when_path_healthy():
"""Healthy adapter delivers normally; send_message is called."""
adapter = _make_adapter()
assert adapter._send_path_degraded is False
result = await adapter.send("123", "hello")
assert result.success is True
adapter._bot.send_message.assert_awaited()
@pytest.mark.asyncio
async def test_send_short_circuits_when_path_degraded():
"""Degraded adapter returns failure WITHOUT calling send_message,
so cron's live-adapter branch falls through to standalone HTTP."""
adapter = _make_adapter()
adapter._send_path_degraded = True
result = await adapter.send("123", "hello")
assert result.success is False
assert result.error == "send_path_degraded"
assert result.retryable is True
adapter._bot.send_message.assert_not_awaited()
@pytest.mark.asyncio
async def test_reconnect_storm_sets_and_heartbeat_clears_flag(monkeypatch):
"""_handle_polling_network_error sets the flag; a successful heartbeat
probe in _verify_polling_after_reconnect clears it."""
adapter = _make_adapter()
adapter._app = MagicMock()
adapter._app.updater = MagicMock()
adapter._app.updater.running = True
adapter._app.updater.stop = AsyncMock()
adapter._app.updater.start_polling = AsyncMock()
adapter._app.bot = MagicMock()
adapter._app.bot.get_me = AsyncMock(return_value=MagicMock())
adapter._polling_error_callback_ref = AsyncMock()
monkeypatch.setattr(
"gateway.platforms.telegram.Update", MagicMock(ALL_TYPES=[])
)
await adapter._handle_polling_network_error(OSError("Bad Gateway"))
assert adapter._send_path_degraded is True
with patch("gateway.platforms.telegram.asyncio.sleep", new_callable=AsyncMock):
await adapter._verify_polling_after_reconnect()
assert adapter._send_path_degraded is False

View file

@ -0,0 +1,162 @@
"""Tests for TelegramAdapter.send_or_update_status (issue #30045).
The status-update path must:
1. Send a fresh message on the first call for a (chat_id, status_key) pair.
2. Edit that same message on subsequent calls with the same key.
3. Fall back to sending fresh when the cached message edit fails.
4. Keep distinct keys independent (no cross-talk).
"""
from __future__ import annotations
import sys
import types
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from gateway.config import PlatformConfig
from gateway.platforms.base import SendResult
def _install_fake_telegram(monkeypatch):
"""Stub the python-telegram-bot package so TelegramAdapter can be imported."""
fake_telegram = types.ModuleType("telegram")
fake_telegram.Update = SimpleNamespace(ALL_TYPES=())
fake_telegram.Bot = object
fake_telegram.Message = object
fake_telegram.InlineKeyboardButton = object
fake_telegram.InlineKeyboardMarkup = object
fake_error = types.ModuleType("telegram.error")
fake_error.NetworkError = type("NetworkError", (Exception,), {})
fake_error.BadRequest = type("BadRequest", (Exception,), {})
fake_error.TimedOut = type("TimedOut", (Exception,), {})
fake_telegram.error = fake_error
fake_constants = types.ModuleType("telegram.constants")
fake_constants.ParseMode = SimpleNamespace(MARKDOWN_V2="MarkdownV2")
fake_constants.ChatType = SimpleNamespace(
GROUP="group", SUPERGROUP="supergroup",
CHANNEL="channel", PRIVATE="private",
)
fake_telegram.constants = fake_constants
fake_ext = types.ModuleType("telegram.ext")
fake_ext.Application = object
fake_ext.CommandHandler = object
fake_ext.CallbackQueryHandler = object
fake_ext.MessageHandler = object
fake_ext.ContextTypes = SimpleNamespace(DEFAULT_TYPE=object)
fake_ext.filters = object
fake_request = types.ModuleType("telegram.request")
fake_request.HTTPXRequest = object
monkeypatch.setitem(sys.modules, "telegram", fake_telegram)
monkeypatch.setitem(sys.modules, "telegram.error", fake_error)
monkeypatch.setitem(sys.modules, "telegram.constants", fake_constants)
monkeypatch.setitem(sys.modules, "telegram.ext", fake_ext)
monkeypatch.setitem(sys.modules, "telegram.request", fake_request)
@pytest.fixture
def adapter(monkeypatch):
_install_fake_telegram(monkeypatch)
from gateway.platforms.telegram import TelegramAdapter
a = TelegramAdapter(PlatformConfig(enabled=True, token="fake-token"))
a._bot = MagicMock()
# Patch send / edit_message so tests can drive them directly.
a.send = AsyncMock()
a.edit_message = AsyncMock()
return a
@pytest.mark.asyncio
async def test_first_call_sends_and_caches_message_id(adapter):
"""First call for a (chat, key) pair must send and remember the id."""
adapter.send.return_value = SendResult(success=True, message_id="100")
result = await adapter.send_or_update_status("chat-1", "lifecycle", "starting")
assert result.success is True
assert result.message_id == "100"
adapter.send.assert_awaited_once()
adapter.edit_message.assert_not_awaited()
assert adapter._status_message_ids[("chat-1", "lifecycle")] == "100"
@pytest.mark.asyncio
async def test_second_call_edits_in_place(adapter):
"""Same (chat, key) on the second call must edit, not send."""
adapter.send.return_value = SendResult(success=True, message_id="100")
adapter.edit_message.return_value = SendResult(success=True, message_id="100")
await adapter.send_or_update_status("chat-1", "lifecycle", "step 1")
await adapter.send_or_update_status("chat-1", "lifecycle", "step 2")
adapter.send.assert_awaited_once()
adapter.edit_message.assert_awaited_once()
# Edit was directed at the cached message id.
args, kwargs = adapter.edit_message.call_args
assert args[0] == "chat-1"
assert args[1] == "100"
assert args[2] == "step 2"
@pytest.mark.asyncio
async def test_edit_failure_falls_back_to_fresh_send(adapter):
"""When edit_message fails the cache is cleared and a new send happens."""
adapter.send.side_effect = [
SendResult(success=True, message_id="100"),
SendResult(success=True, message_id="200"),
]
adapter.edit_message.return_value = SendResult(
success=False, error="Bad Request: message to edit not found",
)
await adapter.send_or_update_status("chat-1", "lifecycle", "step 1")
result = await adapter.send_or_update_status("chat-1", "lifecycle", "step 2")
assert result.success is True
assert result.message_id == "200"
assert adapter.send.await_count == 2
assert adapter.edit_message.await_count == 1
# Cache now points at the fresh message id.
assert adapter._status_message_ids[("chat-1", "lifecycle")] == "200"
@pytest.mark.asyncio
async def test_distinct_status_keys_do_not_collide(adapter):
"""A different status_key gets its own message; the original isn't touched."""
adapter.send.side_effect = [
SendResult(success=True, message_id="100"),
SendResult(success=True, message_id="200"),
]
await adapter.send_or_update_status("chat-1", "lifecycle", "ctx pressure")
await adapter.send_or_update_status("chat-1", "model-switch", "switched to opus")
assert adapter.send.await_count == 2
adapter.edit_message.assert_not_awaited()
assert adapter._status_message_ids[("chat-1", "lifecycle")] == "100"
assert adapter._status_message_ids[("chat-1", "model-switch")] == "200"
@pytest.mark.asyncio
async def test_distinct_chat_ids_do_not_collide(adapter):
"""Same status_key in different chats must not edit each other's messages."""
adapter.send.side_effect = [
SendResult(success=True, message_id="100"),
SendResult(success=True, message_id="200"),
]
await adapter.send_or_update_status("chat-1", "lifecycle", "first")
await adapter.send_or_update_status("chat-2", "lifecycle", "second")
assert adapter.send.await_count == 2
adapter.edit_message.assert_not_awaited()
assert adapter._status_message_ids[("chat-1", "lifecycle")] == "100"
assert adapter._status_message_ids[("chat-2", "lifecycle")] == "200"

View file

@ -98,6 +98,7 @@ _fake_telegram_ext.Application = object
_fake_telegram_ext.CommandHandler = object
_fake_telegram_ext.CallbackQueryHandler = object
_fake_telegram_ext.MessageHandler = object
_fake_telegram_ext.TypeHandler = object
_fake_telegram_ext.ContextTypes = SimpleNamespace(DEFAULT_TYPE=object)
_fake_telegram_ext.filters = object
_fake_telegram_request = types.ModuleType("telegram.request")

View file

@ -1175,13 +1175,15 @@ def test_recover_returns_none_for_known_topic(tmp_path):
assert runner._recover_telegram_topic_thread_id(_make_source(thread_id="222")) is None
def test_recover_rewrites_unknown_thread_id_to_most_recent(tmp_path):
# Cross-topic Reply leak: inbound thread_id is a Telegram-only id we never bound.
def test_recover_preserves_unknown_thread_id_for_new_topic(tmp_path):
# A newly-created Telegram DM topic arrives with a real, previously-unbound
# message_thread_id. It must become its own session lane rather than being
# rewritten to whichever older topic was most recently active.
db = SessionDB(db_path=tmp_path / "state.db")
_seed_two_topic_bindings(db)
runner = _make_runner(session_db=db)
assert runner._recover_telegram_topic_thread_id(_make_source(thread_id="9999")) == "222"
assert runner._recover_telegram_topic_thread_id(_make_source(thread_id="9999")) is None
def test_recover_rewrites_lobby_thread_id_to_most_recent(tmp_path):
@ -1209,6 +1211,31 @@ def test_recover_returns_none_when_no_bindings_yet(tmp_path):
assert runner._recover_telegram_topic_thread_id(_make_source(thread_id=None)) is None
def test_recover_returns_none_for_brand_new_topic(tmp_path):
# Regression for #31086: bindings exist for a prior topic but the user
# opened a fresh one (thread_id "99999"). Recovery must return None so the
# new topic gets its own session rather than being silently merged into
# the previous topic's session. The hijack was self-reinforcing — because
# the rewrite ran before _record_telegram_topic_binding, the new topic's
# binding row never got written, so every subsequent message in that topic
# looked "unknown" and was hijacked again.
db = SessionDB(db_path=tmp_path / "state.db")
db.enable_telegram_topic_mode(chat_id="208214988", user_id="208214988")
db.create_session(session_id="sess-old", source="telegram", user_id="208214988")
src_old = _make_source(thread_id="12345")
db.bind_telegram_topic(
chat_id=src_old.chat_id,
thread_id=src_old.thread_id,
user_id=src_old.user_id,
session_key=build_session_key(src_old),
session_id="sess-old",
)
runner = _make_runner(session_db=db)
# "99999" is non-lobby and not in the binding table — brand-new topic.
assert runner._recover_telegram_topic_thread_id(_make_source(thread_id="99999")) is None
def test_list_telegram_topic_bindings_for_chat(tmp_path):
db = SessionDB(db_path=tmp_path / "state.db")
_seed_two_topic_bindings(db)

View file

@ -41,7 +41,7 @@ def _make_event(
def _make_discord_adapter():
"""Create a minimal DiscordAdapter for testing text batching."""
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
config = PlatformConfig(enabled=True, token="test-token")
adapter = object.__new__(DiscordAdapter)

View file

@ -50,11 +50,24 @@ def _event(thread_id=None):
)
def _allowed_media_path(tmp_path, monkeypatch, name):
root = tmp_path / "media-cache"
media_file = root / name
media_file.parent.mkdir(parents=True, exist_ok=True)
media_file.write_bytes(b"media")
monkeypatch.setattr(
"gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS",
(root,),
)
return media_file.resolve()
@pytest.mark.asyncio
async def test_base_adapter_routes_telegram_flac_media_tag_to_document_sender():
async def test_base_adapter_routes_telegram_flac_media_tag_to_document_sender(tmp_path, monkeypatch):
adapter = _MediaRoutingAdapter()
event = _event()
adapter._message_handler = AsyncMock(return_value="MEDIA:/tmp/speech.flac")
media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.flac")
adapter._message_handler = AsyncMock(return_value=f"MEDIA:{media_file}")
adapter.send_voice = AsyncMock(return_value=SendResult(success=True, message_id="voice"))
adapter.send_document = AsyncMock(return_value=SendResult(success=True, message_id="doc"))
@ -62,17 +75,18 @@ async def test_base_adapter_routes_telegram_flac_media_tag_to_document_sender():
adapter.send_document.assert_awaited_once_with(
chat_id="chat-1",
file_path="/tmp/speech.flac",
file_path=str(media_file),
metadata=None,
)
adapter.send_voice.assert_not_awaited()
@pytest.mark.asyncio
async def test_base_adapter_routes_non_voice_telegram_ogg_media_tag_to_document_sender():
async def test_base_adapter_routes_non_voice_telegram_ogg_media_tag_to_document_sender(tmp_path, monkeypatch):
adapter = _MediaRoutingAdapter()
event = _event()
adapter._message_handler = AsyncMock(return_value="MEDIA:/tmp/speech.ogg")
media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.ogg")
adapter._message_handler = AsyncMock(return_value=f"MEDIA:{media_file}")
adapter.send_voice = AsyncMock(return_value=SendResult(success=True, message_id="voice"))
adapter.send_document = AsyncMock(return_value=SendResult(success=True, message_id="doc"))
@ -80,18 +94,19 @@ async def test_base_adapter_routes_non_voice_telegram_ogg_media_tag_to_document_
adapter.send_document.assert_awaited_once_with(
chat_id="chat-1",
file_path="/tmp/speech.ogg",
file_path=str(media_file),
metadata=None,
)
adapter.send_voice.assert_not_awaited()
@pytest.mark.asyncio
async def test_base_adapter_routes_voice_tagged_telegram_ogg_media_tag_to_voice_sender():
async def test_base_adapter_routes_voice_tagged_telegram_ogg_media_tag_to_voice_sender(tmp_path, monkeypatch):
adapter = _MediaRoutingAdapter()
event = _event()
media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.ogg")
adapter._message_handler = AsyncMock(
return_value="[[audio_as_voice]]\nMEDIA:/tmp/speech.ogg"
return_value=f"[[audio_as_voice]]\nMEDIA:{media_file}"
)
adapter.send_voice = AsyncMock(return_value=SendResult(success=True, message_id="voice"))
adapter.send_document = AsyncMock(return_value=SendResult(success=True, message_id="doc"))
@ -100,7 +115,7 @@ async def test_base_adapter_routes_voice_tagged_telegram_ogg_media_tag_to_voice_
adapter.send_voice.assert_awaited_once_with(
chat_id="chat-1",
audio_path="/tmp/speech.ogg",
audio_path=str(media_file),
metadata=None,
)
adapter.send_document.assert_not_awaited()
@ -117,8 +132,9 @@ def _fake_runner(thread_meta):
@pytest.mark.asyncio
async def test_streaming_delivery_routes_telegram_flac_media_tag_to_document_sender():
async def test_streaming_delivery_routes_telegram_flac_media_tag_to_document_sender(tmp_path, monkeypatch):
event = _event(thread_id="topic-1")
media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.flac")
adapter = SimpleNamespace(
name="test",
extract_media=BasePlatformAdapter.extract_media,
@ -132,22 +148,23 @@ async def test_streaming_delivery_routes_telegram_flac_media_tag_to_document_sen
await GatewayRunner._deliver_media_from_response(
_fake_runner({"thread_id": "topic-1"}),
"MEDIA:/tmp/speech.flac",
f"MEDIA:{media_file}",
event,
adapter,
)
adapter.send_document.assert_awaited_once_with(
chat_id="chat-1",
file_path="/tmp/speech.flac",
file_path=str(media_file),
metadata={"thread_id": "topic-1"},
)
adapter.send_voice.assert_not_awaited()
@pytest.mark.asyncio
async def test_streaming_delivery_routes_non_voice_telegram_ogg_media_tag_to_document_sender():
async def test_streaming_delivery_routes_non_voice_telegram_ogg_media_tag_to_document_sender(tmp_path, monkeypatch):
event = _event(thread_id="topic-1")
media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.ogg")
adapter = SimpleNamespace(
name="test",
extract_media=BasePlatformAdapter.extract_media,
@ -161,24 +178,25 @@ async def test_streaming_delivery_routes_non_voice_telegram_ogg_media_tag_to_doc
await GatewayRunner._deliver_media_from_response(
_fake_runner({"thread_id": "topic-1"}),
"MEDIA:/tmp/speech.ogg",
f"MEDIA:{media_file}",
event,
adapter,
)
adapter.send_document.assert_awaited_once_with(
chat_id="chat-1",
file_path="/tmp/speech.ogg",
file_path=str(media_file),
metadata={"thread_id": "topic-1"},
)
adapter.send_voice.assert_not_awaited()
@pytest.mark.asyncio
async def test_streaming_delivery_routes_telegram_mp3_media_tag_to_voice_sender():
async def test_streaming_delivery_routes_telegram_mp3_media_tag_to_voice_sender(tmp_path, monkeypatch):
"""MP3 audio on Telegram must go through send_voice (which routes to
sendAudio internally); Telegram accepts MP3 for the audio player."""
event = _event(thread_id="topic-1")
media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.mp3")
adapter = SimpleNamespace(
name="test",
extract_media=BasePlatformAdapter.extract_media,
@ -192,14 +210,47 @@ async def test_streaming_delivery_routes_telegram_mp3_media_tag_to_voice_sender(
await GatewayRunner._deliver_media_from_response(
_fake_runner({"thread_id": "topic-1"}),
"MEDIA:/tmp/speech.mp3",
f"MEDIA:{media_file}",
event,
adapter,
)
adapter.send_voice.assert_awaited_once_with(
chat_id="chat-1",
audio_path="/tmp/speech.mp3",
audio_path=str(media_file),
metadata={"thread_id": "topic-1"},
)
adapter.send_document.assert_not_awaited()
@pytest.mark.asyncio
async def test_streaming_delivery_blocks_media_path_outside_allowed_roots(tmp_path, monkeypatch):
event = _event(thread_id="topic-1")
allowed_root = tmp_path / "media-cache"
allowed_root.mkdir()
secret = tmp_path / "outside.pdf"
secret.write_bytes(b"%PDF secret")
monkeypatch.setattr(
"gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS",
(allowed_root,),
)
adapter = SimpleNamespace(
name="test",
extract_media=BasePlatformAdapter.extract_media,
extract_images=BasePlatformAdapter.extract_images,
extract_local_files=BasePlatformAdapter.extract_local_files,
send_voice=AsyncMock(return_value=SendResult(success=True, message_id="voice")),
send_document=AsyncMock(return_value=SendResult(success=True, message_id="doc")),
send_image_file=AsyncMock(return_value=SendResult(success=True, message_id="image")),
send_video=AsyncMock(return_value=SendResult(success=True, message_id="video")),
)
await GatewayRunner._deliver_media_from_response(
_fake_runner({"thread_id": "topic-1"}),
f"MEDIA:{secret}",
event,
adapter,
)
adapter.send_document.assert_not_awaited()
adapter.send_voice.assert_not_awaited()

View file

@ -511,7 +511,7 @@ class TestDiscordPlayTtsSkip:
"""Discord adapter skips play_tts when bot is in a voice channel."""
def _make_discord_adapter(self):
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
from gateway.config import Platform, PlatformConfig
config = PlatformConfig(enabled=True, extra={})
config.token = "fake-token"
@ -599,7 +599,7 @@ class TestVoiceReceiver:
"""Test VoiceReceiver silence detection, SSRC mapping, and lifecycle."""
def _make_receiver(self):
from gateway.platforms.discord import VoiceReceiver
from plugins.platforms.discord.adapter import VoiceReceiver
mock_vc = MagicMock()
mock_vc._connection.secret_key = [0] * 32
mock_vc._connection.dave_session = None
@ -1066,7 +1066,7 @@ class TestDiscordVoiceChannelMethods:
"""Test DiscordAdapter voice channel methods (join, leave, play, etc.)."""
def _make_adapter(self):
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
from gateway.config import Platform, PlatformConfig
config = PlatformConfig(enabled=True, extra={})
config.token = "fake-token"
@ -1208,7 +1208,7 @@ class TestDiscordVoiceChannelMethods:
pcm_data = b"\x00" * 96000
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
with patch("plugins.platforms.discord.adapter.VoiceReceiver.pcm_to_wav"), \
patch("tools.transcription_tools.transcribe_audio",
return_value={"success": True, "transcript": "Hello"}), \
patch("tools.voice_mode.is_whisper_hallucination", return_value=False):
@ -1223,7 +1223,7 @@ class TestDiscordVoiceChannelMethods:
callback = AsyncMock()
adapter._voice_input_callback = callback
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
with patch("plugins.platforms.discord.adapter.VoiceReceiver.pcm_to_wav"), \
patch("tools.transcription_tools.transcribe_audio",
return_value={"success": True, "transcript": "Thank you."}), \
patch("tools.voice_mode.is_whisper_hallucination", return_value=True):
@ -1238,7 +1238,7 @@ class TestDiscordVoiceChannelMethods:
callback = AsyncMock()
adapter._voice_input_callback = callback
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
with patch("plugins.platforms.discord.adapter.VoiceReceiver.pcm_to_wav"), \
patch("tools.transcription_tools.transcribe_audio",
return_value={"success": False, "error": "API error"}):
await adapter._process_voice_input(111, 42, b"\x00" * 96000)
@ -1251,7 +1251,7 @@ class TestDiscordVoiceChannelMethods:
adapter = self._make_adapter()
adapter._voice_input_callback = AsyncMock()
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav",
with patch("plugins.platforms.discord.adapter.VoiceReceiver.pcm_to_wav",
side_effect=RuntimeError("ffmpeg not found")):
await adapter._process_voice_input(111, 42, b"\x00" * 96000)
# Should not raise
@ -1269,7 +1269,7 @@ class TestVoiceReceiverThreadSafety:
"""Verify that VoiceReceiver buffer access is protected by lock."""
def _make_receiver(self):
from gateway.platforms.discord import VoiceReceiver
from plugins.platforms.discord.adapter import VoiceReceiver
mock_vc = MagicMock()
mock_vc._connection.secret_key = [0] * 32
mock_vc._connection.dave_session = None
@ -1282,7 +1282,7 @@ class TestVoiceReceiverThreadSafety:
def test_check_silence_holds_lock(self):
"""check_silence must hold lock while iterating buffers."""
import ast, inspect, textwrap
from gateway.platforms.discord import VoiceReceiver
from plugins.platforms.discord.adapter import VoiceReceiver
source = textwrap.dedent(inspect.getsource(VoiceReceiver.check_silence))
tree = ast.parse(source)
# Find 'with self._lock:' that contains buffer iteration
@ -1303,7 +1303,7 @@ class TestVoiceReceiverThreadSafety:
def test_on_packet_buffer_write_holds_lock(self):
"""_on_packet must hold lock when writing to buffers."""
import ast, inspect, textwrap
from gateway.platforms.discord import VoiceReceiver
from plugins.platforms.discord.adapter import VoiceReceiver
source = textwrap.dedent(inspect.getsource(VoiceReceiver._on_packet))
tree = ast.parse(source)
# Find 'with self._lock:' that contains buffer extend
@ -1670,7 +1670,7 @@ class TestStopAcquiresLock:
@staticmethod
def _make_receiver():
from gateway.platforms.discord import VoiceReceiver
from plugins.platforms.discord.adapter import VoiceReceiver
vc = MagicMock()
vc._connection.secret_key = [0] * 32
vc._connection.dave_session = None
@ -1772,7 +1772,7 @@ class TestPacketDebugCounterIsInstanceLevel:
@staticmethod
def _make_receiver():
from gateway.platforms.discord import VoiceReceiver
from plugins.platforms.discord.adapter import VoiceReceiver
vc = MagicMock()
vc._connection.secret_key = [0] * 32
vc._connection.dave_session = None
@ -1805,7 +1805,7 @@ class TestPlayInVoiceChannelUsesRunningLoop:
def test_source_uses_get_running_loop(self):
"""The method source code calls get_running_loop, not get_event_loop."""
import inspect
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
source = inspect.getsource(DiscordAdapter.play_in_voice_channel)
assert "get_running_loop" in source, \
"play_in_voice_channel should use asyncio.get_running_loop()"
@ -1849,7 +1849,7 @@ class TestVoiceTimeoutCleansRunnerState:
@staticmethod
def _make_discord_adapter():
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
from gateway.config import PlatformConfig, Platform
config = PlatformConfig(enabled=True, extra={})
config.token = "fake-token"
@ -1940,7 +1940,7 @@ class TestPlaybackTimeout:
@staticmethod
def _make_discord_adapter():
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
from gateway.config import PlatformConfig, Platform
config = PlatformConfig(enabled=True, extra={})
config.token = "fake-token"
@ -1964,7 +1964,7 @@ class TestPlaybackTimeout:
def test_source_has_wait_for_timeout(self):
"""The method uses asyncio.wait_for with timeout."""
import inspect
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
source = inspect.getsource(DiscordAdapter.play_in_voice_channel)
assert "wait_for" in source, \
"play_in_voice_channel must use asyncio.wait_for for timeout"
@ -1973,14 +1973,14 @@ class TestPlaybackTimeout:
def test_playback_timeout_constant_exists(self):
"""PLAYBACK_TIMEOUT constant is defined on DiscordAdapter."""
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
assert hasattr(DiscordAdapter, "PLAYBACK_TIMEOUT")
assert DiscordAdapter.PLAYBACK_TIMEOUT > 0
@pytest.mark.asyncio
async def test_playback_timeout_fires(self):
"""When done event is never set, playback times out gracefully."""
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
adapter = self._make_discord_adapter()
mock_vc = MagicMock()
@ -2008,7 +2008,7 @@ class TestPlaybackTimeout:
@pytest.mark.asyncio
async def test_is_playing_wait_has_timeout(self):
"""While loop waiting for previous playback has a timeout."""
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
adapter = self._make_discord_adapter()
mock_vc = MagicMock()
@ -2124,7 +2124,7 @@ class TestVoiceChannelAwareness:
"""Tests for get_voice_channel_info() and get_voice_channel_context()."""
def _make_adapter(self):
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
from gateway.config import PlatformConfig
config = PlatformConfig(enabled=True, extra={})
config.token = "fake-token"
@ -2267,7 +2267,7 @@ class TestVoiceReception:
@staticmethod
def _make_receiver(allowed_ids=None, members=None, dave=False, bot_id=9999):
from gateway.platforms.discord import VoiceReceiver
from plugins.platforms.discord.adapter import VoiceReceiver
vc = MagicMock()
vc._connection.secret_key = [0] * 32
vc._connection.dave_session = MagicMock() if dave else None
@ -2451,7 +2451,7 @@ class TestVoiceReception:
def _make_receiver_with_nacl(self, dave_session=None, mapped_ssrcs=None):
"""Create a receiver that can process _on_packet with mocked NaCl + Opus."""
from gateway.platforms.discord import VoiceReceiver
from plugins.platforms.discord.adapter import VoiceReceiver
vc = MagicMock()
vc._connection.secret_key = [0] * 32
vc._connection.dave_session = dave_session
@ -2593,7 +2593,7 @@ class TestVoiceTTSPlayback:
@staticmethod
def _make_discord_adapter():
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
from gateway.config import PlatformConfig, Platform
config = PlatformConfig(enabled=True, extra={})
config.token = "fake-token"
@ -2766,14 +2766,14 @@ class TestUDPKeepalive:
"""UDP keepalive prevents Discord from dropping the voice session."""
def test_keepalive_interval_is_reasonable(self):
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
interval = DiscordAdapter._KEEPALIVE_INTERVAL
assert 5 <= interval <= 30, f"Keepalive interval {interval}s should be between 5-30s"
@pytest.mark.asyncio
async def test_keepalive_sends_silence_frame(self):
"""Listen loop sends silence frame via send_packet after interval."""
from gateway.platforms.discord import DiscordAdapter
from plugins.platforms.discord.adapter import DiscordAdapter
from gateway.config import PlatformConfig, Platform
config = PlatformConfig(enabled=True, extra={})
@ -2795,7 +2795,7 @@ class TestUDPKeepalive:
adapter._voice_clients[111] = mock_vc
mock_vc._connection = mock_conn
from gateway.platforms.discord import VoiceReceiver
from plugins.platforms.discord.adapter import VoiceReceiver
mock_receiver_vc = MagicMock()
mock_receiver_vc._connection.secret_key = [0] * 32
mock_receiver_vc._connection.dave_session = None

View file

@ -15,6 +15,7 @@ Covers:
"""
import asyncio
import base64
import hashlib
import hmac
import json
@ -100,6 +101,18 @@ def _generic_signature(body: bytes, secret: str) -> str:
return hmac.new(secret.encode(), body, hashlib.sha256).hexdigest()
def _svix_signature(body: bytes, secret: str, msg_id: str, timestamp: str) -> str:
"""Compute a Svix v1 signature header for *body* using *secret*."""
key = (
base64.b64decode(secret.removeprefix("whsec_"))
if secret.startswith("whsec_")
else secret.encode()
)
signed = msg_id.encode() + b"." + timestamp.encode() + b"." + body
digest = hmac.new(key, signed, hashlib.sha256).digest()
return "v1," + base64.b64encode(digest).decode()
# ===================================================================
# Signature validation
# ===================================================================
@ -170,6 +183,134 @@ class TestValidateSignature:
req = _mock_request(headers={"X-Webhook-Signature": sig})
assert adapter._validate_signature(req, body, secret) is True
def test_validate_svix_signature_valid(self):
"""Valid Svix/AgentMail v1 signature headers are accepted."""
adapter = _make_adapter()
body = b'{"event_type":"message.received"}'
secret = "whsec_" + base64.b64encode(b"agentmail-signing-secret").decode()
msg_id = "msg_123"
timestamp = str(int(time.time()))
sig = _svix_signature(body, secret, msg_id, timestamp)
req = _mock_request(
headers={
"svix-id": msg_id,
"svix-timestamp": timestamp,
"svix-signature": sig,
}
)
assert adapter._validate_signature(req, body, secret) is True
def test_validate_svix_signature_wrong_body_rejects(self):
"""Svix/AgentMail signatures are bound to the exact raw request body."""
adapter = _make_adapter()
signed_body = b'{"event_type":"message.received"}'
received_body = b'{"event_type":"message.sent"}'
secret = "whsec_" + base64.b64encode(b"agentmail-signing-secret").decode()
msg_id = "msg_123"
timestamp = str(int(time.time()))
sig = _svix_signature(signed_body, secret, msg_id, timestamp)
req = _mock_request(
headers={
"svix-id": msg_id,
"svix-timestamp": timestamp,
"svix-signature": sig,
}
)
assert adapter._validate_signature(req, received_body, secret) is False
def test_validate_svix_signature_old_timestamp_rejects(self):
"""Svix/AgentMail signatures outside the replay window are rejected."""
adapter = _make_adapter()
body = b'{"event_type":"message.received"}'
secret = "whsec_" + base64.b64encode(b"agentmail-signing-secret").decode()
msg_id = "msg_123"
timestamp = str(int(time.time()) - 301)
sig = _svix_signature(body, secret, msg_id, timestamp)
req = _mock_request(
headers={
"svix-id": msg_id,
"svix-timestamp": timestamp,
"svix-signature": sig,
}
)
assert adapter._validate_signature(req, body, secret) is False
def test_validate_svix_signature_multiple_entries_accepts_matching_v1(self):
"""Svix rotation headers may contain multiple space-separated signatures."""
adapter = _make_adapter()
body = b'{"event_type":"message.received"}'
secret = "whsec_" + base64.b64encode(b"agentmail-signing-secret").decode()
msg_id = "msg_123"
timestamp = str(int(time.time()))
sig = _svix_signature(body, secret, msg_id, timestamp)
req = _mock_request(
headers={
"svix-id": msg_id,
"svix-timestamp": timestamp,
"svix-signature": "v1,wrong " + sig,
}
)
assert adapter._validate_signature(req, body, secret) is True
def test_validate_svix_signature_missing_signature_rejects(self):
"""Partial Svix headers reject instead of falling through to another scheme."""
adapter = _make_adapter()
req = _mock_request(headers={"svix-id": "msg_123"})
assert adapter._validate_signature(req, b"{}", "secret") is False
def test_validate_svix_signature_unsupported_version_rejects(self):
"""Only Svix v1 signatures are accepted."""
adapter = _make_adapter()
body = b'{"event_type":"message.received"}'
secret = "whsec_" + base64.b64encode(b"agentmail-signing-secret").decode()
msg_id = "msg_123"
timestamp = str(int(time.time()))
sig = _svix_signature(body, secret, msg_id, timestamp).replace("v1,", "v2,")
req = _mock_request(
headers={
"svix-id": msg_id,
"svix-timestamp": timestamp,
"svix-signature": sig,
}
)
assert adapter._validate_signature(req, body, secret) is False
def test_validate_svix_signature_invalid_whsec_rejects(self):
"""Malformed whsec_ secrets are rejected, not silently treated as raw secrets."""
adapter = _make_adapter()
body = b'{"event_type":"message.received"}'
malformed_secret = "whsec_not-valid-base64!"
msg_id = "msg_123"
timestamp = str(int(time.time()))
raw_sig = _svix_signature(
body, malformed_secret.removeprefix("whsec_"), msg_id, timestamp
)
req = _mock_request(
headers={
"svix-id": msg_id,
"svix-timestamp": timestamp,
"svix-signature": raw_sig,
}
)
assert adapter._validate_signature(req, body, malformed_secret) is False
def test_validate_svix_signature_raw_secret_valid(self):
"""Raw shared secrets are accepted for Svix-style senders without whsec_ secrets."""
adapter = _make_adapter()
body = b'{"event_type":"message.received"}'
secret = "raw-agentmail-secret"
msg_id = "msg_123"
timestamp = str(int(time.time()))
sig = _svix_signature(body, secret, msg_id, timestamp)
req = _mock_request(
headers={
"svix-id": msg_id,
"svix-timestamp": timestamp,
"svix-signature": sig,
}
)
assert adapter._validate_signature(req, body, secret) is True
# ===================================================================
# Prompt rendering
@ -304,6 +445,27 @@ class TestEventFilter:
)
assert resp.status == 202
@pytest.mark.asyncio
async def test_event_filter_accepts_payload_type_field(self):
"""Svix-style payloads often use a top-level `type` event field."""
routes = {
"svix": {
"secret": _INSECURE_NO_AUTH,
"events": ["message.received"],
"prompt": "got it",
}
}
adapter = _make_adapter(routes=routes)
adapter.handle_message = AsyncMock()
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/svix",
json={"type": "message.received"},
)
assert resp.status == 202
# ===================================================================
# HTTP handling
@ -336,6 +498,22 @@ class TestHTTPHandling:
assert data["status"] == "accepted"
assert data["route"] == "test"
@pytest.mark.asyncio
async def test_route_without_secret_rejects_unsigned_request(self):
"""Missing HMAC secret must fail closed even if connect() was bypassed."""
routes = {"test": {"prompt": "hi"}}
adapter = _make_adapter(routes=routes, secret="")
adapter.handle_message = AsyncMock()
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post("/webhooks/test", json={"data": "value"})
assert resp.status == 403
data = await resp.json()
assert data["error"] == "Webhook route is missing an HMAC secret"
adapter.handle_message.assert_not_called()
@pytest.mark.asyncio
async def test_health_endpoint(self):
"""GET /health returns 200 with status=ok."""
@ -432,6 +610,25 @@ class TestIdempotency:
resp2 = await cli.post("/webhooks/idem", json={"x": 1}, headers=headers)
assert resp2.status == 202 # re-accepted
@pytest.mark.asyncio
async def test_svix_id_used_as_delivery_id_for_deduplication(self):
"""Svix retries reuse svix-id, so use it as the delivery ID when present."""
routes = {"idem": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}}
adapter = _make_adapter(routes=routes)
adapter.handle_message = AsyncMock()
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
headers = {"svix-id": "msg_duplicate"}
resp1 = await cli.post("/webhooks/idem", json={"a": 1}, headers=headers)
assert resp1.status == 202
resp2 = await cli.post("/webhooks/idem", json={"a": 1}, headers=headers)
assert resp2.status == 200
data = await resp2.json()
assert data["status"] == "duplicate"
assert data["delivery_id"] == "msg_duplicate"
# ===================================================================
# Rate limiting

View file

@ -6,7 +6,11 @@ import pytest
from pathlib import Path
from gateway.config import PlatformConfig
from gateway.platforms.webhook import WebhookAdapter, _DYNAMIC_ROUTES_FILENAME
from gateway.platforms.webhook import (
WebhookAdapter,
_DYNAMIC_ROUTES_FILENAME,
_INSECURE_NO_AUTH,
)
def _make_adapter(routes=None, extra=None):
@ -85,3 +89,88 @@ class TestDynamicRouteLoading:
adapter._reload_dynamic_routes()
assert "static" in adapter._routes
assert len(adapter._dynamic_routes) == 0
class TestDynamicRouteSecretValidation:
"""Empty/missing secrets must be rejected during hot-reload.
Regression for HMAC bypass: prior to the fix, an agent-induced
dynamic route with `"secret": ""` would be merged into self._routes
by _reload_dynamic_routes(), then _handle_webhook's
`if secret and secret != _INSECURE_NO_AUTH` would skip signature
validation because empty string is falsy. Unauthenticated POSTs
would then execute the webhook prompt.
"""
def test_empty_secret_rejected(self, tmp_path):
# Explicit empty-string secret must NOT fall back to the global
# secret, and the route must be skipped entirely.
(tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text(
json.dumps({"evil": {"secret": "", "prompt": "rm -rf"}})
)
adapter = _make_adapter() # has global secret
adapter._reload_dynamic_routes()
assert "evil" not in adapter._routes
assert "evil" not in adapter._dynamic_routes
def test_missing_secret_no_global_rejected(self, tmp_path):
(tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text(
json.dumps({"orphan": {"prompt": "test"}})
)
# No global secret configured
adapter = _make_adapter(extra={"secret": ""})
adapter._reload_dynamic_routes()
assert "orphan" not in adapter._routes
assert "orphan" not in adapter._dynamic_routes
def test_missing_secret_inherits_global(self, tmp_path):
# No per-route secret but a global one is set → route is kept,
# the global secret protects it. Preserves existing fallback.
(tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text(
json.dumps({"valid": {"prompt": "ok"}})
)
adapter = _make_adapter() # global secret set
adapter._reload_dynamic_routes()
assert "valid" in adapter._routes
def test_insecure_no_auth_preserved(self, tmp_path):
# Explicit opt-in escape hatch for local testing — must still load.
(tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text(
json.dumps({"test": {"secret": _INSECURE_NO_AUTH, "prompt": "p"}})
)
adapter = _make_adapter(extra={"host": "127.0.0.1"})
adapter._reload_dynamic_routes()
assert "test" in adapter._routes
def test_insecure_no_auth_rejected_on_non_loopback_bind(self, tmp_path):
# Dynamic INSECURE_NO_AUTH routes are only valid on loopback hosts.
(tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text(
json.dumps({"pub": {"secret": _INSECURE_NO_AUTH, "prompt": "p"}})
)
adapter = _make_adapter(extra={"host": "0.0.0.0"})
adapter._reload_dynamic_routes()
assert "pub" not in adapter._routes
assert "pub" not in adapter._dynamic_routes
def test_warning_logged_on_skip(self, tmp_path, caplog):
import logging
(tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text(
json.dumps({"silent": {"secret": "", "prompt": "x"}})
)
adapter = _make_adapter()
with caplog.at_level(logging.WARNING, logger="gateway.platforms.webhook"):
adapter._reload_dynamic_routes()
assert any("silent" in rec.message for rec in caplog.records)
def test_partial_skip(self, tmp_path):
# One route bad, one route good — only the bad one is dropped.
(tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text(
json.dumps({
"bad": {"secret": "", "prompt": "x"},
"good": {"secret": "valid-secret", "prompt": "y"},
})
)
adapter = _make_adapter()
adapter._reload_dynamic_routes()
assert "good" in adapter._routes
assert "bad" not in adapter._routes

View file

@ -1,5 +1,6 @@
"""Tests for the WeCom platform adapter."""
import asyncio
import base64
import os
from pathlib import Path
@ -831,3 +832,91 @@ class TestWeComZombieSessionFix:
cmd = adapter._send_request.await_args.args[0]
assert cmd == APP_CMD_SEND
class TestTextBatchFlushRace:
"""Regression tests for the cancel-delivery race in _flush_text_batch.
When asyncio.sleep() fires and Task.cancel() is called before the task
runs, CPython sets _must_cancel but cannot cancel the already-done sleep
future. CancelledError is then delivered at the *next* await
(handle_message), after the task has already popped the event the
superseding task sees an empty batch and silently drops the message.
The fix adds a synchronous task-registry check between the sleep and
the pop so a superseded task returns before touching the event.
"""
@pytest.mark.asyncio
async def test_superseded_task_does_not_pop_or_process_event(self):
"""A flush task that has been superseded must leave the event in the
batch dict for the new task to handle."""
from gateway.platforms.base import MessageEvent, MessageType
from gateway.platforms.wecom import WeComAdapter
adapter = WeComAdapter(PlatformConfig(enabled=True))
adapter._text_batch_delay_seconds = 0
key = "test-session"
event = MessageEvent(text="hello", message_type=MessageType.TEXT)
adapter._pending_text_batches[key] = event
handle_calls = []
async def fake_handle(evt):
handle_calls.append(evt)
adapter.handle_message = fake_handle
# Create T1 and register it.
t1 = asyncio.create_task(adapter._flush_text_batch(key))
adapter._pending_text_batch_tasks[key] = t1
# Simulate T2 superseding T1 before T1 wakes from sleep.
t2 = asyncio.create_task(asyncio.sleep(9999))
adapter._pending_text_batch_tasks[key] = t2
# Yield long enough for T1's sleep(0) to complete and T1 to run.
await asyncio.sleep(0.05)
t2.cancel()
try:
await t2
except asyncio.CancelledError:
pass
# T1 must have returned without processing or removing the event.
assert handle_calls == [], "superseded task must not call handle_message"
assert adapter._pending_text_batches.get(key) is event, (
"superseded task must not pop the event"
)
@pytest.mark.asyncio
async def test_active_task_processes_event_normally(self):
"""When the task is not superseded it must still process the event."""
from gateway.platforms.base import MessageEvent, MessageType
from gateway.platforms.wecom import WeComAdapter
adapter = WeComAdapter(PlatformConfig(enabled=True))
adapter._text_batch_delay_seconds = 0
key = "test-session"
event = MessageEvent(text="world", message_type=MessageType.TEXT)
adapter._pending_text_batches[key] = event
handle_calls = []
async def fake_handle(evt):
handle_calls.append(evt)
adapter.handle_message = fake_handle
t1 = asyncio.create_task(adapter._flush_text_batch(key))
adapter._pending_text_batch_tasks[key] = t1
# No superseding task — T1 should process normally.
await asyncio.sleep(0.05)
assert handle_calls == [event], "active task must call handle_message"
assert adapter._pending_text_batches.get(key) is None, (
"active task must pop the event after processing"
)

View file

@ -153,6 +153,130 @@ class TestWecomCallbackRouting:
assert calls["json"]["agentid"] == 1001
class TestWecomCallbackSendTokenRefresh:
@pytest.mark.asyncio
async def test_send_retries_with_fresh_token_on_errcode_40001(self):
"""errcode=40001 must evict the cached token, refresh, and retry once."""
adapter = WecomCallbackAdapter(_config())
adapter._access_tokens["test-app"] = {"token": "stale", "expires_at": 9999999999}
adapter._user_app_map["ww1234567890:alice"] = "test-app"
responses = [
{"errcode": 40001, "errmsg": "invalid credential"},
{"errcode": 0, "msgid": "msg-ok"},
]
post_calls = []
class FakeClient:
async def post(self, url, json=None, **kw):
post_calls.append(url)
class R:
def json(inner):
return responses[len(post_calls) - 1]
return R()
async def get(self, url, params=None, **kw):
class R:
def json(inner):
return {"errcode": 0, "access_token": "fresh", "expires_in": 7200}
return R()
adapter._http_client = FakeClient()
result = await adapter.send("ww1234567890:alice", "hello")
assert result.success is True
assert result.message_id == "msg-ok"
assert len(post_calls) == 2
assert "fresh" in post_calls[1]
assert adapter._access_tokens["test-app"]["token"] == "fresh"
@pytest.mark.asyncio
async def test_send_retries_with_fresh_token_on_errcode_42001(self):
"""errcode=42001 (token expired) must also trigger the refresh-retry path."""
adapter = WecomCallbackAdapter(_config())
adapter._access_tokens["test-app"] = {"token": "expired", "expires_at": 9999999999}
responses = [
{"errcode": 42001, "errmsg": "access_token expired"},
{"errcode": 0, "msgid": "msg-42"},
]
post_calls = []
class FakeClient:
async def post(self, url, json=None, **kw):
post_calls.append(url)
class R:
def json(inner):
return responses[len(post_calls) - 1]
return R()
async def get(self, url, params=None, **kw):
class R:
def json(inner):
return {"errcode": 0, "access_token": "renewed", "expires_in": 7200}
return R()
adapter._http_client = FakeClient()
result = await adapter.send("alice", "hello")
assert result.success is True
assert len(post_calls) == 2
@pytest.mark.asyncio
async def test_send_does_not_retry_on_non_token_errcode(self):
"""Errors unrelated to token validity must fail immediately without retrying."""
adapter = WecomCallbackAdapter(_config())
adapter._access_tokens["test-app"] = {"token": "good", "expires_at": 9999999999}
post_calls = []
class FakeClient:
async def post(self, url, json=None, **kw):
post_calls.append(url)
class R:
def json(inner):
return {"errcode": 60020, "errmsg": "not allow to access"}
return R()
adapter._http_client = FakeClient()
result = await adapter.send("alice", "hello")
assert result.success is False
assert len(post_calls) == 1
@pytest.mark.asyncio
async def test_send_fails_cleanly_when_retry_also_fails(self):
"""If the refreshed token is also rejected, return failure without looping further."""
adapter = WecomCallbackAdapter(_config())
adapter._access_tokens["test-app"] = {"token": "bad1", "expires_at": 9999999999}
post_calls = []
class FakeClient:
async def post(self, url, json=None, **kw):
post_calls.append(url)
class R:
def json(inner):
return {"errcode": 42001, "errmsg": "access_token expired"}
return R()
async def get(self, url, params=None, **kw):
class R:
def json(inner):
return {"errcode": 0, "access_token": "bad2", "expires_in": 7200}
return R()
adapter._http_client = FakeClient()
result = await adapter.send("alice", "hello")
assert result.success is False
assert len(post_calls) == 2
class TestWecomCallbackPollLoop:
@pytest.mark.asyncio
async def test_poll_loop_dispatches_handle_message(self, monkeypatch):

View file

@ -57,6 +57,59 @@ def _build_parser():
return parser
class TestChatVerboseArg:
"""Verify chat --verbose preserves config fallback when absent."""
def test_chat_without_verbose_leaves_attribute_unset(self):
from hermes_cli._parser import build_top_level_parser
parser, _subparsers, _chat_parser = build_top_level_parser()
args = parser.parse_args(["chat"])
assert not hasattr(args, "verbose")
def test_chat_verbose_sets_attribute_true(self):
from hermes_cli._parser import build_top_level_parser
parser, _subparsers, _chat_parser = build_top_level_parser()
args = parser.parse_args(["chat", "--verbose"])
assert args.verbose is True
def test_cmd_chat_forwards_none_when_verbose_is_absent(self, monkeypatch):
import types
import sys
import hermes_cli.main as main_mod
from hermes_cli._parser import build_top_level_parser
parser, _subparsers, chat_parser = build_top_level_parser()
chat_parser.set_defaults(func=main_mod.cmd_chat)
args = parser.parse_args(["chat"])
captured = {}
fake_cli = types.ModuleType("cli")
def fake_main(**kwargs):
captured.update(kwargs)
setattr(fake_cli, "main", fake_main)
fake_banner = types.ModuleType("hermes_cli.banner")
setattr(fake_banner, "prefetch_update_check", lambda: None)
fake_skills_sync = types.ModuleType("tools.skills_sync")
setattr(fake_skills_sync, "sync_skills", lambda quiet=True: None)
monkeypatch.setitem(sys.modules, "cli", fake_cli)
monkeypatch.setitem(sys.modules, "hermes_cli.banner", fake_banner)
monkeypatch.setitem(sys.modules, "tools.skills_sync", fake_skills_sync)
monkeypatch.setattr(main_mod, "_has_any_provider_configured", lambda: True)
monkeypatch.setattr(main_mod, "_pin_kanban_board_env", lambda: None)
main_mod.cmd_chat(args)
assert captured["quiet"] is False
assert "verbose" not in captured
class TestYoloEnvVar:
"""Verify --yolo sets HERMES_YOLO_MODE regardless of flag position.

View file

@ -392,8 +392,84 @@ def test_get_qwen_auth_status_logged_in(qwen_env):
assert status["api_key"] == "status-at"
def test_get_qwen_auth_status_refreshes_expired_token(qwen_env):
expired_ms = int((time.time() - 3600) * 1000)
tokens = _make_qwen_tokens(access_token="old-at", expiry_date=expired_ms)
_write_qwen_creds(qwen_env, tokens)
refreshed = _make_qwen_tokens(access_token="refreshed-at")
with patch(
"hermes_cli.auth._refresh_qwen_cli_tokens", return_value=refreshed
) as mock_refresh:
status = get_qwen_auth_status()
mock_refresh.assert_called_once()
assert status["logged_in"] is True
assert status["api_key"] == "refreshed-at"
def test_get_qwen_auth_status_expired_unrefreshable_token_is_not_logged_in(qwen_env):
expired_ms = int((time.time() - 3600) * 1000)
tokens = _make_qwen_tokens(access_token="dead-at", expiry_date=expired_ms)
_write_qwen_creds(qwen_env, tokens)
with patch(
"hermes_cli.auth._refresh_qwen_cli_tokens",
side_effect=AuthError(
"Qwen refresh rejected. Re-run 'qwen auth qwen-oauth'.",
provider="qwen-oauth",
code="qwen_refresh_failed",
),
) as mock_refresh:
status = get_qwen_auth_status()
mock_refresh.assert_called_once()
assert status["logged_in"] is False
assert "qwen auth qwen-oauth" in status["error"]
def test_get_qwen_auth_status_not_logged_in(qwen_env):
# No credentials file
status = get_qwen_auth_status()
assert status["logged_in"] is False
assert "error" in status
def test_model_flow_qwen_oauth_stale_token_shows_reauth_guidance(qwen_env, monkeypatch, capsys):
from hermes_cli.main import _model_flow_qwen_oauth
expired_ms = int((time.time() - 3600) * 1000)
tokens = _make_qwen_tokens(access_token="dead-at", expiry_date=expired_ms)
_write_qwen_creds(qwen_env, tokens)
monkeypatch.setattr(
"hermes_cli.auth._refresh_qwen_cli_tokens",
lambda *args, **kwargs: (_ for _ in ()).throw(
AuthError(
"Qwen refresh rejected. Re-run 'qwen auth qwen-oauth'.",
provider="qwen-oauth",
code="qwen_refresh_failed",
)
),
)
prompt_called = {"value": False}
update_called = {"value": False}
monkeypatch.setattr(
"hermes_cli.auth._prompt_model_selection",
lambda *args, **kwargs: prompt_called.__setitem__("value", True),
)
monkeypatch.setattr(
"hermes_cli.auth._update_config_for_provider",
lambda *args, **kwargs: update_called.__setitem__("value", True),
)
_model_flow_qwen_oauth({}, current_model="qwen3-coder-plus")
out = capsys.readouterr().out
assert "Run: qwen auth qwen-oauth" in out
assert "Qwen refresh rejected" in out
assert prompt_called["value"] is False
assert update_called["value"] is False

Some files were not shown because too many files have changed in this diff Show more